PyTorch上實作卷積神經網路CNN的方法
本篇文章主要介紹了PyTorch上實現卷積神經網路CNN的方法,現在分享給大家,也給大家做個參考。一起來看看吧
一、卷積神經網路
卷積神經網路(ConvolutionalNeuralNetwork,CNN)原本是為解決影像識別等問題設計的,CNN現在的應用已經不限於圖像和視頻,也可用於時間序列訊號,例如音頻訊號和文字資料等。 CNN作為一個深度學習架構被提出的最初訴求是降低對影像資料預處理的要求,避免複雜的特徵工程。在卷積神經網路中,第一個卷積層會直接接受影像像素層級的輸入,每一層卷積(濾波器)都會擷取資料中最有效的特徵,這種方法可以擷取到影像中最基礎的特徵,而後再進行組合和抽象形成更高階的特徵,因此CNN在理論上具有對影像縮放、平移和旋轉的不變性。
卷積神經網路CNN的重點是局部連結(LocalConnection)、權值共享(WeightsSharing)和池化層(Pooling)中的降採樣(Down-Sampling)。其中,局部連接和權值共享降低了參數量,使訓練複雜度大大下降並減輕了過度擬合。同時權值共享也賦予了卷積網路對平移的容忍性,池化層降採樣則進一步降低了輸出參數量並賦予模型對輕度形變的容忍性,提高了模型的泛化能力。可以把捲積層卷積操作理解為用少量參數在影像的多個位置上提取相似特徵的過程。
二、程式碼實作
import torch import torch.nn as nn from torch.autograd import Variable import torch.utils.data as Data import torchvision import matplotlib.pyplot as plt torch.manual_seed(1) EPOCH = 1 BATCH_SIZE = 50 LR = 0.001 DOWNLOAD_MNIST = True # 获取训练集dataset training_data = torchvision.datasets.MNIST( root='./mnist/', # dataset存储路径 train=True, # True表示是train训练集,False表示test测试集 transform=torchvision.transforms.ToTensor(), # 将原数据规范化到(0,1)区间 download=DOWNLOAD_MNIST, ) # 打印MNIST数据集的训练集及测试集的尺寸 print(training_data.train_data.size()) print(training_data.train_labels.size()) # torch.Size([60000, 28, 28]) # torch.Size([60000]) plt.imshow(training_data.train_data[0].numpy(), cmap='gray') plt.title('%i' % training_data.train_labels[0]) plt.show() # 通过torchvision.datasets获取的dataset格式可直接可置于DataLoader train_loader = Data.DataLoader(dataset=training_data, batch_size=BATCH_SIZE, shuffle=True) # 获取测试集dataset test_data = torchvision.datasets.MNIST(root='./mnist/', train=False) # 取前2000个测试集样本 test_x = Variable(torch.unsqueeze(test_data.test_data, dim=1), volatile=True).type(torch.FloatTensor)[:2000]/255 # (2000, 28, 28) to (2000, 1, 28, 28), in range(0,1) test_y = test_data.test_labels[:2000] class CNN(nn.Module): def __init__(self): super(CNN, self).__init__() self.conv1 = nn.Sequential( # (1,28,28) nn.Conv2d(in_channels=1, out_channels=16, kernel_size=5, stride=1, padding=2), # (16,28,28) # 想要con2d卷积出来的图片尺寸没有变化, padding=(kernel_size-1)/2 nn.ReLU(), nn.MaxPool2d(kernel_size=2) # (16,14,14) ) self.conv2 = nn.Sequential( # (16,14,14) nn.Conv2d(16, 32, 5, 1, 2), # (32,14,14) nn.ReLU(), nn.MaxPool2d(2) # (32,7,7) ) self.out = nn.Linear(32*7*7, 10) def forward(self, x): x = self.conv1(x) x = self.conv2(x) x = x.view(x.size(0), -1) # 将(batch,32,7,7)展平为(batch,32*7*7) output = self.out(x) return output cnn = CNN() print(cnn) ''''' CNN ( (conv1): Sequential ( (0): Conv2d(1, 16, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2)) (1): ReLU () (2): MaxPool2d (size=(2, 2), stride=(2, 2), dilation=(1, 1)) ) (conv2): Sequential ( (0): Conv2d(16, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2)) (1): ReLU () (2): MaxPool2d (size=(2, 2), stride=(2, 2), dilation=(1, 1)) ) (out): Linear (1568 -> 10) ) ''' optimizer = torch.optim.Adam(cnn.parameters(), lr=LR) loss_function = nn.CrossEntropyLoss() for epoch in range(EPOCH): for step, (x, y) in enumerate(train_loader): b_x = Variable(x) b_y = Variable(y) output = cnn(b_x) loss = loss_function(output, b_y) optimizer.zero_grad() loss.backward() optimizer.step() if step % 100 == 0: test_output = cnn(test_x) pred_y = torch.max(test_output, 1)[1].data.squeeze() accuracy = sum(pred_y == test_y) / test_y.size(0) print('Epoch:', epoch, '|Step:', step, '|train loss:%.4f'%loss.data[0], '|test accuracy:%.4f'%accuracy) test_output = cnn(test_x[:10]) pred_y = torch.max(test_output, 1)[1].data.numpy().squeeze() print(pred_y, 'prediction number') print(test_y[:10].numpy(), 'real number') ''''' Epoch: 0 |Step: 0 |train loss:2.3145 |test accuracy:0.1040 Epoch: 0 |Step: 100 |train loss:0.5857 |test accuracy:0.8865 Epoch: 0 |Step: 200 |train loss:0.0600 |test accuracy:0.9380 Epoch: 0 |Step: 300 |train loss:0.0996 |test accuracy:0.9345 Epoch: 0 |Step: 400 |train loss:0.0381 |test accuracy:0.9645 Epoch: 0 |Step: 500 |train loss:0.0266 |test accuracy:0.9620 Epoch: 0 |Step: 600 |train loss:0.0973 |test accuracy:0.9685 Epoch: 0 |Step: 700 |train loss:0.0421 |test accuracy:0.9725 Epoch: 0 |Step: 800 |train loss:0.0654 |test accuracy:0.9710 Epoch: 0 |Step: 900 |train loss:0.1333 |test accuracy:0.9740 Epoch: 0 |Step: 1000 |train loss:0.0289 |test accuracy:0.9720 Epoch: 0 |Step: 1100 |train loss:0.0429 |test accuracy:0.9770 [7 2 1 0 4 1 4 9 5 9] prediction number [7 2 1 0 4 1 4 9 5 9] real number '''
三、分析解讀
透過利用torchvision.datasets可以快速取得可以直接置於DataLoader中的dataset格式的數據,透過train參數控制是取得訓練資料集還是測試資料集,也可以在取得的時候便直接轉換成訓練所需的資料格式。
卷積神經網路的建構是透過定義一個CNN類別來實現,卷積層conv1,conv2及out層以類別屬性的形式定義,各層之間的銜接資訊在forward中定義,定義的時候要留意各層的神經元數量。
CNN ( (conv1): Sequential ( (0): Conv2d(1, 16,kernel_size=(5, 5), stride=(1, 1), padding=(2, 2)) (1): ReLU () (2): MaxPool2d (size=(2,2), stride=(2, 2), dilation=(1, 1)) ) (conv2): Sequential ( (0): Conv2d(16, 32,kernel_size=(5, 5), stride=(1, 1), padding=(2, 2)) (1): ReLU () (2): MaxPool2d (size=(2,2), stride=(2, 2), dilation=(1, 1)) ) (out): Linear (1568 ->10) )
#經過實驗可見,在EPOCH=1的訓練結果中,測試集準確率可達到97.7%。
相關推薦:#PyTorch批次訓練及最佳化器比較#################################### ######
以上是PyTorch上實作卷積神經網路CNN的方法的詳細內容。更多資訊請關注PHP中文網其他相關文章!

