Maison > Périphériques technologiques > IA > Comment créer un réseau neuronal simple à l'aide de PyTorch

Comment créer un réseau neuronal simple à l'aide de PyTorch

WBOY
Libérer: 2024-01-25 09:27:06
avant
706 Les gens l'ont consulté

Comment créer un réseau neuronal simple à laide de PyTorch

PyTorch est un framework d'apprentissage en profondeur basé sur Python pour créer divers réseaux de neurones. Cet article montrera comment utiliser PyTorch pour créer un réseau neuronal simple et fournira des exemples de code.

Tout d'abord, nous devons installer PyTorch. Il peut être installé à partir de la ligne de commande avec :

pip install torch
Copier après la connexion

Ensuite, nous utiliserons PyTorch pour construire un réseau neuronal simple entièrement connecté pour les tâches de classification binaire. Ce réseau de neurones aura deux couches cachées de 10 neurones chacune. Nous utiliserons la fonction d'activation sigmoïde et la fonction de perte d'entropie croisée.

Voici le code complet :

import torch
import torch.nn as nn
import torch.optim as optim

# 定义神经网络模型
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(2, 10)  # 第一个隐藏层
        self.fc2 = nn.Linear(10, 10)  # 第二个隐藏层
        self.fc3 = nn.Linear(10, 1)  # 输出层

    def forward(self, x):
        x = torch.sigmoid(self.fc1(x))
        x = torch.sigmoid(self.fc2(x))
        x = torch.sigmoid(self.fc3(x))
        return x

# 创建数据集
X = torch.tensor([[0, 0], [0, 1], [1, 0], [1, 1]], dtype=torch.float32)
y = torch.tensor([[0], [1], [1], [0]], dtype=torch.float32)

# 创建神经网络实例
net = Net()

# 定义损失函数和优化器
criterion = nn.BCELoss()
optimizer = optim.SGD(net.parameters(), lr=0.1)

# 训练神经网络
for epoch in range(10000):
    optimizer.zero_grad()
    output = net(X)
    loss = criterion(output, y)
    loss.backward()
    optimizer.step()

    # 打印训练损失
    if epoch % 1000 == 0:
    print('Epoch {}: loss = {}'.format(epoch, loss.item()))

# 使用训练好的神经网络进行预测
with torch.no_grad():
    output = net(X)
    predicted = (output > 0.5).float()
    print('Predicted: {}\n'.format(predicted))
Copier après la connexion

Tout d'abord, nous définissons une classe appelée Net, qui hérite de nn.Module. Cette classe contient toutes les couches du réseau neuronal. Dans cet exemple, nous définissons trois couches entièrement connectées, dont les deux premières sont des couches cachées et la dernière est la couche de sortie.

Dans la classe Net, en plus de définir une méthode directe pour décrire le processus de propagation vers l'avant du réseau neuronal, nous utilisons également la fonction d'activation sigmoïde pour transmettre la sortie de chaque couche cachée à la couche suivante.

Ensuite, nous avons créé un ensemble de données contenant quatre échantillons, où chaque échantillon possède deux caractéristiques. Nous avons également défini une instance de réseau neuronal nommée net et sélectionné BCELoss comme fonction de perte et SGD comme optimiseur.

Ensuite, nous commençons à entraîner le réseau neuronal. À chaque itération, nous remettons d'abord à zéro le gradient de l'optimiseur, puis transmettons l'ensemble de données X dans le réseau neuronal pour obtenir le résultat. Nous calculons la perte et effectuons une rétropropagation, et enfin mettons à jour les paramètres du réseau à l'aide d'un optimiseur. Nous avons également imprimé la perte d'entraînement toutes les 1 000 itérations.

Une fois la formation terminée, nous utilisons le gestionnaire de contexte no_grad pour faire des prédictions sur l'ensemble de données. Nous allons sortir les quatre prédictions et les imprimer.

Ceci est un exemple simple montrant comment créer un réseau neuronal de base à l'aide de PyTorch. PyTorch fournit de nombreux outils et fonctions pour nous aider à créer et former plus facilement des réseaux de neurones.

Ce qui précède est le contenu détaillé de. pour plus d'informations, suivez d'autres articles connexes sur le site Web de PHP en chinois!

Étiquettes associées:
source:163.com
Déclaration de ce site Web
Le contenu de cet article est volontairement contribué par les internautes et les droits d'auteur appartiennent à l'auteur original. Ce site n'assume aucune responsabilité légale correspondante. Si vous trouvez un contenu suspecté de plagiat ou de contrefaçon, veuillez contacter admin@php.cn
Tutoriels populaires
Plus>
Derniers téléchargements
Plus>
effets Web
Code source du site Web
Matériel du site Web
Modèle frontal