首页 后端开发 Python教程 PyTorch快速搭建神经网络及其保存提取方法详解

PyTorch快速搭建神经网络及其保存提取方法详解

Apr 28, 2018 am 10:56 AM
pytorch 神经网络

本篇文章主要介绍了PyTorch快速搭建神经网络及其保存提取方法详解,现在分享给大家,也给大家做个参考。一起过来看看吧

有时候我们训练了一个模型, 希望保存它下次直接使用,不需要下次再花时间去训练 ,本节我们来讲解一下PyTorch快速搭建神经网络及其保存提取方法详解

一、PyTorch快速搭建神经网络方法

先看实验代码:

import torch 
import torch.nn.functional as F 
 
# 方法1,通过定义一个Net类来建立神经网络 
class Net(torch.nn.Module): 
  def __init__(self, n_feature, n_hidden, n_output): 
    super(Net, self).__init__() 
    self.hidden = torch.nn.Linear(n_feature, n_hidden) 
    self.predict = torch.nn.Linear(n_hidden, n_output) 
 
  def forward(self, x): 
    x = F.relu(self.hidden(x)) 
    x = self.predict(x) 
    return x 
 
net1 = Net(2, 10, 2) 
print('方法1:\n', net1) 
 
# 方法2 通过torch.nn.Sequential快速建立神经网络结构 
net2 = torch.nn.Sequential( 
  torch.nn.Linear(2, 10), 
  torch.nn.ReLU(), 
  torch.nn.Linear(10, 2), 
  ) 
print('方法2:\n', net2) 
# 经验证,两种方法构建的神经网络功能相同,结构细节稍有不同 
 
''''' 
方法1: 
 Net ( 
 (hidden): Linear (2 -> 10) 
 (predict): Linear (10 -> 2) 
) 
方法2: 
 Sequential ( 
 (0): Linear (2 -> 10) 
 (1): ReLU () 
 (2): Linear (10 -> 2) 
) 
'''
登录后复制

先前学习了通过定义一个Net类来构建神经网络的方法,classNet中首先通过super函数继承torch.nn.Module模块的构造方法,再通过添加属性的方式搭建神经网络各层的结构信息,在forward方法中完善神经网络各层之间的连接信息,然后再通过定义Net类对象的方式完成对神经网络结构的构建。

构建神经网络的另一个方法,也可以说是快速构建方法,就是通过torch.nn.Sequential,直接完成对神经网络的建立。

两种方法构建得到的神经网络结构完全相同,都可以通过print函数来打印输出网络信息,不过打印结果会有些许不同。

二、PyTorch的神经网络保存和提取

在学习和研究深度学习的时候,当我们通过一定时间的训练,得到了一个比较好的模型的时候,我们当然希望将这个模型及模型参数保存下来,以备后用,所以神经网络的保存和模型参数提取重载是很有必要的。

首先,我们需要在需要保存网路结构及其模型参数的神经网络的定义、训练部分之后通过torch.save()实现对网络结构和模型参数的保存。有两种保存方式:一是保存年整个神经网络的的结构信息和模型参数信息,save的对象是网络net;二是只保存神经网络的训练模型参数,save的对象是net.state_dict(),保存结果都以.pkl文件形式存储。

对应上面两种保存方式,重载方式也有两种。对应第一种完整网络结构信息,重载的时候通过torch.load(‘.pkl')直接初始化新的神经网络对象即可。对应第二种只保存模型参数信息,需要首先搭建相同的神经网络结构,通过net.load_state_dict(torch.load('.pkl'))完成模型参数的重载。在网络比较大的时候,第一种方法会花费较多的时间。

代码实现:

import torch 
from torch.autograd import Variable 
import matplotlib.pyplot as plt 
 
torch.manual_seed(1) # 设定随机数种子 
 
# 创建数据 
x = torch.unsqueeze(torch.linspace(-1, 1, 100), dim=1) 
y = x.pow(2) + 0.2*torch.rand(x.size()) 
x, y = Variable(x, requires_grad=False), Variable(y, requires_grad=False) 
 
# 将待保存的神经网络定义在一个函数中 
def save(): 
  # 神经网络结构 
  net1 = torch.nn.Sequential( 
    torch.nn.Linear(1, 10), 
    torch.nn.ReLU(), 
    torch.nn.Linear(10, 1), 
    ) 
  optimizer = torch.optim.SGD(net1.parameters(), lr=0.5) 
  loss_function = torch.nn.MSELoss() 
 
  # 训练部分 
  for i in range(300): 
    prediction = net1(x) 
    loss = loss_function(prediction, y) 
    optimizer.zero_grad() 
    loss.backward() 
    optimizer.step() 
 
  # 绘图部分 
  plt.figure(1, figsize=(10, 3)) 
  plt.subplot(131) 
  plt.title('net1') 
  plt.scatter(x.data.numpy(), y.data.numpy()) 
  plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5) 
 
  # 保存神经网络 
  torch.save(net1, '7-net.pkl')           # 保存整个神经网络的结构和模型参数 
  torch.save(net1.state_dict(), '7-net_params.pkl') # 只保存神经网络的模型参数 
 
