PyTorch快速建構神經網路及其保存擷取方法詳解
本篇主要介紹了PyTorch快速搭建神經網路及其保存擷取方法詳解,現在分享給大家,也給大家做個參考。一起過來看看吧
有時候我們訓練了一個模型, 希望保存它下次直接使用,不需要下次再花時間去訓練,本節我們來講解一下PyTorch快速搭建神經網路及其保存擷取方法詳解
一、PyTorch快速搭建神經網路方法
先看實驗程式碼:
import torch import torch.nn.functional as F # 方法1,通过定义一个Net类来建立神经网络 class Net(torch.nn.Module): def __init__(self, n_feature, n_hidden, n_output): super(Net, self).__init__() self.hidden = torch.nn.Linear(n_feature, n_hidden) self.predict = torch.nn.Linear(n_hidden, n_output) def forward(self, x): x = F.relu(self.hidden(x)) x = self.predict(x) return x net1 = Net(2, 10, 2) print('方法1:\n', net1) # 方法2 通过torch.nn.Sequential快速建立神经网络结构 net2 = torch.nn.Sequential( torch.nn.Linear(2, 10), torch.nn.ReLU(), torch.nn.Linear(10, 2), ) print('方法2:\n', net2) # 经验证,两种方法构建的神经网络功能相同,结构细节稍有不同 ''''' 方法1: Net ( (hidden): Linear (2 -> 10) (predict): Linear (10 -> 2) ) 方法2: Sequential ( (0): Linear (2 -> 10) (1): ReLU () (2): Linear (10 -> 2) ) '''
先前學習了透過定義一個Net類別來建構神經網路的方法,classNet中首先透過super函數繼承torch.nn.Module模組的建構方法,再透過增加屬性的方式建構神經網路各層的結構訊息,在forward方法中完善神經網路各層之間的連接訊息,然後再透過定義Net類別物件的方式完成對神經網路結構的建構。
建構神經網路的另一個方法,也可以說是快速建構方法,就是透過torch.nn.Sequential,直接完成對神經網路的建立。
兩種方法建構得到的神經網路結構完全相同,都可以透過print函數來列印輸出網路訊息,不過列印結果會有些許不同。
二、PyTorch的神經網路保存和提取
#在學習和研究深度學習的時候,當我們經過一定時間的訓練,得到了一個比較好的模型的時候,我們當然希望將這個模型及模型參數保存下來,以備後用,所以神經網路的保存和模型參數提取重載是很有必要的。
首先,我們需要在需要保存網路結構及其模型參數的神經網路的定義、訓練部分之後透過torch.save()實現網路結構和模型參數的保存。有兩種保存方式:一是保存年整個神經網路的的結構資訊和模型參數訊息,save的物件是網路net;二是只保存神經網路的訓練模型參數,save的物件是net.state_dict(),保存結果都以.pkl檔案形式儲存。
對應上面兩種保存方式,重載方式也有兩種。對應第一種完整網路結構訊息,重載的時候透過torch.load(‘.pkl')直接初始化新的神經網路物件即可。對應第二種只保存模型參數訊息,需要先搭建相同的神經網路結構,透過net.load_state_dict(torch.load('.pkl'))完成模型參數的重載。在網路比較大的時候,第一種方法會花費較多的時間。
程式碼實作:
import torch from torch.autograd import Variable import matplotlib.pyplot as plt torch.manual_seed(1) # 设定随机数种子 # 创建数据 x = torch.unsqueeze(torch.linspace(-1, 1, 100), dim=1) y = x.pow(2) + 0.2*torch.rand(x.size()) x, y = Variable(x, requires_grad=False), Variable(y, requires_grad=False) # 将待保存的神经网络定义在一个函数中 def save(): # 神经网络结构 net1 = torch.nn.Sequential( torch.nn.Linear(1, 10), torch.nn.ReLU(), torch.nn.Linear(10, 1), ) optimizer = torch.optim.SGD(net1.parameters(), lr=0.5) loss_function = torch.nn.MSELoss() # 训练部分 for i in range(300): prediction = net1(x) loss = loss_function(prediction, y) optimizer.zero_grad() loss.backward() optimizer.step() # 绘图部分 plt.figure(1, figsize=(10, 3)) plt.subplot(131) plt.title('net1') plt.scatter(x.data.numpy(), y.data.numpy()) plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5) # 保存神经网络 torch.save(net1, '7-net.pkl') # 保存整个神经网络的结构和模型参数 torch.save(net1.state_dict(), '7-net_params.pkl') # 只保存神经网络的模型参数 # 载入整个神经网络的结构及其模型参数 def reload_net(): net2 = torch.load('7-net.pkl') prediction = net2(x) plt.subplot(132) plt.title('net2') plt.scatter(x.data.numpy(), y.data.numpy()) plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5) # 只载入神经网络的模型参数,神经网络的结构需要与保存的神经网络相同的结构 def reload_params(): # 首先搭建相同的神经网络结构 net3 = torch.nn.Sequential( torch.nn.Linear(1, 10), torch.nn.ReLU(), torch.nn.Linear(10, 1), ) # 载入神经网络的模型参数 net3.load_state_dict(torch.load('7-net_params.pkl')) prediction = net3(x) plt.subplot(133) plt.title('net3') plt.scatter(x.data.numpy(), y.data.numpy()) plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5) # 运行测试 save() reload_net() reload_params()
實驗結果:
相關推薦:
#
以上是PyTorch快速建構神經網路及其保存擷取方法詳解的詳細內容。更多資訊請關注PHP中文網其他相關文章!

熱AI工具

Undresser.AI Undress
人工智慧驅動的應用程序,用於創建逼真的裸體照片

AI Clothes Remover
用於從照片中去除衣服的線上人工智慧工具。

Undress AI Tool
免費脫衣圖片

Clothoff.io
AI脫衣器

Video Face Swap
使用我們完全免費的人工智慧換臉工具,輕鬆在任何影片中換臉!

熱門文章

熱工具

記事本++7.3.1
好用且免費的程式碼編輯器

SublimeText3漢化版
中文版,非常好用

禪工作室 13.0.1
強大的PHP整合開發環境

Dreamweaver CS6
視覺化網頁開發工具

SublimeText3 Mac版
神級程式碼編輯軟體(SublimeText3)

本站10月22日消息,今年第三季度,科大讯飞实现净利润2579万元,同比下降81.86%;前三季度净利润9936万元,同比下降76.36%。科大讯飞副总裁江涛在Q3业绩说明会上透露,讯飞已于2023年初与华为昇腾启动专项攻关,与华为联合研发高性能算子库,合力打造我国通用人工智能新底座,让国产大模型架构在自主创新的软硬件基础之上。他指出,目前华为昇腾910B能力已经基本做到可对标英伟达A100。在即将举行的科大讯飞1024全球开发者节上,讯飞和华为在人工智能算力底座上将有进一步联合发布。他还提到,

PyCharm是一款強大的整合開發環境(IDE),而PyTorch則是深度學習領域備受歡迎的開源架構。在機器學習和深度學習領域,使用PyCharm和PyTorch進行開發可以大大提高開發效率和程式碼品質。本文將詳細介紹如何在PyCharm中安裝設定PyTorch,並附上具體的程式碼範例,幫助讀者更好地利用這兩者的強大功能。第一步:安裝PyCharm和Python

如今的深度學習方法專注於設計最適合的目標函數,以使模型的預測結果與實際情況最接近。同時,必須設計一個合適的架構,以便為預測取得足夠的資訊。現有方法忽略了一個事實,當輸入資料經過逐層特徵提取和空間變換時,大量資訊將會遺失。本文將深入探討資料透過深度網路傳輸時的重要問題,即資訊瓶頸和可逆函數。基於此提出了可編程梯度資訊(PGI)的概念,以應對深度網路實現多目標所需的各種變化。 PGI可以為目標任務提供完整的輸入訊息,以計算目標函數,從而獲得可靠的梯度資訊以更新網路權重。此外設計了一種新的輕量級網路架

目前主流的AI晶片主要分為三類,GPU、FPGA、ASIC。 GPU、FPGA皆是前期較成熟的晶片架構,屬於通用型晶片。 ASIC屬於為AI特定場景定制的晶片。業界已經確認CPU不適用於AI計算,但在AI應用領域也是不可或缺。 GPU方案GPU與CPU的架構比較CPU遵循的是馮諾依曼架構,其核心是儲存程式/資料、序列順序執行。因此CPU的架構中需要大量的空間去放置儲存單元(Cache)和控制單元(Control),相較之下運算單元(ALU)只佔據了很小的一部分,所以CPU在進行大規模平行運算

在我的世界(Minecraft)中,紅石是一種非常重要的物品。它是遊戲中獨特的材料,開關、紅石火把和紅石塊等能對導線或物體提供類似電流的能量。紅石電路可以為你建造用於控製或激活其他機械的結構,其本身既可以被設計為用於響應玩家的手動激活,也可以反複輸出信號或者響應非玩家引發的變化,如生物移動、物品掉落、植物生長、日夜更替等等。因此,在我的世界中,紅石能夠控制的機械類別極其多,小到簡單機械如自動門、光開關和頻閃電源,大到佔地巨大的電梯、自動農場、小遊戲平台甚至遊戲內建的計算機。近日,B站UP主@

在自然語言生成任務中,取樣方法是從生成模型中獲得文字輸出的一種技術。這篇文章將討論5種常用方法,並使用PyTorch進行實作。 1.GreedyDecoding在貪婪解碼中,生成模型根據輸入序列逐個時間步地預測輸出序列的單字。在每個時間步,模型會計算每個單字的條件機率分佈,然後選擇具有最高條件機率的單字作為當前時間步的輸出。這個單字成為下一個時間步的輸入,生成過程會持續直到滿足某種終止條件,例如產生了指定長度的序列或產生了特殊的結束標記。 GreedyDecoding的特點是每次選擇當前條件機率最

在详细了解去噪扩散概率模型(DDPM)的工作原理之前,我们先来了解一下生成式人工智能的一些发展情况,这也是DDPM的基础研究之一。VAEVAE使用编码器、概率潜在空间和解码器。在训练过程中,编码器预测每个图像的均值和方差,并从高斯分布中对这些值进行采样。采样的结果传递到解码器中,解码器将输入图像转换为与输出图像相似的形式。KL散度用于计算损失。VAE的一个显著优势是其能够生成多样化的图像。在采样阶段,可以直接从高斯分布中采样,并通过解码器生成新的图像。GAN在变分自编码器(VAEs)的短短一年之

PyTorch作為一個功能強大的深度學習框架,被廣泛應用於各類機器學習專案。 PyCharm作為一個強大的Python整合開發環境,在實現深度學習任務時也能提供很好的支援。本文將詳細介紹如何在PyCharm中安裝PyTorch,並提供具體的程式碼範例,幫助讀者快速上手使用PyTorch進行深度學習任務。第一步:安裝PyCharm首先,我們需要確保已經在電腦上
