Heim > Backend-Entwicklung > Python-Tutorial > So implementieren Sie das Faltungs-Neuronale Netzwerk CNN auf PyTorch

So implementieren Sie das Faltungs-Neuronale Netzwerk CNN auf PyTorch

不言
Freigeben: 2018-04-28 10:02:42
Original
2776 Leute haben es durchsucht

In diesem Artikel wird hauptsächlich die Methode zur Implementierung des Faltungs-Neuronalen Netzwerks CNN auf PyTorch vorgestellt. Jetzt werde ich es mit Ihnen teilen und Ihnen eine Referenz geben. Werfen wir gemeinsam einen Blick darauf

1. Convolutional Neural Network

Convolutional Neural Network (CNN) wurde ursprünglich zur Lösung der Bilderkennung entwickelt Aufgrund dieser Probleme beschränken sich die aktuellen Anwendungen von CNN nicht nur auf Bilder und Videos, sondern können auch für Zeitreihensignale wie Audiosignale und Textdaten verwendet werden. Der anfängliche Reiz von CNN als Deep-Learning-Architektur besteht darin, die Anforderungen an die Vorverarbeitung von Bilddaten zu reduzieren und komplexes Feature-Engineering zu vermeiden. In einem Faltungs-Neuronalen Netzwerk akzeptiert die erste Faltungsschicht direkt die Eingabe auf Pixelebene. Jede Faltungsschicht (Filter) extrahiert die effektivsten Merkmale in den Daten. Mit dieser Methode können die grundlegendsten Merkmale des Bildes extrahiert werden . Features werden dann kombiniert und abstrahiert, um Features höherer Ordnung zu bilden, sodass CNN theoretisch invariant gegenüber Bildskalierung, Translation und Rotation ist.

Die Schlüsselpunkte des Faltungs-Neuronalen Netzwerks CNN sind lokale Verbindung (LocalConnection), Gewichtsteilung (WeightsSharing) und Downsampling (Down-Sampling) in der Pooling-Schicht (Pooling). Unter anderem reduzieren lokale Verbindungen und Gewichtsverteilung die Anzahl der Parameter, reduzieren die Trainingskomplexität erheblich und lindern Überanpassungen. Gleichzeitig verleiht die Gewichtsteilung dem Faltungsnetzwerk auch Toleranz gegenüber Übersetzungen, und das Downsampling der Pooling-Schicht reduziert die Anzahl der Ausgabeparameter weiter und verleiht dem Modell Toleranz gegenüber leichten Verformungen, wodurch die Generalisierungsfähigkeit des Modells verbessert wird. Die Faltungsoperation der Faltungsschicht kann als ein Prozess verstanden werden, bei dem ähnliche Merkmale an mehreren Stellen im Bild mit einer kleinen Anzahl von Parametern extrahiert werden.

2. Code-Implementierung

import torch 
import torch.nn as nn 
from torch.autograd import Variable 
import torch.utils.data as Data 
import torchvision 
import matplotlib.pyplot as plt 
 
torch.manual_seed(1) 
 
EPOCH = 1 
BATCH_SIZE = 50 
LR = 0.001 
DOWNLOAD_MNIST = True 
 
# 获取训练集dataset 
training_data = torchvision.datasets.MNIST( 
       root='./mnist/', # dataset存储路径 
       train=True, # True表示是train训练集,False表示test测试集 
       transform=torchvision.transforms.ToTensor(), # 将原数据规范化到(0,1)区间 
       download=DOWNLOAD_MNIST, 
       ) 
 
# 打印MNIST数据集的训练集及测试集的尺寸 
print(training_data.train_data.size()) 
print(training_data.train_labels.size()) 
# torch.Size([60000, 28, 28]) 
# torch.Size([60000]) 
 
plt.imshow(training_data.train_data[0].numpy(), cmap='gray') 
plt.title('%i' % training_data.train_labels[0]) 
plt.show() 
 
# 通过torchvision.datasets获取的dataset格式可直接可置于DataLoader 
train_loader = Data.DataLoader(dataset=training_data, batch_size=BATCH_SIZE, 
                shuffle=True) 
 
# 获取测试集dataset 
test_data = torchvision.datasets.MNIST(root='./mnist/', train=False) 
# 取前2000个测试集样本 
test_x = Variable(torch.unsqueeze(test_data.test_data, dim=1), 
         volatile=True).type(torch.FloatTensor)[:2000]/255 
# (2000, 28, 28) to (2000, 1, 28, 28), in range(0,1) 
test_y = test_data.test_labels[:2000] 
 
class CNN(nn.Module): 
  def __init__(self): 
    super(CNN, self).__init__() 
    self.conv1 = nn.Sequential( # (1,28,28) 
           nn.Conv2d(in_channels=1, out_channels=16, kernel_size=5, 
                stride=1, padding=2), # (16,28,28) 
    # 想要con2d卷积出来的图片尺寸没有变化, padding=(kernel_size-1)/2 
           nn.ReLU(), 
           nn.MaxPool2d(kernel_size=2) # (16,14,14) 
           ) 
    self.conv2 = nn.Sequential( # (16,14,14) 
           nn.Conv2d(16, 32, 5, 1, 2), # (32,14,14) 
           nn.ReLU(), 
           nn.MaxPool2d(2) # (32,7,7) 
           ) 
    self.out = nn.Linear(32*7*7, 10) 
 
  def forward(self, x): 
    x = self.conv1(x) 
    x = self.conv2(x) 
    x = x.view(x.size(0), -1) # 将(batch,32,7,7)展平为(batch,32*7*7) 
    output = self.out(x) 
    return output 
 
