首頁 後端開發 Python教學 pytorch + visdom CNN處理自建圖片資料集的方法

pytorch + visdom CNN處理自建圖片資料集的方法

Jun 04, 2018 pm 04:19 PM
pytorch 圖片

這篇文章主要介紹了關於pytorch visdom CNN處理自建圖片資料集的方法,有著一定的參考價值,現在分享給大家,有需要的朋友可以參考一下

環境

系統:win10

cpu:i7-6700HQ

gpu:gtx965m

python : 3.6

pytorch :0.3

資料下載

#來源自Sasank Chilamkurthy 的教學;資料:下載連結。

下載後解壓縮放到專案根目錄:

 

資料集為用來分類 螞蟻和蜜蜂。有大約120個訓練圖像,每個類別有75個驗證圖像。

資料導入

可以使用 torchvision.datasets.ImageFolder(root,transforms) 模組 可以將 圖片轉換為 tensor。

先定義transform:

ata_transforms = {
  'train': transforms.Compose([
    # 随机切成224x224 大小图片 统一图片格式
    transforms.RandomResizedCrop(224),
    # 图像翻转
    transforms.RandomHorizontalFlip(),
    # totensor 归一化(0,255) >> (0,1)  normalize  channel=(channel-mean)/std
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
  ]),
  "val" : transforms.Compose([
    # 图片大小缩放 统一图片格式
    transforms.Resize(256),
    # 以中心裁剪
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
  ])
}
登入後複製

#匯入,載入資料:

##

data_dir = './hymenoptera_data'
# trans data
image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x), data_transforms[x]) for x in ['train', 'val']}
# load data
data_loaders = {x: DataLoader(image_datasets[x], batch_size=BATCH_SIZE, shuffle=True) for x in ['train', 'val']}

data_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']}
class_names = image_datasets['train'].classes
print(data_sizes, class_names)
登入後複製
{'train': 244, 'val': 153} ['ants', 'bees']
登入後複製

#訓練集244圖片,測試集153圖片。

視覺化部分圖片看看,由於visdom支援tensor輸入,不用換成numpy,直接用tensor計算可以:

inputs, classes = next(iter(data_loaders['val']))
out = torchvision.utils.make_grid(inputs)
inp = torch.transpose(out, 0, 2)
mean = torch.FloatTensor([0.485, 0.456, 0.406])
std = torch.FloatTensor([0.229, 0.224, 0.225])
inp = std * inp + mean
inp = torch.transpose(inp, 0, 2)
viz.images(inp)
登入後複製

建立CNN

net 根據上一篇的處理cifar10的改了一下規格:

class CNN(nn.Module):
  def __init__(self, in_dim, n_class):
    super(CNN, self).__init__()
    self.cnn = nn.Sequential(
      nn.BatchNorm2d(in_dim),
      nn.ReLU(True),
      nn.Conv2d(in_dim, 16, 7), # 224 >> 218
      nn.BatchNorm2d(16),
      nn.ReLU(inplace=True),
      nn.MaxPool2d(2, 2), # 218 >> 109
      nn.ReLU(True),
      nn.Conv2d(16, 32, 5), # 105
      nn.BatchNorm2d(32),
      nn.ReLU(True),
      nn.Conv2d(32, 64, 5), # 101
      nn.BatchNorm2d(64),
      nn.ReLU(True),
      nn.Conv2d(64, 64, 3, 1, 1),
      nn.BatchNorm2d(64),
      nn.ReLU(True),
      nn.MaxPool2d(2, 2), # 101 >> 50
      nn.Conv2d(64, 128, 3, 1, 1), #
      nn.BatchNorm2d(128),
      nn.ReLU(True),
      nn.MaxPool2d(3), # 50 >> 16
    )
    self.fc = nn.Sequential(
      nn.Linear(128*16*16, 120),
      nn.BatchNorm1d(120),
      nn.ReLU(True),
      nn.Linear(120, n_class))
  def forward(self, x):
    out = self.cnn(x)
    out = self.fc(out.view(-1, 128*16*16))
    return out

# 输入3层rgb ,输出 分类 2    
model = CNN(3, 2)
登入後複製

loss,最佳化函數:

line = viz.line(Y=np.arange(10))
loss_f = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=LR, momentum=0.9)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)
登入後複製

#參數:

BATCH_SIZE = 4
LR = 0.001
EPOCHS = 10
登入後複製

##運行10個epoch 看看:

[9/10] train_loss:0.650|train_acc:0.639|test_loss:0.621|test_acc0.706
[10/10] train_loss:0.645|train_acc:0.627|test_loss:0.654|test_acc0.686
Training complete in 1m 16s
Best val Acc: 0.712418
登入後複製

運行20個看看:

