PyTorch로 신경망을 빠르게 구축하고 저장 및 추출하는 방법에 대한 자세한 설명

不言
풀어 주다: 2018-04-28 10:56:06
원래의
2541명이 탐색했습니다.

이 글에서는 신경망을 빠르게 구축하기 위한 PyTorch를 주로 소개하고 저장 및 추출 방법에 대한 자세한 설명을 공유하고 참고하겠습니다. 함께 살펴보겠습니다

때때로 우리는 모델을 훈련시켰고 다음번 훈련에 시간을 들이지 않고 직접 사용할 수 있도록 저장하고 싶을 때가 있습니다. 이 섹션에서는 PyTorch를 사용하여 신경망을 빠르게 구축하는 방법과 저장 방법을 설명하겠습니다.

1. PyTorch로 빠르게 신경망을 구축하는 방법

먼저 실험 코드를 살펴보세요.

import torch 
import torch.nn.functional as F 
 
# 方法1,通过定义一个Net类来建立神经网络 
class Net(torch.nn.Module): 
  def __init__(self, n_feature, n_hidden, n_output): 
    super(Net, self).__init__() 
    self.hidden = torch.nn.Linear(n_feature, n_hidden) 
    self.predict = torch.nn.Linear(n_hidden, n_output) 
 
  def forward(self, x): 
    x = F.relu(self.hidden(x)) 
    x = self.predict(x) 
    return x 
 
net1 = Net(2, 10, 2) 
print('方法1:\n', net1) 
 
# 方法2 通过torch.nn.Sequential快速建立神经网络结构 
net2 = torch.nn.Sequential( 
  torch.nn.Linear(2, 10), 
  torch.nn.ReLU(), 
  torch.nn.Linear(10, 2), 
  ) 
print('方法2:\n', net2) 
# 经验证,两种方法构建的神经网络功能相同,结构细节稍有不同 
 
''''' 
方法1: 
 Net ( 
 (hidden): Linear (2 -> 10) 
 (predict): Linear (10 -> 2) 
) 
方法2: 
 Sequential ( 
 (0): Linear (2 -> 10) 
 (1): ReLU () 
 (2): Linear (10 -> 2) 
) 
'''
로그인 후 복사

이전에 신경망을 구축하는 방법을 다음과 같이 정의했습니다. Net 클래스에서 먼저 super를 전달합니다. 이 함수는 torch.nn.Module 모듈의 구성 방법을 상속한 다음 속성을 추가하여 신경망의 각 계층의 구조 정보를 구축하고 각 계층 간의 연결 정보를 향상시킵니다. 신경망을 순방향 메소드로 정의한 다음 Net 클래스 객체를 정의하여 신경망 구조의 구성을 완료합니다.

빠른 구축 방법이라고도 할 수 있는 신경망을 구축하는 또 다른 방법은 torch.nn.Sequential을 통해 직접 신경망 구축을 완료하는 것입니다.

두 가지 방법으로 구성한 신경망 구조는 완전히 동일하며, 인쇄 기능을 통해 네트워크 정보를 출력할 수 있지만 인쇄 결과는 조금씩 다릅니다.

2. PyTorch 신경망 저장 및 추출

딥러닝을 학습하고 연구할 때 특정 기간의 훈련을 거쳐 더 나은 모델을 얻으면 당연히 이 모델을 사용하기를 바라며 모델 매개변수는 저장됩니다. 나중에 사용하려면 신경망을 저장하고 모델 매개변수를 추출하고 다시 로드해야 합니다.

먼저 네트워크 구조와 모델 매개변수를 저장해야 하는 신경망의 정의 및 훈련 부분 이후에 torch.save()를 통해 네트워크 구조와 모델 매개변수를 저장해야 합니다. 저장 방법에는 두 가지가 있는데, 하나는 전체 신경망의 구조 정보와 모델 매개변수 정보를 저장하는 것이고, 다른 하나는 신경망의 훈련 모델 매개변수만 저장하는 것이고, 저장 객체는 net.state_dict()이며, 저장된 결과는 .pkl 파일 형태로 저장됩니다.

