Maison > Périphériques technologiques > IA > le corps du texte

Améliorez les points clés de Pytorch et améliorez l'optimiseur !

WBOY
Libérer: 2024-01-05 13:22:01
avant
1188 Les gens l'ont consulté

Salut, je m'appelle Xiaozhuang !

Aujourd'hui, nous parlons de l'optimiseur dans Pytorch.

Le choix de l'optimiseur a un impact direct sur l'effet d'entraînement et la rapidité du modèle d'apprentissage profond. Différents optimiseurs sont adaptés à différents problèmes, et leurs différences de performances peuvent amener le modèle à converger plus rapidement et de manière plus stable, ou à mieux fonctionner sur une tâche spécifique. Par conséquent, lors du choix d’un optimiseur, des compromis et des décisions doivent être faits en fonction des caractéristiques du problème spécifique.

Par conséquent, choisir le bon optimiseur est crucial pour régler les modèles d’apprentissage profond. Le choix de l'optimiseur affectera non seulement de manière significative les performances du modèle, mais également l'efficacité du processus de formation.

PyTorch fournit une variété d'optimiseurs qui peuvent être utilisés pour entraîner les réseaux de neurones et mettre à jour les poids des modèles. Ces optimiseurs incluent les communs SGD, Adam, RMSprop, etc. Chaque optimiseur a ses caractéristiques uniques et ses scénarios applicables. Le choix d'un optimiseur approprié peut accélérer la convergence des modèles et améliorer les résultats de la formation. Lorsque vous utilisez l'optimiseur, vous devez définir des hyperparamètres tels que le taux d'apprentissage et la perte de poids, ainsi que définir des fonctions de perte et des paramètres de modèle.

突破Pytorch核心点,优化器 !!

Optimiseurs courants

Listons d'abord quelques optimiseurs couramment utilisés dans PyTorch et donnons-en une brève introduction :

Comprenons comment fonctionne SGD (descente de gradient stochastique). SGD est un algorithme d'optimisation couramment utilisé pour résoudre les paramètres des modèles d'apprentissage automatique. Il estime le gradient en sélectionnant aléatoirement un petit lot d’échantillons et utilise la direction négative du gradient pour mettre à jour les paramètres. Cela permet d'optimiser progressivement les performances du modèle au cours d'un processus itératif. L'avantage de SGD est une efficacité de calcul élevée, particulièrement adaptée à la

La descente de gradient stochastique est un algorithme d'optimisation couramment utilisé pour minimiser la fonction de perte. Il fonctionne en calculant le gradient des poids par rapport à la fonction de perte et en mettant à jour les poids dans le sens négatif du gradient. Cet algorithme est largement utilisé en apprentissage automatique et en apprentissage profond.

optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)
Copier après la connexion

(2) Adam

Adam est un algorithme d'optimisation du taux d'apprentissage adaptatif qui combine les idées d'AdaGrad et de RMSProp. Par rapport à l'algorithme traditionnel de descente de gradient, Adam peut calculer différents taux d'apprentissage pour chaque paramètre, s'adaptant ainsi mieux aux caractéristiques des différents paramètres. En ajustant de manière adaptative le taux d'apprentissage, Adam peut améliorer la vitesse de convergence et les performances du modèle.

optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
Copier après la connexion

(3) Adagrad

Adagrad est un algorithme d'optimisation du taux d'apprentissage adaptatif qui ajuste le taux d'apprentissage en fonction du gradient historique des paramètres. Cependant, à mesure que le taux d’apprentissage diminue progressivement, la formation peut s’arrêter prématurément.

optimizer = torch.optim.Adagrad(model.parameters(), lr=learning_rate)
Copier après la connexion

(4) RMSProp

RMSProp est également un algorithme de taux d'apprentissage adaptatif qui ajuste le taux d'apprentissage en considérant la moyenne mobile du gradient.

optimizer = torch.optim.RMSprop(model.parameters(), lr=learning_rate)
Copier après la connexion

(5) Adadelta

Adadelta est un algorithme d'optimisation du taux d'apprentissage adaptatif et une version améliorée de RMSProp, qui ajuste dynamiquement le taux d'apprentissage en considérant la moyenne mobile du gradient et la moyenne mobile des paramètres.

optimizer = torch.optim.Adadelta(model.parameters(), lr=learning_rate)
Copier après la connexion

