首页 后端开发 Python教程 pytorch + visdom CNN处理自建图片数据集的方法

pytorch + visdom CNN处理自建图片数据集的方法

Jun 04, 2018 pm 04:19 PM
pytorch 图片

这篇文章主要介绍了关于pytorch + visdom CNN处理自建图片数据集的方法,有着一定的参考价值,现在分享给大家,有需要的朋友可以参考一下

环境

系统:win10

cpu:i7-6700HQ

gpu:gtx965m

python : 3.6

pytorch :0.3

数据下载

来源自Sasank Chilamkurthy 的教程; 数据:下载链接。

下载后解压放到项目根目录:

 

数据集为用来分类 蚂蚁和蜜蜂。有大约120个训练图像,每个类有75个验证图像。

数据导入

可以使用 torchvision.datasets.ImageFolder(root,transforms) 模块 可以将 图片转换为 tensor。

先定义transform:

ata_transforms = {
  'train': transforms.Compose([
    # 随机切成224x224 大小图片 统一图片格式
    transforms.RandomResizedCrop(224),
    # 图像翻转
    transforms.RandomHorizontalFlip(),
    # totensor 归一化(0,255) >> (0,1)  normalize  channel=(channel-mean)/std
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
  ]),
  "val" : transforms.Compose([
    # 图片大小缩放 统一图片格式
    transforms.Resize(256),
    # 以中心裁剪
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
  ])
}
登录后复制

导入,加载数据:

data_dir = './hymenoptera_data'
# trans data
image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x), data_transforms[x]) for x in ['train', 'val']}
# load data
data_loaders = {x: DataLoader(image_datasets[x], batch_size=BATCH_SIZE, shuffle=True) for x in ['train', 'val']}

data_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']}
class_names = image_datasets['train'].classes
print(data_sizes, class_names)
登录后复制
{'train': 244, 'val': 153} ['ants', 'bees']
登录后复制

训练集 244图片 , 测试集153图片 。

可视化部分图片看看,由于visdom支持tensor输入 ,不用换成numpy,直接用tensor计算即可 :

inputs, classes = next(iter(data_loaders['val']))
out = torchvision.utils.make_grid(inputs)
inp = torch.transpose(out, 0, 2)
mean = torch.FloatTensor([0.485, 0.456, 0.406])
std = torch.FloatTensor([0.229, 0.224, 0.225])
inp = std * inp + mean
inp = torch.transpose(inp, 0, 2)
viz.images(inp)
登录后复制

创建CNN

net 根据上一篇的处理cifar10的改了一下规格:

class CNN(nn.Module):
  def __init__(self, in_dim, n_class):
    super(CNN, self).__init__()
    self.cnn = nn.Sequential(
      nn.BatchNorm2d(in_dim),
      nn.ReLU(True),
      nn.Conv2d(in_dim, 16, 7), # 224 >> 218
      nn.BatchNorm2d(16),
      nn.ReLU(inplace=True),
      nn.MaxPool2d(2, 2), # 218 >> 109
      nn.ReLU(True),
      nn.Conv2d(16, 32, 5), # 105
      nn.BatchNorm2d(32),
      nn.ReLU(True),
      nn.Conv2d(32, 64, 5), # 101
      nn.BatchNorm2d(64),
      nn.ReLU(True),
      nn.Conv2d(64, 64, 3, 1, 1),
      nn.BatchNorm2d(64),
      nn.ReLU(True),
      nn.MaxPool2d(2, 2), # 101 >> 50
      nn.Conv2d(64, 128, 3, 1, 1), #
      nn.BatchNorm2d(128),
      nn.ReLU(True),
      nn.MaxPool2d(3), # 50 >> 16
    )
    self.fc = nn.Sequential(
      nn.Linear(128*16*16, 120),
      nn.BatchNorm1d(120),
      nn.ReLU(True),
      nn.Linear(120, n_class))
  def forward(self, x):
    out = self.cnn(x)
    out = self.fc(out.view(-1, 128*16*16))
    return out

# 输入3层rgb ,输出 分类 2    
model = CNN(3, 2)
登录后复制

loss,优化函数:

line = viz.line(Y=np.arange(10))
loss_f = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=LR, momentum=0.9)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)
登录后复制

参数:

BATCH_SIZE = 4
LR = 0.001
EPOCHS = 10
登录后复制

运行 10个 epoch 看看:

[9/10] train_loss:0.650|train_acc:0.639|test_loss:0.621|test_acc0.706
[10/10] train_loss:0.645|train_acc:0.627|test_loss:0.654|test_acc0.686
Training complete in 1m 16s
Best val Acc: 0.712418
登录后复制

运行 20个看看:

[19/20] train_loss:0.592|train_acc:0.701|test_loss:0.563|test_acc0.712
[20/20] train_loss:0.564|train_acc:0.721|test_loss:0.571|test_acc0.706
Training complete in 2m 30s
Best val Acc: 0.745098
登录后复制

准确率比较低:只有74.5%

我们使用models 里的 resnet18 运行 10个epoch:

model = torchvision.models.resnet18(True)
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, 2)
登录后复制
[9/10] train_loss:0.621|train_acc:0.652|test_loss:0.588|test_acc0.667
[10/10] train_loss:0.610|train_acc:0.680|test_loss:0.561|test_acc0.667
Training complete in 1m 24s
Best val Acc: 0.686275
登录后复制

效果也很一般,想要短时间内就训练出效果很好的models,我们可以下载训练好的state,在此基础上训练:

model = torchvision.models.resnet18(pretrained=True)
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, 2)
登录后复制
[9/10] train_loss:0.308|train_acc:0.877|test_loss:0.160|test_acc0.941
[10/10] train_loss:0.267|train_acc:0.885|test_loss:0.148|test_acc0.954
Training complete in 1m 25s
Best val Acc: 0.954248
登录后复制

10个epoch直接的到95%的准确率。

相关推荐:

pytorch + visdom 处理简单分类问题

以上是pytorch + visdom CNN处理自建图片数据集的方法的详细内容。更多信息请关注PHP中文网其他相关文章!

本站声明
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系admin@php.cn

热AI工具

Undresser.AI Undress

Undresser.AI Undress

人工智能驱动的应用程序,用于创建逼真的裸体照片

AI Clothes Remover

AI Clothes Remover

用于从照片中去除衣服的在线人工智能工具。

Undress AI Tool

Undress AI Tool

免费脱衣服图片

Clothoff.io

Clothoff.io

AI脱衣机

Video Face Swap

Video Face Swap

使用我们完全免费的人工智能换脸工具轻松在任何视频中换脸!

热工具

记事本++7.3.1

记事本++7.3.1

好用且免费的代码编辑器

SublimeText3汉化版

SublimeText3汉化版

中文版,非常好用

禅工作室 13.0.1

禅工作室 13.0.1

功能强大的PHP集成开发环境

Dreamweaver CS6

Dreamweaver CS6

视觉化网页开发工具

SublimeText3 Mac版

SublimeText3 Mac版

神级代码编辑软件(SublimeText3)

小红书发布自动保存图片怎么解决?发布自动保存图片在哪里? 小红书发布自动保存图片怎么解决?发布自动保存图片在哪里? Mar 22, 2024 am 08:06 AM

随着社交媒体的不断发展,小红书已经成为越来越多年轻人分享生活、发现美好事物的平台。许多用户在发布图片时遇到了自动保存的问题,这让他们感到十分困扰。那么,如何解决这个问题呢?一、小红书发布自动保存图片怎么解决?1.清除缓存首先,我们可以尝试清除小红书的缓存数据。步骤如下:(1)打开小红书,点击右下角的“我的”按钮;(2)在个人中心页面,找到“设置”并点击;(3)向下滚动,找到“清除缓存”选项,点击确认。清除缓存后,重新进入小红书,尝试发布图片看是否解决了自动保存的问题。2.更新小红书版本确保你的小

抖音评论里怎么发图片?评论区图片入口在哪里? 抖音评论里怎么发图片?评论区图片入口在哪里? Mar 21, 2024 pm 09:12 PM

随着抖音短视频的火爆,用户们在评论区互动变得更加丰富多彩。有些用户希望在评论中分享图片,以更好地表达自己的观点或情感。那么,抖音评论里怎么发图片呢?本文将为你详细解答这个问题,并为你提供一些相关的技巧和注意事项。一、抖音评论里怎么发图片?1.打开抖音:首先,你需要打开抖音APP,并登录你的账号。2.找到评论区:在浏览或发布短视频时,找到想要评论的地方,点击“评论”按钮。3.输入评论内容:在评论区输入你的评论内容。4.选择发送图片:在输入评论内容的界面,你会看到一个“图片”按钮或者“+”号按钮,点

