Eingehende Analyse der Kernpunkte von Pytorch, CNN-Entschlüsselung!

王林
Freigeben: 2024-01-04 19:18:16
nach vorne
1313 Leute haben es durchsucht

Hallo, ich bin Xiaozhuang!

Anfänger sind möglicherweise nicht mit der Erstellung von Faltungs-Neuronalen Netzen (CNN) vertraut. Lassen Sie uns dies anhand eines vollständigen Falls veranschaulichen.

CNN ist ein Deep-Learning-Modell, das häufig bei der Bildklassifizierung, Zielerkennung, Bildgenerierung und anderen Aufgaben verwendet wird. Es extrahiert automatisch Merkmale von Bildern durch Faltungsschichten und Pooling-Schichten und führt die Klassifizierung durch vollständig verbundene Schichten durch. Der Schlüssel zu diesem Modell besteht darin, Faltungs- und Pooling-Operationen zu verwenden, um lokale Merkmale in Bildern effektiv zu erfassen und sie über mehrschichtige Netzwerke zu kombinieren, um eine erweiterte Merkmalsextraktion und Klassifizierung von Bildern zu erreichen.

Prinzip

1. Faltungsschicht:

Die Faltungsschicht extrahiert Merkmale aus dem Eingabebild durch Faltungsoperationen. Bei diesem Vorgang handelt es sich um einen lernbaren Faltungskern, der über das Eingabebild gleitet und das Skalarprodukt unter dem Schiebefenster berechnet. Dieser Prozess hilft bei der Extraktion lokaler Merkmale und verbessert so die Wahrnehmung der Übersetzungsinvarianz durch das Netzwerk.

Formel:

突破Pytorch核心点,CNN !!!

wobei x die Eingabe, w der Faltungskern und b der Bias ist.

2. Pooling-Schicht:

Die Pooling-Schicht ist eine häufig verwendete Dimensionsreduktionstechnologie. Ihre Funktion besteht darin, die räumliche Dimension der Daten zu reduzieren und dadurch den Berechnungsaufwand zu reduzieren und die wichtigsten Merkmale zu extrahieren. Unter diesen ist Max Pooling eine gängige Pooling-Methode, bei der der größte Wert in jedem Fenster als Vertreter ausgewählt wird. Durch maximales Pooling können wir die Komplexität der Daten reduzieren und die Recheneffizienz des Modells verbessern, während wichtige Informationen erhalten bleiben.

Formel (Max Pooling):

突破Pytorch核心点,CNN !!!

3. Vollständig verbundene Schicht:

Die vollständig verbundene Schicht spielt eine wichtige Rolle im neuronalen Netzwerk. Sie extrahiert die Faltungs- und Pooling-Feature-Maps . Jedes Neuron in der vollständig verbundenen Schicht ist mit allen Neuronen in der vorherigen Schicht verbunden, sodass eine Merkmalssynthese und -klassifizierung erreicht werden kann.

Praktische Schritte und detaillierte Erklärung

1. Schritte

  • Importieren Sie die erforderlichen Bibliotheken und Module.
  • Definieren Sie die Netzwerkstruktur: Verwenden Sie nn.Module, um eine davon geerbte benutzerdefinierte neuronale Netzwerkklasse zu definieren und die Faltungsschicht, die Aktivierungsfunktion, die Pooling-Schicht und die vollständig verbundene Schicht zu definieren.
  • Verlustfunktion und Optimierer definieren.
  • Daten laden und vorverarbeiten.
  • Trainieren Sie das Netzwerk: Trainieren Sie Netzwerkparameter mithilfe von Trainingsdaten iterativ.
  • Testnetzwerk: Verwenden Sie Testdaten, um die Modellleistung zu bewerten.

2. Code-Implementierung

