> 백엔드 개발 > 파이썬 튜토리얼 > PyTorch에서 컨벌루션 신경망 CNN을 구현하는 방법

PyTorch에서 컨벌루션 신경망 CNN을 구현하는 방법

不言
풀어 주다: 2018-04-28 10:02:42
원래의
2777명이 탐색했습니다.

이 글에서는 주로 PyTorch에서 CNN을 구현하는 방법을 소개하고 공유하겠습니다. 함께 구경하세요

1. Convolutional Neural Network

CNN(Convolutional Neural Network)은 원래 이미지 인식과 같은 문제를 해결하기 위해 설계되었습니다. 현재 CNN의 응용 분야는 이미지와 동영상에만 국한되지 않습니다. 오디오 신호 및 텍스트 데이터와 같은 시계열 신호에도 사용할 수 있습니다. 딥 러닝 아키텍처로서 CNN의 초기 매력은 이미지 데이터 전처리에 대한 요구 사항을 줄이고 복잡한 기능 엔지니어링을 피하는 것입니다. 컨볼루션 신경망에서 첫 번째 컨볼루션 레이어는 이미지의 픽셀 수준 입력을 직접 받아들입니다. 컨볼루션(필터)의 각 레이어는 데이터에서 가장 효과적인 특징을 추출합니다. 이 방법을 사용하면 이미지의 가장 기본적인 특징을 추출할 수 있습니다. 그런 다음 기능이 결합되고 추상화되어 고차 기능이 형성되므로 CNN은 이론적으로 이미지 크기 조정, 변환 및 회전에 영향을 받지 않습니다.

컨벌루션 신경망 CNN의 핵심은 풀링 계층(Pooling)에서의 로컬 연결(LocalConnection), 가중치 공유(WeightsSharing) 및 다운 샘플링(Down-Sampling)입니다. 그 중 로컬 연결과 가중치 공유는 매개변수의 양을 줄여 훈련 복잡성을 크게 줄이고 과적합을 완화합니다. 동시에 가중치 공유는 변환에 대한 컨벌루션 네트워크 허용 오차도 제공하고, 풀링 계층 다운샘플링은 출력 매개변수의 양을 더욱 줄이고 가벼운 변형에 대한 모델 허용 오차를 제공하여 모델의 일반화 능력을 향상시킵니다. 컨볼루션 레이어의 컨볼루션 연산은 적은 수의 매개변수로 이미지 내 여러 위치에서 유사한 특징을 추출하는 과정으로 이해할 수 있습니다.

2. 코드 구현

import torch 
import torch.nn as nn 
from torch.autograd import Variable 
import torch.utils.data as Data 
import torchvision 
import matplotlib.pyplot as plt 
 
torch.manual_seed(1) 
 
EPOCH = 1 
BATCH_SIZE = 50 
LR = 0.001 
DOWNLOAD_MNIST = True 
 
# 获取训练集dataset 
training_data = torchvision.datasets.MNIST( 
       root='./mnist/', # dataset存储路径 
       train=True, # True表示是train训练集,False表示test测试集 
       transform=torchvision.transforms.ToTensor(), # 将原数据规范化到(0,1)区间 
       download=DOWNLOAD_MNIST, 
       ) 
 
# 打印MNIST数据集的训练集及测试集的尺寸 
print(training_data.train_data.size()) 
print(training_data.train_labels.size()) 
# torch.Size([60000, 28, 28]) 
# torch.Size([60000]) 
 
plt.imshow(training_data.train_data[0].numpy(), cmap='gray') 
plt.title('%i' % training_data.train_labels[0]) 
plt.show() 
 
# 通过torchvision.datasets获取的dataset格式可直接可置于DataLoader 
train_loader = Data.DataLoader(dataset=training_data, batch_size=BATCH_SIZE, 
                shuffle=True) 
 
# 获取测试集dataset 
test_data = torchvision.datasets.MNIST(root='./mnist/', train=False) 
# 取前2000个测试集样本 
test_x = Variable(torch.unsqueeze(test_data.test_data, dim=1), 
         volatile=True).type(torch.FloatTensor)[:2000]/255 
# (2000, 28, 28) to (2000, 1, 28, 28), in range(0,1) 
test_y = test_data.test_labels[:2000] 
 
class CNN(nn.Module): 
  def __init__(self): 
    super(CNN, self).__init__() 
    self.conv1 = nn.Sequential( # (1,28,28) 
           nn.Conv2d(in_channels=1, out_channels=16, kernel_size=5, 
                stride=1, padding=2), # (16,28,28) 
    # 想要con2d卷积出来的图片尺寸没有变化, padding=(kernel_size-1)/2 
           nn.ReLU(), 
           nn.MaxPool2d(kernel_size=2) # (16,14,14) 
           ) 
    self.conv2 = nn.Sequential( # (16,14,14) 
           nn.Conv2d(16, 32, 5, 1, 2), # (32,14,14) 
           nn.ReLU(), 
           nn.MaxPool2d(2) # (32,7,7) 
           ) 
    self.out = nn.Linear(32*7*7, 10) 
 
  def forward(self, x): 
    x = self.conv1(x) 
    x = self.conv2(x) 
    x = x.view(x.size(0), -1) # 将(batch,32,7,7)展平为(batch,32*7*7) 
    output = self.out(x) 
    return output 
 