[19/20] train_loss:0.592|train_acc:0.701|test_loss:0.563|test_acc0.712
[20/20] train_loss:0.564|train_acc:0.721|test_loss:0.571|test_acc0.706
Training complete in 2m 30s
Best val Acc: 0.745098
登入後複製

##準確率比較低:只有74.5%

我們使用models 裡的resnet18 運行10個epoch:

model = torchvision.models.resnet18(True)
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, 2)
登入後複製
[9/10] train_loss:0.621|train_acc:0.652|test_loss:0.588|test_acc0.667
[10/10] train_loss:0.610|train_acc:0.680|test_loss:0.561|test_acc0.667
Training complete in 1m 24s
Best val Acc: 0.686275
登入後複製

效果也很一般,想要短時間內就訓練出效果很好的models,我們可以下載訓練好的state,在此基礎上訓練:

model = torchvision.models.resnet18(pretrained=True)
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, 2)
登入後複製
[9/10] train_loss:0.308|train_acc:0.877|test_loss:0.160|test_acc0.941
[10/10] train_loss:0.267|train_acc:0.885|test_loss:0.148|test_acc0.954
Training complete in 1m 25s
Best val Acc: 0.954248
登入後複製
10個epoch直接的到95%的準確率。


相關推薦:

pytorch visdom 處理簡單分類問題

###################

以上是pytorch + visdom CNN處理自建圖片資料集的方法的詳細內容。更多資訊請關注PHP中文網其他相關文章!

本網站聲明
本文內容由網友自願投稿,版權歸原作者所有。本站不承擔相應的法律責任。如發現涉嫌抄襲或侵權的內容,請聯絡admin@php.cn

熱AI工具

Undresser.AI Undress

Undresser.AI Undress

人工智慧驅動的應用程序,用於創建逼真的裸體照片

AI Clothes Remover

AI Clothes Remover

用於從照片中去除衣服的線上人工智慧工具。

Undress AI Tool

Undress AI Tool

免費脫衣圖片

Clothoff.io

Clothoff.io

AI脫衣器

Video Face Swap

Video Face Swap

使用我們完全免費的人工智慧換臉工具,輕鬆在任何影片中換臉!

熱門文章

<🎜>:泡泡膠模擬器無窮大 - 如何獲取和使用皇家鑰匙
4 週前 By 尊渡假赌尊渡假赌尊渡假赌
北端:融合系統,解釋
4 週前 By 尊渡假赌尊渡假赌尊渡假赌
Mandragora:巫婆樹的耳語 - 如何解鎖抓鉤
3 週前 By 尊渡假赌尊渡假赌尊渡假赌

熱工具

記事本++7.3.1

記事本++7.3.1

好用且免費的程式碼編輯器

SublimeText3漢化版

SublimeText3漢化版

中文版,非常好用

禪工作室 13.0.1

禪工作室 13.0.1

強大的PHP整合開發環境

Dreamweaver CS6

Dreamweaver CS6

視覺化網頁開發工具

SublimeText3 Mac版

SublimeText3 Mac版

神級程式碼編輯軟體(SublimeText3)

熱門話題

Java教學
1670
14
CakePHP 教程
1428
52
Laravel 教程
1329
25
PHP教程
1276
29
C# 教程
1256
24
小紅書發布自動儲存圖片怎麼解決?發布自動保存圖片在哪裡? 小紅書發布自動儲存圖片怎麼解決?發布自動保存圖片在哪裡? Mar 22, 2024 am 08:06 AM

隨著社群媒體的不斷發展,小紅書已經成為越來越多年輕人分享生活、發現美好事物的平台。許多用戶在發布圖片時遇到了自動儲存的問題,這讓他們感到十分困擾。那麼,如何解決這個問題呢?一、小紅書發布自動儲存圖片怎麼解決? 1.清除快取首先,我們可以嘗試清除小紅書的快取資料。步驟如下:(1)開啟小紅書,點選右下角的「我的」按鈕;(2)在個人中心頁面,找到「設定」並點選;(3)向下捲動,找到「清除快取」選項,點擊確認。清除快取後,重新進入小紅書,嘗試發布圖片看是否解決了自動儲存的問題。 2.更新小紅書版本確保你的小

抖音評論裡怎麼發圖片?評論區圖片入口在哪裡? 抖音評論裡怎麼發圖片?評論區圖片入口在哪裡? Mar 21, 2024 pm 09:12 PM

隨著抖音短影片的火爆,用戶們在留言區互動變得更加豐富多彩。有些用戶希望在評論中分享圖片,以便更好地表達自己的觀點或情感。那麼,抖音評論裡怎麼發圖片呢?本文將為你詳細解答這個問題,並為你提供一些相關的技巧和注意事項。一、抖音評論裡怎麼發圖片? 1.開啟抖音:首先,你需要開啟抖音APP,並登入你的帳號。 2.找到評論區:瀏覽或發布短影片時,找到想要評論的地方,點擊「評論」按鈕。 3.輸入評論內容:在留言區輸入你的評論內容。 4.選擇傳送圖片:在輸入評論內容的介面,你會看到一個「圖片」按鈕或「+」號按鈕,點