import torchimport torch.nn as nnimport torch.optim as optimfrom torchvision import datasets, transforms# 定义卷积神经网络类class SimpleCNN(nn.Module):def __init__(self):super(SimpleCNN, self).__init__()# 卷积层1self.conv1 = nn.Conv2d(in_channels=1, out_channels=16, kernel_size=3, stride=1, padding=1)self.relu = nn.ReLU()self.pool = nn.MaxPool2d(kernel_size=2, stride=2)# 卷积层2self.conv2 = nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3, stride=1, padding=1)# 全连接层self.fc1 = nn.Linear(32 * 7 * 7, 10)# 输入大小根据数据调整def forward(self, x):x = self.conv1(x)x = self.relu(x)x = self.pool(x)x = self.conv2(x)x = self.relu(x)x = self.pool(x)x = x.view(-1, 32 * 7 * 7)x = self.fc1(x)return x# 定义损失函数和优化器net = SimpleCNN()criterion = nn.CrossEntropyLoss()optimizer = optim.Adam(net.parameters(), lr=0.001)# 加载和预处理数据transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=64, shuffle=True)# 训练网络num_epochs = 5for epoch in range(num_epochs):for i, (images, labels) in enumerate(train_loader):optimizer.zero_grad()outputs = net(images)loss = criterion(outputs, labels)loss.backward()optimizer.step()if (i+1) % 100 == 0:print(f'Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{len(train_loader)}], Loss: {loss.item()}')# 测试网络net.eval()with torch.no_grad():correct = 0total = 0for images, labels in test_loader:outputs = net(images)_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()accuracy = correct / totalprint('Accuracy on the test set: {}%'.format(100 * accuracy))
Nach dem Login kopieren

Dieses Beispiel zeigt ein einfaches CNN-Modell, trainiert und getestet mit dem MNIST-Datensatz.

Als nächstes fügen wir einen Visualisierungsschritt hinzu, um die Leistung und den Trainingsprozess des Modells intuitiver zu verstehen.

Visualisierung

1. Matplotlib importieren

import matplotlib.pyplot as plt
Nach dem Login kopieren

2. Zeichnen Sie Verlust und Genauigkeit während des Trainings auf:

Zeichnen Sie während der Trainingsschleife den Verlust und die Genauigkeit jeder Epoche auf.

# 在训练循环中添加以下代码train_loss_list = []accuracy_list = []for epoch in range(num_epochs):running_loss = 0.0correct = 0total = 0for i, (images, labels) in enumerate(train_loader):optimizer.zero_grad()outputs = net(images)loss = criterion(outputs, labels)loss.backward()optimizer.step()running_loss += loss.item()_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()if (i+1) % 100 == 0:print(f'Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{len(train_loader)}], Loss: {loss.item()}')epoch_loss = running_loss / len(train_loader)accuracy = correct / totaltrain_loss_list.append(epoch_loss)accuracy_list.append(accuracy)
Nach dem Login kopieren

3. Verlust und Genauigkeit visualisieren:

# 在训练循环后,添加以下代码plt.figure(figsize=(12, 4))# 可视化损失plt.subplot(1, 2, 1)plt.plot(range(1, num_epochs + 1), train_loss_list, label='Training Loss')plt.title('Training Loss')plt.xlabel('Epochs')plt.ylabel('Loss')plt.legend()# 可视化准确率plt.subplot(1, 2, 2)plt.plot(range(1, num_epochs + 1), accuracy_list, label='Accuracy')plt.title('Accuracy')plt.xlabel('Epochs')plt.ylabel('Accuracy')plt.legend()plt.tight_layout()plt.show()
Nach dem Login kopieren

Auf diese Weise können wir die Veränderungen des Trainingsverlusts und der Genauigkeit nach dem Trainingsprozess sehen.

Nach dem Import des Codes können Sie den visuellen Inhalt und das Format nach Bedarf anpassen.

Das obige ist der detaillierte Inhalt vonEingehende Analyse der Kernpunkte von Pytorch, CNN-Entschlüsselung!. Für weitere Informationen folgen Sie bitte anderen verwandten Artikeln auf der PHP chinesischen Website!

Verwandte Etiketten:
Quelle:51cto.com
Erklärung dieser Website
Der Inhalt dieses Artikels wird freiwillig von Internetnutzern beigesteuert und das Urheberrecht liegt beim ursprünglichen Autor. Diese Website übernimmt keine entsprechende rechtliche Verantwortung. Wenn Sie Inhalte finden, bei denen der Verdacht eines Plagiats oder einer Rechtsverletzung besteht, wenden Sie sich bitte an admin@php.cn
Beliebte Tutorials
Mehr>
Neueste Downloads
Mehr>
Web-Effekte
Quellcode der Website
Website-Materialien
Frontend-Vorlage