cnn = CNN() 
print(cnn) 
''''' 
CNN ( 
 (conv1): Sequential ( 
  (0): Conv2d(1, 16, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2)) 
  (1): ReLU () 
  (2): MaxPool2d (size=(2, 2), stride=(2, 2), dilation=(1, 1)) 
 ) 
 (conv2): Sequential ( 
  (0): Conv2d(16, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2)) 
  (1): ReLU () 
  (2): MaxPool2d (size=(2, 2), stride=(2, 2), dilation=(1, 1)) 
 ) 
 (out): Linear (1568 -> 10) 
) 
''' 
optimizer = torch.optim.Adam(cnn.parameters(), lr=LR) 
loss_function = nn.CrossEntropyLoss() 
 
for epoch in range(EPOCH): 
  for step, (x, y) in enumerate(train_loader): 
    b_x = Variable(x) 
    b_y = Variable(y) 
 
    output = cnn(b_x) 
    loss = loss_function(output, b_y) 
    optimizer.zero_grad() 
    loss.backward() 
    optimizer.step() 
 
    if step % 100 == 0: 
      test_output = cnn(test_x) 
      pred_y = torch.max(test_output, 1)[1].data.squeeze() 
      accuracy = sum(pred_y == test_y) / test_y.size(0) 
      print('Epoch:', epoch, '|Step:', step, 
         '|train loss:%.4f'%loss.data[0], '|test accuracy:%.4f'%accuracy) 
 
test_output = cnn(test_x[:10]) 
pred_y = torch.max(test_output, 1)[1].data.numpy().squeeze() 
print(pred_y, 'prediction number') 
print(test_y[:10].numpy(), 'real number') 
''''' 
Epoch: 0 |Step: 0 |train loss:2.3145 |test accuracy:0.1040 
Epoch: 0 |Step: 100 |train loss:0.5857 |test accuracy:0.8865 
Epoch: 0 |Step: 200 |train loss:0.0600 |test accuracy:0.9380 
Epoch: 0 |Step: 300 |train loss:0.0996 |test accuracy:0.9345 
Epoch: 0 |Step: 400 |train loss:0.0381 |test accuracy:0.9645 
Epoch: 0 |Step: 500 |train loss:0.0266 |test accuracy:0.9620 
Epoch: 0 |Step: 600 |train loss:0.0973 |test accuracy:0.9685 
Epoch: 0 |Step: 700 |train loss:0.0421 |test accuracy:0.9725 
Epoch: 0 |Step: 800 |train loss:0.0654 |test accuracy:0.9710 
Epoch: 0 |Step: 900 |train loss:0.1333 |test accuracy:0.9740 
Epoch: 0 |Step: 1000 |train loss:0.0289 |test accuracy:0.9720 
Epoch: 0 |Step: 1100 |train loss:0.0429 |test accuracy:0.9770 
[7 2 1 0 4 1 4 9 5 9] prediction number 
[7 2 1 0 4 1 4 9 5 9] real number 
'''
로그인 후 복사

3. 분석 및 해석

torchvision.datasets를 사용하면 직접 배치할 수 있는 데이터 세트 형식의 데이터를 빠르게 얻을 수 있습니다. DataLoader.train 매개변수는 학습 데이터 세트를 가져올지 테스트 데이터 세트를 가져올지 제어하거나, 가져올 때 학습에 필요한 데이터 형식으로 직접 변환할 수 있습니다.

컨볼루션 신경망의 구성은 CNN 클래스를 정의하여 이루어집니다. 컨볼루션 레이어는 CNN 클래스 속성의 형태로 정의됩니다. 각 레이어 간의 연결 정보는 정의할 때 지불됩니다. 각각의 레이어에 있는 뉴런의 수에 주목하세요.

CNN의 네트워크 구조는 다음과 같습니다.

CNN (
 (conv1): Sequential (
  (0): Conv2d(1, 16,kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
  (1): ReLU ()
  (2): MaxPool2d (size=(2,2), stride=(2, 2), dilation=(1, 1))
 )
 (conv2): Sequential (
  (0): Conv2d(16, 32,kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
  (1): ReLU ()
  (2): MaxPool2d (size=(2,2), stride=(2, 2), dilation=(1, 1))
 )
 (out): Linear (1568 ->10)
)
로그인 후 복사

EPOCH=1의 학습 결과에서 테스트 세트 정확도가 97.7%에 도달할 수 있다는 것을 실험을 통해 알 수 있습니다.

관련 권장 사항:

PyTorch 배치 훈련 및 최적화 프로그램 비교에 대한 자세한 설명

위 내용은 PyTorch에서 컨벌루션 신경망 CNN을 구현하는 방법의 상세 내용입니다. 자세한 내용은 PHP 중국어 웹사이트의 기타 관련 기사를 참조하세요!

관련 라벨:
원천:php.cn
본 웹사이트의 성명
본 글의 내용은 네티즌들의 자발적인 기여로 작성되었으며, 저작권은 원저작자에게 있습니다. 본 사이트는 이에 상응하는 법적 책임을 지지 않습니다. 표절이나 침해가 의심되는 콘텐츠를 발견한 경우 admin@php.cn으로 문의하세요.
인기 튜토리얼
더>
최신 다운로드
더>
웹 효과
웹사이트 소스 코드
웹사이트 자료
프론트엔드 템플릿