cnn = CNN() 
print(cnn) 
''''' 
CNN ( 
 (conv1): Sequential ( 
  (0): Conv2d(1, 16, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2)) 
  (1): ReLU () 
  (2): MaxPool2d (size=(2, 2), stride=(2, 2), dilation=(1, 1)) 
 ) 
 (conv2): Sequential ( 
  (0): Conv2d(16, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2)) 
  (1): ReLU () 
  (2): MaxPool2d (size=(2, 2), stride=(2, 2), dilation=(1, 1)) 
 ) 
 (out): Linear (1568 -> 10) 
) 
''' 
optimizer = torch.optim.Adam(cnn.parameters(), lr=LR) 
loss_function = nn.CrossEntropyLoss() 
 
for epoch in range(EPOCH): 
  for step, (x, y) in enumerate(train_loader): 
    b_x = Variable(x) 
    b_y = Variable(y) 
 
    output = cnn(b_x) 
    loss = loss_function(output, b_y) 
    optimizer.zero_grad() 
    loss.backward() 
    optimizer.step() 
 
    if step % 100 == 0: 
      test_output = cnn(test_x) 
      pred_y = torch.max(test_output, 1)[1].data.squeeze() 
      accuracy = sum(pred_y == test_y) / test_y.size(0) 
      print('Epoch:', epoch, '|Step:', step, 
         '|train loss:%.4f'%loss.data[0], '|test accuracy:%.4f'%accuracy) 
 
test_output = cnn(test_x[:10]) 
pred_y = torch.max(test_output, 1)[1].data.numpy().squeeze() 
print(pred_y, 'prediction number') 
print(test_y[:10].numpy(), 'real number') 
''''' 
Epoch: 0 |Step: 0 |train loss:2.3145 |test accuracy:0.1040 
Epoch: 0 |Step: 100 |train loss:0.5857 |test accuracy:0.8865 
Epoch: 0 |Step: 200 |train loss:0.0600 |test accuracy:0.9380 
Epoch: 0 |Step: 300 |train loss:0.0996 |test accuracy:0.9345 
Epoch: 0 |Step: 400 |train loss:0.0381 |test accuracy:0.9645 
Epoch: 0 |Step: 500 |train loss:0.0266 |test accuracy:0.9620 
Epoch: 0 |Step: 600 |train loss:0.0973 |test accuracy:0.9685 
Epoch: 0 |Step: 700 |train loss:0.0421 |test accuracy:0.9725 
Epoch: 0 |Step: 800 |train loss:0.0654 |test accuracy:0.9710 
Epoch: 0 |Step: 900 |train loss:0.1333 |test accuracy:0.9740 
Epoch: 0 |Step: 1000 |train loss:0.0289 |test accuracy:0.9720 
Epoch: 0 |Step: 1100 |train loss:0.0429 |test accuracy:0.9770 
[7 2 1 0 4 1 4 9 5 9] prediction number 
[7 2 1 0 4 1 4 9 5 9] real number 
'''
Nach dem Login kopieren

3. Analyse und Interpretation

Durch die Verwendung von Torchvision.datasets können Sie schnell Daten im Datensatzformat erhalten, die direkt im DataLoader platziert werden können. Verwenden Sie den Zugparameter, um zu steuern, ob Um den Trainingsdatensatz zu erhalten oder zu testen, kann der Datensatz nach Erhalt auch direkt in das für das Training erforderliche Datenformat konvertiert werden.

Der Aufbau eines Faltungs-Neuronalen Netzwerks wird durch die Definition einer CNN-Klasse erreicht. Die Faltungsschichten conv1, conv2 und out werden in Form von Klassenattributen definiert. Achten Sie bei der Definition immer auf die Anzahl der Neuronen in jeder Schicht.

Die Netzwerkstruktur von CNN ist wie folgt:

CNN (
 (conv1): Sequential (
  (0): Conv2d(1, 16,kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
  (1): ReLU ()
  (2): MaxPool2d (size=(2,2), stride=(2, 2), dilation=(1, 1))
 )
 (conv2): Sequential (
  (0): Conv2d(16, 32,kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
  (1): ReLU ()
  (2): MaxPool2d (size=(2,2), stride=(2, 2), dilation=(1, 1))
 )
 (out): Linear (1568 ->10)
)
Nach dem Login kopieren

Durch Experimente kann festgestellt werden, dass in den Trainingsergebnissen von EPOCH=1, Die Genauigkeit des Testsatzes kann 97,7 % erreichen.

Verwandte Empfehlungen:

Detaillierte Erläuterung des PyTorch-Batch-Trainings und Optimierervergleichs

Das obige ist der detaillierte Inhalt vonSo implementieren Sie das Faltungs-Neuronale Netzwerk CNN auf PyTorch. Für weitere Informationen folgen Sie bitte anderen verwandten Artikeln auf der PHP chinesischen Website!

Verwandte Etiketten:
Quelle:php.cn
Erklärung dieser Website
Der Inhalt dieses Artikels wird freiwillig von Internetnutzern beigesteuert und das Urheberrecht liegt beim ursprünglichen Autor. Diese Website übernimmt keine entsprechende rechtliche Verantwortung. Wenn Sie Inhalte finden, bei denen der Verdacht eines Plagiats oder einer Rechtsverletzung besteht, wenden Sie sich bitte an admin@php.cn
Beliebte Tutorials
Mehr>
Neueste Downloads
Mehr>
Web-Effekte
Quellcode der Website
Website-Materialien
Frontend-Vorlage