首頁 後端開發 Python教學 PyTorch快速建構神經網路及其保存擷取方法詳解

PyTorch快速建構神經網路及其保存擷取方法詳解

Apr 28, 2018 am 10:56 AM
pytorch 神經網路

本篇主要介紹了PyTorch快速搭建神經網路及其保存擷取方法詳解,現在分享給大家,也給大家做個參考。一起過來看看吧

有時候我們訓練了一個模型, 希望保存它下次直接使用,不需要下次再花時間去訓練,本節我們來講解一下PyTorch快速搭建神經網路及其保存擷取方法詳解

一、PyTorch快速搭建神經網路方法

先看實驗程式碼:

import torch 
import torch.nn.functional as F 
 
# 方法1,通过定义一个Net类来建立神经网络 
class Net(torch.nn.Module): 
  def __init__(self, n_feature, n_hidden, n_output): 
    super(Net, self).__init__() 
    self.hidden = torch.nn.Linear(n_feature, n_hidden) 
    self.predict = torch.nn.Linear(n_hidden, n_output) 
 
  def forward(self, x): 
    x = F.relu(self.hidden(x)) 
    x = self.predict(x) 
    return x 
 
net1 = Net(2, 10, 2) 
print('方法1:\n', net1) 
 
# 方法2 通过torch.nn.Sequential快速建立神经网络结构 
net2 = torch.nn.Sequential( 
  torch.nn.Linear(2, 10), 
  torch.nn.ReLU(), 
  torch.nn.Linear(10, 2), 
  ) 
print('方法2:\n', net2) 
# 经验证,两种方法构建的神经网络功能相同,结构细节稍有不同 
 
''''' 
方法1: 
 Net ( 
 (hidden): Linear (2 -> 10) 
 (predict): Linear (10 -> 2) 
) 
方法2: 
 Sequential ( 
 (0): Linear (2 -> 10) 
 (1): ReLU () 
 (2): Linear (10 -> 2) 
) 
'''
登入後複製

先前學習了透過定義一個Net類別來建構神經網路的方法,classNet中首先透過super函數繼承torch.nn.Module模組的建構方法,再透過增加屬性的方式建構神經網路各層的結構訊息,在forward方法中完善神經網路各層之間的連接訊息,然後再透過定義Net類別物件的方式完成對神經網路結構的建構。

建構神經網路的另一個方法,也可以說是快速建構方法,就是透過torch.nn.Sequential,直接完成對神經網路的建立。

兩種方法建構得到的神經網路結構完全相同,都可以透過print函數來列印輸出網路訊息,不過列印結果會有些許不同。

二、PyTorch的神經網路保存和提取

#在學習和研究深度學習的時候,當我們經過一定時間的訓練,得到了一個比較好的模型的時候,我們當然希望將這個模型及模型參數保存下來,以備後用,所以神經網路的保存和模型參數提取重載是很有必要的。

首先,我們需要在需要保存網路結構及其模型參數的神經網路的定義、訓練部分之後透過torch.save()實現網路結構和模型參數的保存。有兩種保存方式:一是保存年整個神經網路的的結構資訊和模型參數訊息,save的物件是網路net;二是只保存神經網路的訓練模型參數,save的物件是net.state_dict(),保存結果都以.pkl檔案形式儲存。

對應上面兩種保存方式,重載方式也有兩種。對應第一種完整網路結構訊息,重載的時候透過torch.load(‘.pkl')直接初始化新的神經網路物件即可。對應第二種只保存模型參數訊息,需要先搭建相同的神經網路結構,透過net.load_state_dict(torch.load('.pkl'))完成模型參數的重載。在網路比較大的時候,第一種方法會花費較多的時間。

程式碼實作:

import torch 
from torch.autograd import Variable 
import matplotlib.pyplot as plt 
 
torch.manual_seed(1) # 设定随机数种子 
 
# 创建数据 
x = torch.unsqueeze(torch.linspace(-1, 1, 100), dim=1) 
y = x.pow(2) + 0.2*torch.rand(x.size()) 
x, y = Variable(x, requires_grad=False), Variable(y, requires_grad=False) 
 
# 将待保存的神经网络定义在一个函数中 
def save(): 
  # 神经网络结构 
  net1 = torch.nn.Sequential( 
    torch.nn.Linear(1, 10), 
    torch.nn.ReLU(), 
    torch.nn.Linear(10, 1), 
    ) 
  optimizer = torch.optim.SGD(net1.parameters(), lr=0.5) 
  loss_function = torch.nn.MSELoss() 
 
  # 训练部分 
  for i in range(300): 
    prediction = net1(x) 
    loss = loss_function(prediction, y) 
    optimizer.zero_grad() 
    loss.backward() 
    optimizer.step() 
 
  # 绘图部分 
  plt.figure(1, figsize=(10, 3)) 
  plt.subplot(131) 
  plt.title('net1') 
  plt.scatter(x.data.numpy(), y.data.numpy()) 
  plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5) 
 
  # 保存神经网络 
  torch.save(net1, '7-net.pkl')           # 保存整个神经网络的结构和模型参数 
  torch.save(net1.state_dict(), '7-net_params.pkl') # 只保存神经网络的模型参数 
 