熱AI工具

Undresser.AI Undress
人工智慧驅動的應用程序,用於創建逼真的裸體照片

AI Clothes Remover
用於從照片中去除衣服的線上人工智慧工具。

Undress AI Tool
免費脫衣圖片

Clothoff.io
AI脫衣器

AI Hentai Generator
免費產生 AI 無盡。

熱門文章

熱工具

記事本++7.3.1
好用且免費的程式碼編輯器

SublimeText3漢化版
中文版,非常好用

禪工作室 13.0.1
強大的PHP整合開發環境

Dreamweaver CS6
視覺化網頁開發工具

SublimeText3 Mac版
神級程式碼編輯軟體(SublimeText3)

熱門話題

番茄小說是一款非常熱門的小說閱讀軟體,我們在番茄小說中經常會有新的小說和漫畫可以去閱讀,每一本小說和漫畫都很有意思,很多小伙伴也想著要去寫小說來賺取賺取零用錢,在把自己想要寫的小說內容編輯成文字,那麼我們要怎麼樣在這裡面去寫小說呢?小伙伴們都不知道,那就讓我們一起到本站本站中花點時間來看寫小說的方法介紹。分享番茄小說寫小說方法教學 1、先在手機上打開番茄免費小說app,點擊個人中心——作家中心 2、跳到番茄作家助手頁面——點擊創建新書在小說的結