# 载入整个神经网络的结构及其模型参数 
def reload_net(): 
  net2 = torch.load('7-net.pkl') 
  prediction = net2(x) 
 
  plt.subplot(132) 
  plt.title('net2') 
  plt.scatter(x.data.numpy(), y.data.numpy()) 
  plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5) 
 
# 只载入神经网络的模型参数,神经网络的结构需要与保存的神经网络相同的结构 
def reload_params(): 
  # 首先搭建相同的神经网络结构 
  net3 = torch.nn.Sequential( 
    torch.nn.Linear(1, 10), 
    torch.nn.ReLU(), 
    torch.nn.Linear(10, 1), 
    ) 
 
  # 载入神经网络的模型参数 
  net3.load_state_dict(torch.load('7-net_params.pkl')) 
  prediction = net3(x) 
 
  plt.subplot(133) 
  plt.title('net3') 
  plt.scatter(x.data.numpy(), y.data.numpy()) 
  plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5) 
 
# 运行测试 
save() 
reload_net() 
reload_params()
登录后复制

实验结果:

相关推荐:

PyTorch上实现卷积神经网络CNN的方法

详解PyTorch批训练及优化器比较

Pytorch入门之mnist分类实例

以上是PyTorch快速搭建神经网络及其保存提取方法详解的详细内容。更多信息请关注PHP中文网其他相关文章!

本站声明
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系admin@php.cn

热AI工具

Undresser.AI Undress

Undresser.AI Undress

人工智能驱动的应用程序,用于创建逼真的裸体照片

AI Clothes Remover

AI Clothes Remover

用于从照片中去除衣服的在线人工智能工具。

Undress AI Tool

Undress AI Tool

免费脱衣服图片

Clothoff.io

Clothoff.io

AI脱衣机

AI Hentai Generator

AI Hentai Generator

免费生成ai无尽的。

热工具

记事本++7.3.1

记事本++7.3.1

好用且免费的代码编辑器

SublimeText3汉化版

SublimeText3汉化版

中文版,非常好用

禅工作室 13.0.1

禅工作室 13.0.1

功能强大的PHP集成开发环境

Dreamweaver CS6

Dreamweaver CS6

视觉化网页开发工具

SublimeText3 Mac版

SublimeText3 Mac版

神级代码编辑软件(SublimeText3)

科大讯飞:华为升腾 910B 能力基本可对标英伟达 A100,正合力打造我国通用人工智能新底座 科大讯飞:华为升腾 910B 能力基本可对标英伟达 A100,正合力打造我国通用人工智能新底座 Oct 22, 2023 pm 06:13 PM

本站10月22日消息,今年第三季度,科大讯飞实现净利润2579万元,同比下降81.86%;前三季度净利润9936万元,同比下降76.36%。科大讯飞副总裁江涛在Q3业绩说明会上透露,讯飞已于2023年初与华为升腾启动专项攻关,与华为联合研发高性能算子库,合力打造我国通用人工智能新底座,让国产大模型架构在自主创新的软硬件基础之上。他指出,目前华为升腾910B能力已经基本做到可对标英伟达A100。在即将举行的科大讯飞1024全球开发者节上,讯飞和华为在人工智能算力底座上将有进一步联合发布。他还提到,

PyCharm与PyTorch完美结合:安装配置步骤详解 PyCharm与PyTorch完美结合:安装配置步骤详解 Feb 21, 2024 pm 12:00 PM

PyCharm是一款强大的集成开发环境(IDE),而PyTorch是深度学习领域备受欢迎的开源框架。在机器学习和深度学习领域,使用PyCharm和PyTorch进行开发可以极大地提高开发效率和代码质量。本文将详细介绍如何在PyCharm中安装配置PyTorch,并附上具体的代码示例,帮助读者更好地利用这两者的强大功能。第一步:安装PyCharm和Python

YOLO不死!YOLOv9出炉:性能速度SOTA~ YOLO不死!YOLOv9出炉:性能速度SOTA~ Feb 26, 2024 am 11:31 AM

