Cet article présente principalement PyTorch pour construire rapidement un réseau de neurones et une explication détaillée de ses méthodes de sauvegarde et d'extraction. Maintenant, je le partage avec vous et le donne comme référence. Jetons un coup d'oeil ensemble
Parfois, nous avons entraîné un modèle et souhaitons le sauvegarder pour une utilisation directe la prochaine fois sans passer du temps à l'entraîner à nouveau la prochaine fois. Dans cette section, nous expliquerons comment construire rapidement un réseau neuronal avec. PyTorch et son explication détaillée de la méthode de sauvegarde et d'extraction
1. Méthode PyTorch pour construire rapidement un réseau de neurones
Regardez le code expérimental en premier :
import torch import torch.nn.functional as F # 方法1,通过定义一个Net类来建立神经网络 class Net(torch.nn.Module): def __init__(self, n_feature, n_hidden, n_output): super(Net, self).__init__() self.hidden = torch.nn.Linear(n_feature, n_hidden) self.predict = torch.nn.Linear(n_hidden, n_output) def forward(self, x): x = F.relu(self.hidden(x)) x = self.predict(x) return x net1 = Net(2, 10, 2) print('方法1:\n', net1) # 方法2 通过torch.nn.Sequential快速建立神经网络结构 net2 = torch.nn.Sequential( torch.nn.Linear(2, 10), torch.nn.ReLU(), torch.nn.Linear(10, 2), ) print('方法2:\n', net2) # 经验证,两种方法构建的神经网络功能相同,结构细节稍有不同 ''''' 方法1: Net ( (hidden): Linear (2 -> 10) (predict): Linear (10 -> 2) ) 方法2: Sequential ( (0): Linear (2 -> 10) (1): ReLU () (2): Linear (10 -> 2) ) '''
Auparavant, j'ai appris à construire un réseau de neurones en définissant une classe Net. Dans classNet, on hérite d'abord. la méthode de construction du module torch.nn.Module via la super fonction. Construisez ensuite les informations structurelles de chaque couche du réseau neuronal en ajoutant des attributs, améliorez les informations de connexion entre chaque couche du réseau neuronal dans la méthode forward, puis terminer la construction de la structure du réseau neuronal en définissant l'objet de classe Net.
Une autre façon de construire un réseau de neurones, qui peut également être considérée comme une méthode de construction rapide, consiste à terminer directement l'établissement du réseau de neurones via torch.nn.Sequential.
Les structures de réseau neuronal construites par les deux méthodes sont exactement les mêmes et les informations du réseau peuvent être imprimées via la fonction d'impression, mais les résultats d'impression seront légèrement différents.
2. Préservation et extraction du réseau neuronal PyTorch
Lors de l'apprentissage et de la recherche sur l'apprentissage profond, lorsque nous traversons une certaine période de formation, Quand nous obtenons un meilleur modèle, bien sûr, nous voulons enregistrer le modèle et les paramètres du modèle pour une utilisation ultérieure, il est donc nécessaire de sauvegarder le réseau neuronal et d'extraire et de recharger les paramètres du modèle.
Tout d'abord, nous devons enregistrer la structure du réseau et les paramètres du modèle via torch.save() après la définition et la formation de la partie du réseau neuronal qui doit enregistrer la structure du réseau et ses paramètres de modèle. Il existe deux méthodes de sauvegarde : l'une consiste à sauvegarder les informations structurelles et les informations sur les paramètres du modèle de l'ensemble du réseau neuronal, et l'objet de la sauvegarde est le réseau, l'autre consiste à sauvegarder uniquement les paramètres du modèle de formation du réseau neuronal, et le l'objet de la sauvegarde est net.state_dict(), les résultats enregistrés sont stockés sous forme de fichiers .pkl.
correspond aux deux méthodes de sauvegarde ci-dessus, et il existe également deux méthodes de rechargement. Correspondant aux premières informations complètes sur la structure du réseau, vous pouvez directement initialiser le nouvel objet de réseau neuronal via torch.load('.pkl') lors du rechargement. Correspondant à la deuxième méthode consistant à enregistrer uniquement les informations sur les paramètres du modèle, vous devez d'abord créer la même structure de réseau neuronal et terminer le rechargement des paramètres du modèle via net.load_state_dict(torch.load('.pkl')). Lorsque le réseau est relativement étendu, la première méthode prendra plus de temps.
Mise en œuvre du code :
import torch from torch.autograd import Variable import matplotlib.pyplot as plt torch.manual_seed(1) # 设定随机数种子 # 创建数据 x = torch.unsqueeze(torch.linspace(-1, 1, 100), dim=1) y = x.pow(2) + 0.2*torch.rand(x.size()) x, y = Variable(x, requires_grad=False), Variable(y, requires_grad=False) # 将待保存的神经网络定义在一个函数中 def save(): # 神经网络结构 net1 = torch.nn.Sequential( torch.nn.Linear(1, 10), torch.nn.ReLU(), torch.nn.Linear(10, 1), ) optimizer = torch.optim.SGD(net1.parameters(), lr=0.5) loss_function = torch.nn.MSELoss() # 训练部分 for i in range(300): prediction = net1(x) loss = loss_function(prediction, y) optimizer.zero_grad() loss.backward() optimizer.step() # 绘图部分 plt.figure(1, figsize=(10, 3)) plt.subplot(131) plt.title('net1') plt.scatter(x.data.numpy(), y.data.numpy()) plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5) # 保存神经网络 torch.save(net1, '7-net.pkl') # 保存整个神经网络的结构和模型参数 torch.save(net1.state_dict(), '7-net_params.pkl') # 只保存神经网络的模型参数 # 载入整个神经网络的结构及其模型参数 def reload_net(): net2 = torch.load('7-net.pkl') prediction = net2(x) plt.subplot(132) plt.title('net2') plt.scatter(x.data.numpy(), y.data.numpy()) plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5) # 只载入神经网络的模型参数,神经网络的结构需要与保存的神经网络相同的结构 def reload_params(): # 首先搭建相同的神经网络结构 net3 = torch.nn.Sequential( torch.nn.Linear(1, 10), torch.nn.ReLU(), torch.nn.Linear(10, 1), ) # 载入神经网络的模型参数 net3.load_state_dict(torch.load('7-net_params.pkl')) prediction = net3(x) plt.subplot(133) plt.title('net3') plt.scatter(x.data.numpy(), y.data.numpy()) plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5) # 运行测试 save() reload_net() reload_params()
Résultats expérimentaux :
Recommandations associées :
Comment implémenter le réseau neuronal convolutionnel CNN sur PyTorch
Explication détaillée de la formation par lots PyTorch et comparaison des optimiseurs
Introduction à l'exemple de classification mnist Pytorch
Ce qui précède est le contenu détaillé de. pour plus d'informations, suivez d'autres articles connexes sur le site Web de PHP en chinois!