# 载入整个神经网络的结构及其模型参数 
def reload_net(): 
  net2 = torch.load('7-net.pkl') 
  prediction = net2(x) 
 
  plt.subplot(132) 
  plt.title('net2') 
  plt.scatter(x.data.numpy(), y.data.numpy()) 
  plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5) 
 
# 只载入神经网络的模型参数,神经网络的结构需要与保存的神经网络相同的结构 
def reload_params(): 
  # 首先搭建相同的神经网络结构 
  net3 = torch.nn.Sequential( 
    torch.nn.Linear(1, 10), 
    torch.nn.ReLU(), 
    torch.nn.Linear(10, 1), 
    ) 
 
  # 载入神经网络的模型参数 
  net3.load_state_dict(torch.load('7-net_params.pkl')) 
  prediction = net3(x) 
 
  plt.subplot(133) 
  plt.title('net3') 
  plt.scatter(x.data.numpy(), y.data.numpy()) 
  plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5) 
 
# 运行测试 
save() 
reload_net() 
reload_params()
登入後複製

實驗結果:

相關推薦:

PyTorch上實作卷積神經網路CNN的方法

#詳解PyTorch批訓練及最佳化器比較

Pytorch入門之mnist分類實例

#

以上是PyTorch快速建構神經網路及其保存擷取方法詳解的詳細內容。更多資訊請關注PHP中文網其他相關文章!

本網站聲明
本文內容由網友自願投稿,版權歸原作者所有。本站不承擔相應的法律責任。如發現涉嫌抄襲或侵權的內容,請聯絡admin@php.cn

熱AI工具

Undresser.AI Undress

Undresser.AI Undress

人工智慧驅動的應用程序,用於創建逼真的裸體照片

AI Clothes Remover

AI Clothes Remover

用於從照片中去除衣服的線上人工智慧工具。

Undress AI Tool

Undress AI Tool

免費脫衣圖片

Clothoff.io

Clothoff.io

AI脫衣器

Video Face Swap

Video Face Swap

使用我們完全免費的人工智慧換臉工具,輕鬆在任何影片中換臉!

熱工具

記事本++7.3.1

記事本++7.3.1

好用且免費的程式碼編輯器

SublimeText3漢化版

SublimeText3漢化版

中文版,非常好用

禪工作室 13.0.1

禪工作室 13.0.1

強大的PHP整合開發環境

Dreamweaver CS6

Dreamweaver CS6

視覺化網頁開發工具

SublimeText3 Mac版

SublimeText3 Mac版

神級程式碼編輯軟體(SublimeText3)

科大讯飞:华为昇腾 910B 能力基本可对标英伟达 A100,正合力打造我国通用人工智能新底座 科大讯飞:华为昇腾 910B 能力基本可对标英伟达 A100,正合力打造我国通用人工智能新底座 Oct 22, 2023 pm 06:13 PM

本站10月22日消息,今年第三季度,科大讯飞实现净利润2579万元,同比下降81.86%;前三季度净利润9936万元,同比下降76.36%。科大讯飞副总裁江涛在Q3业绩说明会上透露,讯飞已于2023年初与华为昇腾启动专项攻关,与华为联合研发高性能算子库,合力打造我国通用人工智能新底座,让国产大模型架构在自主创新的软硬件基础之上。他指出,目前华为昇腾910B能力已经基本做到可对标英伟达A100。在即将举行的科大讯飞1024全球开发者节上,讯飞和华为在人工智能算力底座上将有进一步联合发布。他还提到,

PyCharm與PyTorch完美結合:安裝設定步驟詳解 PyCharm與PyTorch完美結合:安裝設定步驟詳解 Feb 21, 2024 pm 12:00 PM

PyCharm是一款強大的整合開發環境(IDE),而PyTorch則是深度學習領域備受歡迎的開源架構。在機器學習和深度學習領域,使用PyCharm和PyTorch進行開發可以大大提高開發效率和程式碼品質。本文將詳細介紹如何在PyCharm中安裝設定PyTorch,並附上具體的程式碼範例,幫助讀者更好地利用這兩者的強大功能。第一步:安裝PyCharm和Python

YOLO不死! YOLOv9出爐:性能速度SOTA~ YOLO不死! YOLOv9出爐:性能速度SOTA~ Feb 26, 2024 am 11:31 AM