如今的深度学习方法专注于设计最适合的目标函数,以使模型的预测结果与实际情况最接近。同时,必须设计一个合适的架构,以便为预测获取足够的信息。现有方法忽略了一个事实,即当输入数据经过逐层特征提取和空间变换时,大量信息将会丢失。本文将深入探讨数据通过深度网络传输时的重要问题,即信息瓶颈和可逆函数。基于此提出了可编程梯度信息(PGI)的概念,以应对深度网络实现多目标所需的各种变化。PGI可以为目标任务提供完整的输入信息,以计算目标函数,从而获得可靠的梯度信息以更新网络权重。此外设计了一种新的轻量级网络架

一文通览自动驾驶三大主流芯片架构 一文通览自动驾驶三大主流芯片架构 Apr 12, 2023 pm 12:07 PM

当前主流的AI芯片主要分为三类,GPU、FPGA、ASIC。GPU、FPGA均是前期较为成熟的芯片架构,属于通用型芯片。ASIC属于为AI特定场景定制的芯片。行业内已经确认CPU不适用于AI计算,但是在AI应用领域也是必不可少。 GPU方案GPU与CPU的架构对比CPU遵循的是冯·诺依曼架构,其核心是存储程序/数据、串行顺序执行。因此CPU的架构中需要大量的空间去放置存储单元(Cache)和控制单元(Control),相比之下计算单元(ALU)只占据了很小的一部分,所以CPU在进行大规模并行计算

自然语言生成任务中的五种采样方法介绍和Pytorch代码实现 自然语言生成任务中的五种采样方法介绍和Pytorch代码实现 Feb 20, 2024 am 08:50 AM

在自然语言生成任务中,采样方法是从生成模型中获得文本输出的一种技术。这篇文章将讨论5种常用方法,并使用PyTorch进行实现。1、GreedyDecoding在贪婪解码中,生成模型根据输入序列逐个时间步地预测输出序列的单词。在每个时间步,模型会计算每个单词的条件概率分布,然后选择具有最高条件概率的单词作为当前时间步的输出。这个单词成为下一个时间步的输入,生成过程会持续直到满足某种终止条件,比如生成了指定长度的序列或者生成了特殊的结束标记。GreedyDecoding的特点是每次选择当前条件概率最

"B站UP主成功打造全球首个基于红石的神经网络在社交媒体引起轰动,得到Yann LeCun的点赞赞赏" "B站UP主成功打造全球首个基于红石的神经网络在社交媒体引起轰动,得到Yann LeCun的点赞赞赏" May 07, 2023 pm 10:58 PM

在我的世界(Minecraft)中,红石是一种非常重要的物品。它是游戏中的一种独特材料,开关、红石火把和红石块等能对导线或物体提供类似电流的能量。红石电路可以为你建造用于控制或激活其他机械的结构,其本身既可以被设计为用于响应玩家的手动激活,也可以反复输出信号或者响应非玩家引发的变化,如生物移动、物品掉落、植物生长、日夜更替等等。因此,在我的世界中,红石能够控制的机械类别极其多,小到简单机械如自动门、光开关和频闪电源,大到占地巨大的电梯、自动农场、小游戏平台甚至游戏内建的计算机。近日,B站UP主@

用PyTorch实现噪声去除扩散模型 用PyTorch实现噪声去除扩散模型 Jan 14, 2024 pm 10:33 PM

在详细了解去噪扩散概率模型(DDPM)的工作原理之前,我们先来了解一下生成式人工智能的一些发展情况,这也是DDPM的基础研究之一。 VAEVAE使用编码器、概率潜在空间和解码器。在训练过程中,编码器预测每个图像的均值和方差,并从高斯分布中对这些值进行采样。采样的结果传递到解码器中,解码器将输入图像转换为与输出图像相似的形式。 KL散度用于计算损失。 VAE的一个显着优势是其能够生成多样化的图像。在采样阶段,可以直接从高斯分布中采样,并通过解码器生成新的图像。 GAN在变分自编码器(VAEs)的短短一年之

安装PyTorch的PyCharm教程 安装PyTorch的PyCharm教程 Feb 24, 2024 am 10:09 AM

PyTorch作为一款功能强大的深度学习框架,被广泛应用于各类机器学习项目中。PyCharm作为一款强大的Python集成开发环境,在实现深度学习任务时也能提供很好的支持。本文将详细介绍如何在PyCharm中安装PyTorch,并提供具体的代码示例,帮助读者快速上手使用PyTorch进行深度学习任务。第一步:安装PyCharm首先,我们需要确保已经在计算机上

See all articles