위의 두 가지 저장 방법에 해당하며, 다시 로드하는 방법도 두 가지가 있습니다. 첫 번째 완전한 네트워크 구조 정보에 해당하여, 다시 로드할 때 torch.load('.pkl')을 통해 새로운 신경망 객체를 직접 초기화할 수 있습니다. 모델 매개변수 정보만 저장하는 두 번째 방법에 해당하면 먼저 동일한 신경망 구조를 구축하고 net.load_state_dict(torch.load('.pkl'))를 통해 모델 매개변수 다시 로드를 완료해야 합니다. 네트워크 규모가 상대적으로 큰 경우 첫 번째 방법을 사용하면 시간이 더 오래 걸립니다.

코드 구현:

import torch 
from torch.autograd import Variable 
import matplotlib.pyplot as plt 
 
torch.manual_seed(1) # 设定随机数种子 
 
# 创建数据 
x = torch.unsqueeze(torch.linspace(-1, 1, 100), dim=1) 
y = x.pow(2) + 0.2*torch.rand(x.size()) 
x, y = Variable(x, requires_grad=False), Variable(y, requires_grad=False) 
 
# 将待保存的神经网络定义在一个函数中 
def save(): 
  # 神经网络结构 
  net1 = torch.nn.Sequential( 
    torch.nn.Linear(1, 10), 
    torch.nn.ReLU(), 
    torch.nn.Linear(10, 1), 
    ) 
  optimizer = torch.optim.SGD(net1.parameters(), lr=0.5) 
  loss_function = torch.nn.MSELoss() 
 
  # 训练部分 
  for i in range(300): 
    prediction = net1(x) 
    loss = loss_function(prediction, y) 
    optimizer.zero_grad() 
    loss.backward() 
    optimizer.step() 
 
  # 绘图部分 
  plt.figure(1, figsize=(10, 3)) 
  plt.subplot(131) 
  plt.title('net1') 
  plt.scatter(x.data.numpy(), y.data.numpy()) 
  plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5) 
 
  # 保存神经网络 
  torch.save(net1, '7-net.pkl')           # 保存整个神经网络的结构和模型参数 
  torch.save(net1.state_dict(), '7-net_params.pkl') # 只保存神经网络的模型参数 
 
# 载入整个神经网络的结构及其模型参数 
def reload_net(): 
  net2 = torch.load('7-net.pkl') 
  prediction = net2(x) 
 
  plt.subplot(132) 
  plt.title('net2') 
  plt.scatter(x.data.numpy(), y.data.numpy()) 
  plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5) 
 
# 只载入神经网络的模型参数,神经网络的结构需要与保存的神经网络相同的结构 
def reload_params(): 
  # 首先搭建相同的神经网络结构 
  net3 = torch.nn.Sequential( 
    torch.nn.Linear(1, 10), 
    torch.nn.ReLU(), 
    torch.nn.Linear(10, 1), 
    ) 
 
  # 载入神经网络的模型参数 
  net3.load_state_dict(torch.load('7-net_params.pkl')) 
  prediction = net3(x) 
 
  plt.subplot(133) 
  plt.title('net3') 
  plt.scatter(x.data.numpy(), y.data.numpy()) 
  plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5) 
 
# 运行测试 
save() 
reload_net() 
reload_params()
로그인 후 복사

실험 결과:

관련 권장 사항:

PyTorch에서 CNN을 구현하는 방법

PyT에 대한 자세한 설명 orch 배치 훈련 및 최적화 비교

Pytorch 소개 - mnist 분류 예제

위 내용은 PyTorch로 신경망을 빠르게 구축하고 저장 및 추출하는 방법에 대한 자세한 설명의 상세 내용입니다. 자세한 내용은 PHP 중국어 웹사이트의 기타 관련 기사를 참조하세요!

관련 라벨:
원천:php.cn
본 웹사이트의 성명
본 글의 내용은 네티즌들의 자발적인 기여로 작성되었으며, 저작권은 원저작자에게 있습니다. 본 사이트는 이에 상응하는 법적 책임을 지지 않습니다. 표절이나 침해가 의심되는 콘텐츠를 발견한 경우 admin@php.cn으로 문의하세요.
인기 튜토리얼
더>
최신 다운로드
더>
웹 효과
웹사이트 소스 코드
웹사이트 자료
프론트엔드 템플릿