七彩虹主機板在中國國內市場享有較高的知名度和市場佔有率,但是有些七彩虹主機板的用戶還不清楚怎麼進入bios進行設定呢?針對這一情況,小編專門為大家帶來了兩種進入七彩虹主機板bios的方法,快來試試吧!方法一:使用u盤啟動快捷鍵直接進入u盤裝系統七彩虹主機板一鍵啟動u盤的快捷鍵是ESC或F11,首先使用黑鯊裝機大師製作一個黑鯊U盤啟動盤,然後開啟電腦,當看到開機畫面的時候,連續按下鍵盤上的ESC或F11鍵以後將會進入到一個啟動項順序選擇的窗口,將遊標移到顯示“USB”的地方,然

而後悔莫及、人們常常會因為一些原因不小心刪除某些聯絡人、微信作為一款廣泛使用的社群軟體。幫助用戶解決這個問題,本文將介紹如何透過簡單的方法找回被刪除的聯絡人。 1.了解微信聯絡人刪除機制這為我們找回被刪除的聯絡人提供了可能性、微信中的聯絡人刪除機制是將其從通訊錄中移除,但並未完全刪除。 2.使用微信內建「通訊錄恢復」功能微信提供了「通訊錄恢復」節省時間和精力,使用者可以透過此功能快速找回先前刪除的聯絡人,功能。 3.進入微信設定頁面點選右下角,開啟微信應用程式「我」再點選右上角設定圖示、進入設定頁面,,

字體大小的設定成為了重要的個人化需求,隨著手機成為人們日常生活的重要工具。以滿足不同使用者的需求、本文將介紹如何透過簡單的操作,提升手機使用體驗,調整手機字體大小。為什麼需要調整手機字體大小-調整字體大小可以使文字更清晰易讀-適合不同年齡段用戶的閱讀需求-方便視力不佳的用戶使用手機系統自帶字體大小設置功能-如何進入系統設置界面-在在設定介面中找到並進入"顯示"選項-找到"字體大小"選項並進行調整第三方應用調整字體大小-下載並安裝支援字體大小調整的應用程式-開啟應用程式並進入相關設定介面-根據個人

Win11管理員權限取得方法匯總在Windows11作業系統中,管理員權限是非常重要的權限之一,可以讓使用者對系統進行各種操作。有時候,我們可能需要取得管理員權限來完成一些操作,例如安裝軟體、修改系統設定等。下面就為大家總結了一些取得Win11管理員權限的方法,希望能幫助大家。 1.使用快捷鍵在Windows11系統中,可以透過快捷鍵的方式快速開啟命令提

手機遊戲成為了人們生活中不可或缺的一部分,隨著科技的發展。它以其可愛的龍蛋形象和有趣的孵化過程吸引了眾多玩家的關注,而其中一款備受矚目的遊戲就是手機版龍蛋。幫助玩家們在遊戲中更好地培養和成長自己的小龍,本文將向大家介紹手機版龍蛋的孵化方法。 1.選擇合適的龍蛋種類玩家需要仔細選擇自己喜歡並且適合自己的龍蛋種類,根據遊戲中提供的不同種類的龍蛋屬性和能力。 2.提升孵化機的等級玩家需要透過完成任務和收集道具來提升孵化機的等級,孵化機的等級決定了孵化速度和孵化成功率。 3.收集孵化所需的資源玩家需要在遊戲中

Oracle版本查詢方法詳解Oracle是目前世界上最受歡迎的關聯式資料庫管理系統之一,它提供了豐富的功能和強大的效能,廣泛應用於企業。在進行資料庫管理和開發過程中,了解Oracle資料庫的版本是非常重要的。本文將詳細介紹如何查詢Oracle資料庫的版本信息,並給出具體的程式碼範例。查詢資料庫版本的SQL語句在Oracle資料庫中,可以透過執行簡單的SQL語句

在現今社會,手機已經成為我們生活中不可或缺的一部分。而微信作為我們日常溝通、工作、生活的重要工具,更是經常被使用。然而,在處理不同事務時可能需要分開兩個微信帳號,這就要求手機能夠支援同時登入兩個微信帳號。華為手機作為國內知名品牌,很多人使用,那麼華為手機開啟兩個微信帳號的方法是怎麼樣的呢?下面就來揭秘一下這個方法。首先,要在華為手機上同時使用兩個微信帳號,最簡