如今的深度學習方法專注於設計最適合的目標函數,以使模型的預測結果與實際情況最接近。同時,必須設計一個合適的架構,以便為預測取得足夠的資訊。現有方法忽略了一個事實,當輸入資料經過逐層特徵提取和空間變換時,大量資訊將會遺失。本文將深入探討資料透過深度網路傳輸時的重要問題,即資訊瓶頸和可逆函數。基於此提出了可編程梯度資訊(PGI)的概念,以應對深度網路實現多目標所需的各種變化。 PGI可以為目標任務提供完整的輸入訊息,以計算目標函數,從而獲得可靠的梯度資訊以更新網路權重。此外設計了一種新的輕量級網路架

一文通覽自動駕駛三大主流晶片架構 一文通覽自動駕駛三大主流晶片架構 Apr 12, 2023 pm 12:07 PM

目前主流的AI晶片主要分為三類,GPU、FPGA、ASIC。 GPU、FPGA皆是前期較成熟的晶片架構,屬於通用型晶片。 ASIC屬於為AI特定場景定制的晶片。業界已經確認CPU不適用於AI計算,但在AI應用領域也是不可或缺。 GPU方案GPU與CPU的架構比較CPU遵循的是馮諾依曼架構,其核心是儲存程式/資料、序列順序執行。因此CPU的架構中需要大量的空間去放置儲存單元(Cache)和控制單元(Control),相較之下運算單元(ALU)只佔據了很小的一部分,所以CPU在進行大規模平行運算

'B站UP主成功打造全球首個基於紅石的神經網絡在社交媒體引起轟動,得到Yann LeCun的點贊讚賞' 'B站UP主成功打造全球首個基於紅石的神經網絡在社交媒體引起轟動,得到Yann LeCun的點贊讚賞' May 07, 2023 pm 10:58 PM

在我的世界(Minecraft)中,紅石是一種非常重要的物品。它是遊戲中獨特的材料,開關、紅石火把和紅石塊等能對導線或物體提供類似電流的能量。紅石電路可以為你建造用於控製或激活其他機械的結構,其本身既可以被設計為用於響應玩家的手動激活,也可以反複輸出信號或者響應非玩家引發的變化,如生物移動、物品掉落、植物生長、日夜更替等等。因此,在我的世界中,紅石能夠控制的機械類別極其多,小到簡單機械如自動門、光開關和頻閃電源,大到佔地巨大的電梯、自動農場、小遊戲平台甚至遊戲內建的計算機。近日,B站UP主@

自然語言生成任務中的五種採樣方法介紹和Pytorch程式碼實現 自然語言生成任務中的五種採樣方法介紹和Pytorch程式碼實現 Feb 20, 2024 am 08:50 AM

在自然語言生成任務中,取樣方法是從生成模型中獲得文字輸出的一種技術。這篇文章將討論5種常用方法,並使用PyTorch進行實作。 1.GreedyDecoding在貪婪解碼中,生成模型根據輸入序列逐個時間步地預測輸出序列的單字。在每個時間步,模型會計算每個單字的條件機率分佈,然後選擇具有最高條件機率的單字作為當前時間步的輸出。這個單字成為下一個時間步的輸入,生成過程會持續直到滿足某種終止條件,例如產生了指定長度的序列或產生了特殊的結束標記。 GreedyDecoding的特點是每次選擇當前條件機率最

用PyTorch實現雜訊去除擴散模型 用PyTorch實現雜訊去除擴散模型 Jan 14, 2024 pm 10:33 PM

在详细了解去噪扩散概率模型(DDPM)的工作原理之前,我们先来了解一下生成式人工智能的一些发展情况,这也是DDPM的基础研究之一。VAEVAE使用编码器、概率潜在空间和解码器。在训练过程中,编码器预测每个图像的均值和方差,并从高斯分布中对这些值进行采样。采样的结果传递到解码器中,解码器将输入图像转换为与输出图像相似的形式。KL散度用于计算损失。VAE的一个显著优势是其能够生成多样化的图像。在采样阶段,可以直接从高斯分布中采样,并通过解码器生成新的图像。GAN在变分自编码器(VAEs)的短短一年之

安裝PyTorch的PyCharm教學 安裝PyTorch的PyCharm教學 Feb 24, 2024 am 10:09 AM

PyTorch作為一個功能強大的深度學習框架,被廣泛應用於各類機器學習專案。 PyCharm作為一個強大的Python整合開發環境,在實現深度學習任務時也能提供很好的支援。本文將詳細介紹如何在PyCharm中安裝PyTorch,並提供具體的程式碼範例,幫助讀者快速上手使用PyTorch進行深度學習任務。第一步:安裝PyCharm首先,我們需要確保已經在電腦上

See all articles