PyCharm與PyTorch完美結合:安裝設定步驟詳解 PyCharm與PyTorch完美結合:安裝設定步驟詳解 Feb 21, 2024 pm 12:00 PM

PyCharm是一款強大的整合開發環境(IDE),而PyTorch則是深度學習領域備受歡迎的開源架構。在機器學習和深度學習領域,使用PyCharm和PyTorch進行開發可以大大提高開發效率和程式碼品質。本文將詳細介紹如何在PyCharm中安裝設定PyTorch,並附上具體的程式碼範例,幫助讀者更好地利用這兩者的強大功能。第一步:安裝PyCharm和Python

ppt怎麼讓圖片一張一張出來 ppt怎麼讓圖片一張一張出來 Mar 25, 2024 pm 04:00 PM

在PowerPoint中,讓圖片逐一顯示是常用的技巧,可以透過設定動畫效果來實現。本指南詳細介紹了實現此技巧的步驟,包括基本設定、圖片插入、新增動畫、調整動畫順序和時間。此外,還提供了進階設定和調整,例如使用觸發器、調整動畫速度和順序,以及預覽動畫效果。透過遵循這些步驟和技巧,使用者可以輕鬆地在PowerPoint中設定圖片逐一出現,從而提升簡報的視覺效果並吸引觀眾的注意力。

如何使用HTML、CSS和jQuery實現圖片合併展示的進階功能 如何使用HTML、CSS和jQuery實現圖片合併展示的進階功能 Oct 27, 2023 pm 04:36 PM

如何使用HTML、CSS和jQuery實現圖片合併展示的高級功能概述:在網頁設計中,圖片展示是一個重要的環節,而圖片合併展示是提高頁面加載速度和提升用戶體驗的常用技巧之一。本文將介紹如何使用HTML、CSS和jQuery來實現圖片合併展示的進階功能,並提供具體的程式碼範例。一、HTML佈局:首先,我們需要在HTML中建立一個容器來展示合併後的圖片。可以使用di

在 iPhone 上讓圖片更清晰的 6 種方法 在 iPhone 上讓圖片更清晰的 6 種方法 Mar 04, 2024 pm 06:25 PM

Apple最近的iPhone可以透過清晰的細節、飽和度和亮度來捕捉回憶。但有時,您可能會遇到一些問題,這些問題可能會導致影像看起來不那麼清晰。儘管iPhone相機上的自動對焦已經取得了長足的進步,可以讓您快速拍照,但相機在某些情況下可能會錯誤地對焦錯誤的拍攝對象,從而使照片在不需要的區域更加模糊。如果iPhone上的照片看起來失焦或整體缺乏清晰度,以下貼文應該可以幫助您使它們更清晰。如何在iPhone上讓圖片更清晰[6種方法]您可以嘗試使用本機的「照片」應用程式來清理照片。如果您需要更多功能和選項

網頁圖片載入不出來怎麼辦? 6種解決辦法 網頁圖片載入不出來怎麼辦? 6種解決辦法 Mar 15, 2024 am 10:30 AM

  有網友發現打開瀏覽器網頁,網頁上的圖片遲遲加載不出來,是怎麼回事?檢查過網路是正常的,那是哪裡出現了問題呢?下面小編就來跟大家介紹一下網頁圖片載入不出來的六種解決方法。網頁圖片載入不出來:  1、網速問題網頁顯示不出圖片有可能是因為電腦的網路速度比較慢,電腦中開啟的軟體比較多,  而我們造訪的圖片比較大,這就可能因為載入逾時,導致圖片顯示不出來,  可以將比較佔網速的軟體將關掉,可以去任務管理器查看一下。  2、造訪人數過多  網頁顯示不出圖片還有可能是因為我們造訪的網頁,在同時段造訪的

自然語言生成任務中的五種採樣方法介紹和Pytorch程式碼實現 自然語言生成任務中的五種採樣方法介紹和Pytorch程式碼實現 Feb 20, 2024 am 08:50 AM

在自然語言生成任務中,取樣方法是從生成模型中獲得文字輸出的一種技術。這篇文章將討論5種常用方法,並使用PyTorch進行實作。 1.GreedyDecoding在貪婪解碼中,生成模型根據輸入序列逐個時間步地預測輸出序列的單字。在每個時間步,模型會計算每個單字的條件機率分佈,然後選擇具有最高條件機率的單字作為當前時間步的輸出。這個單字成為下一個時間步的輸入,生成過程會持續直到滿足某種終止條件,例如產生了指定長度的序列或產生了特殊的結束標記。 GreedyDecoding的特點是每次選擇當前條件機率最

See all articles