PyCharm与PyTorch完美结合:安装配置步骤详解 PyCharm与PyTorch完美结合:安装配置步骤详解 Feb 21, 2024 pm 12:00 PM

PyCharm是一款强大的集成开发环境(IDE),而PyTorch是深度学习领域备受欢迎的开源框架。在机器学习和深度学习领域,使用PyCharm和PyTorch进行开发可以极大地提高开发效率和代码质量。本文将详细介绍如何在PyCharm中安装配置PyTorch,并附上具体的代码示例,帮助读者更好地利用这两者的强大功能。第一步:安装PyCharm和Python

在 iPhone 上使图片更清晰的 6 种方法 在 iPhone 上使图片更清晰的 6 种方法 Mar 04, 2024 pm 06:25 PM

Apple最近的iPhone可以通过清晰的细节、饱和度和亮度来捕捉回忆。但有时,您可能会遇到一些问题,这些问题可能会导致图像看起来不那么清晰。尽管iPhone相机上的自动对焦已经取得了长足的进步,可以让您快速拍照,但相机在某些情况下可能会错误地对焦错误的拍摄对象,从而使照片在不需要的区域更加模糊。如果iPhone上的照片看起来失焦或总体上缺乏清晰度,以下帖子应该可以帮助您使它们更清晰。如何在iPhone上使图片更清晰[6种方法]您可以尝试使用本机的“照片”应用来清理照片。如果您需要更多功能和选项

ppt怎么让图片一张一张出来 ppt怎么让图片一张一张出来 Mar 25, 2024 pm 04:00 PM

在PowerPoint中,让图片逐一显示是一种常用的技巧,可以通过设置动画效果来实现。本指南详细介绍了实现这一技巧的步骤,包括基本设置、图片插入、添加动画、调整动画顺序和时间。此外,还提供了高级设置和调整,例如使用触发器、调整动画速度和顺序,以及预览动画效果。通过遵循这些步骤和技巧,用户可以轻松地在PowerPoint中设置图片逐一出现,从而提升演示文稿的视觉效果并吸引观众的注意力。

网页图片加载不出来怎么办?6种解决办法 网页图片加载不出来怎么办?6种解决办法 Mar 15, 2024 am 10:30 AM

  有网友发现打开浏览器网页,网页上的图片迟迟加载不出来,是怎么回事?检查过网络是正常的,那是哪里出现了问题呢?下面小编就给大家介绍一下网页图片加载不出来的六种解决方法。  网页图片加载不出来:  1、网速问题  网页显示不出图片有可能是因为电脑的网速比较慢,电脑中开启的软件比较多,  而我们访问的图片比较大,这就可能因为加载超时,导致图片显示不出来,  可以将比较占网速的软件将关掉,可以去任务管理器查看一下。  2、访问人数过多  网页显示不出图片还有可能是因为我们访问的网页,在同时间段访问的

福昕PDF阅读器如何将pdf文档转成jpg图片-福昕PDF阅读器将pdf文档转成jpg图片的方法 福昕PDF阅读器如何将pdf文档转成jpg图片-福昕PDF阅读器将pdf文档转成jpg图片的方法 Mar 04, 2024 pm 05:49 PM

你们是不是也在使用福昕PDF阅读器软件呢?那么你们知道福昕PDF阅读器如何将pdf文档转成jpg图片吗?下面这篇文章就为大伙带来了福昕PDF阅读器将pdf文档转成jpg图片的方法,感兴趣的小伙伴们快来下文看看吧。先启动福昕PDF阅读器,接着在顶部工具栏找到“特色功能”,然后选择“PDF转其他”功能。在接下来,打开一个名为“福昕pdf在线转换”的网页。在页面上方右侧点击“登录”按钮进行登录,然后打开“PDF转图片”功能。之后点击上传按钮并将想要转换成图片的pdf文件添加进来,添加完毕后点击“开始转

如何使用HTML、CSS和jQuery实现图片合并展示的高级功能 如何使用HTML、CSS和jQuery实现图片合并展示的高级功能 Oct 27, 2023 pm 04:36 PM

如何使用HTML、CSS和jQuery实现图片合并展示的高级功能概述:在网页设计中,图片展示是一个重要的环节,而图片合并展示是提高页面加载速度和提升用户体验的常用技巧之一。本文将介绍如何使用HTML、CSS和jQuery来实现图片合并展示的高级功能,并提供具体的代码示例。一、HTML布局:首先,我们需要在HTML中创建一个容器来展示合并后的图片。可以使用di

See all articles