Un cas complet

Ici, parlons de la façon d'utiliser PyTorch pour former un simple réseau neuronal convolutionnel (CNN) pour la reconnaissance de chiffres manuscrits.

Ce cas utilise l'ensemble de données MNIST et utilise la bibliothèque Matplotlib pour tracer la courbe de perte et la courbe de précision.

import torchimport torch.nn as nnimport torch.optim as optimfrom torchvision import datasets, transformsfrom torch.utils.data import DataLoaderimport matplotlib.pyplot as plt# 设置随机种子torch.manual_seed(42)# 定义数据转换transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])# 下载和加载MNIST数据集train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)test_loader = DataLoader(test_dataset, batch_size=1000, shuffle=False)# 定义简单的卷积神经网络模型class CNN(nn.Module):def __init__(self):super(CNN, self).__init__()self.conv1 = nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1)self.relu = nn.ReLU()self.pool = nn.MaxPool2d(kernel_size=2, stride=2)self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)self.fc1 = nn.Linear(64 * 7 * 7, 128)self.fc2 = nn.Linear(128, 10)def forward(self, x):x = self.conv1(x)x = self.relu(x)x = self.pool(x)x = self.conv2(x)x = self.relu(x)x = self.pool(x)x = x.view(-1, 64 * 7 * 7)x = self.fc1(x)x = self.relu(x)x = self.fc2(x)return x# 创建模型、损失函数和优化器model = CNN()criterion = nn.CrossEntropyLoss()optimizer = optim.Adam(model.parameters(), lr=0.001)# 训练模型num_epochs = 5train_losses = []train_accuracies = []for epoch in range(num_epochs):model.train()total_loss = 0.0correct = 0total = 0for inputs, labels in train_loader:optimizer.zero_grad()outputs = model(inputs)loss = criterion(outputs, labels)loss.backward()optimizer.step()total_loss += loss.item()_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()accuracy = correct / totaltrain_losses.append(total_loss / len(train_loader))train_accuracies.append(accuracy)print(f"Epoch {epoch+1}/{num_epochs}, Loss: {train_losses[-1]:.4f}, Accuracy: {accuracy:.4f}")# 绘制损失曲线和准确率曲线plt.figure(figsize=(10, 5))plt.subplot(1, 2, 1)plt.plot(train_losses, label='Training Loss')plt.title('Training Loss')plt.xlabel('Epoch')plt.ylabel('Loss')plt.legend()plt.subplot(1, 2, 2)plt.plot(train_accuracies, label='Training Accuracy')plt.title('Training Accuracy')plt.xlabel('Epoch')plt.ylabel('Accuracy')plt.legend()plt.tight_layout()plt.show()# 在测试集上评估模型model.eval()correct = 0total = 0with torch.no_grad():for inputs, labels in test_loader:outputs = model(inputs)_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()accuracy = correct / totalprint(f"Accuracy on test set: {accuracy * 100:.2f}%")
Copier après la connexion

Dans le code ci-dessus, nous définissons un simple réseau de neurones convolutifs (CNN), entraîné à l'aide de la perte d'entropie croisée et de l'optimiseur Adam.

Pendant le processus de formation, nous avons enregistré la perte et la précision de chaque époque, et utilisé la bibliothèque Matplotlib pour tracer la courbe de perte et la courbe de précision.

突破Pytorch核心点,优化器 !!

Je m'appelle Xiao Zhuang, à la prochaine fois !

Ce qui précède est le contenu détaillé de. pour plus d'informations, suivez d'autres articles connexes sur le site Web de PHP en chinois!

Étiquettes associées:
source:51cto.com
Déclaration de ce site Web
Le contenu de cet article est volontairement contribué par les internautes et les droits d'auteur appartiennent à l'auteur original. Ce site n'assume aucune responsabilité légale correspondante. Si vous trouvez un contenu suspecté de plagiat ou de contrefaçon, veuillez contacter admin@php.cn
Tutoriels populaires
Plus>
Derniers téléchargements
Plus>
effets Web
Code source du site Web
Matériel du site Web
Modèle frontal
À propos de nous Clause de non-responsabilité Sitemap
Site Web PHP chinois:Formation PHP en ligne sur le bien-être public,Aidez les apprenants PHP à grandir rapidement!