Maison > développement back-end > Tutoriel Python > ImageNet dans PyTorch

ImageNet dans PyTorch

Barbara Streisand
Libérer: 2025-01-04 22:25:40
original
491 Les gens l'ont consulté

Achetez-moi un café☕

*Mon message explique ImageNet.

ImageNet() peut utiliser l'ensemble de données ImageNet comme indiqué ci-dessous :

*Mémos :

  • Le 1er argument est root (Required-Type:str ou pathlib.Path). *Un chemin absolu ou relatif est possible.
  • Le 2ème argument est split(Optional-Default:"train"-Type:str) : *Mémos :
    • "train" (1 281 167 images) ou "val" (50 000 images) peuvent y être définis.
    • "test" (100 000 images) n'est pas pris en charge, j'ai donc demandé la fonctionnalité sur GitHub.
  • Il existe un argument de transformation (Optional-Default:None-Type:callable). *transform= doit être utilisé.
  • Il existe un argument target_transform (Optional-Default:None-Type:callable). - Il existe un argument de transformation (Optional-Default:None-Type:callable). *target_transform= doit être utilisé.
  • Il existe un argument de chargement (Optional-Default:torchvision.datasets.folder.default_loader-Type:callable). *loader= doit être utilisé.
  • Vous devez télécharger manuellement l'ensemble de données (ILSVRC2012_devkit_t12.tar.gz, ILSVRC2012_img_train.tar et ILSVRC2012_img_val.tar dans data/, puis exécuter ImageNet() extrait et charge l'ensemble de données.
  • À propos de l'étiquette des classes pour les indices du train et de l'image de validation respectivement, tanche&Tinca tinca(0) sont 0~1299 et 0~49, poisson rouge &Carassius auratus(1) sont 1300~2599 et 50~99, grand requin blanc&requin blanc&mangeur d'hommes&requin mangeur d'hommes&Carcharodon carcharias(2) sont 2600~3899 et 100~149, requin tigre&Galeocerdo cuvieri(3) sont 3900~5199 et 150~199, requin marteau&requin marteau (4) sont 5200~6499 et 200~249, raie électrique&crampfish&poisson engourdi&torpille(5) sont 6500~7799 et 250~ 299, raie pastenague(6) vaut 7800~9099 et 250~299, coq(7) vaut 9100~10399 et 300~349, poule(8) vaut 10400 ~11699 et 350~399, autruche&Struthio camelus(9) sont 11700~12999 et 400~449, etc.
from torchvision.datasets import ImageNet
from torchvision.datasets.folder import default_loader

train_data = ImageNet(
    root="data"
)

train_data = ImageNet(
    root="data",
    split="train",
    transform=None,
    target_transform=None,
    loader=default_loader
)

val_data = ImageNet(
    root="data",
    split="val"
)

len(train_data), len(val_data)
# (1281167, 50000)

train_data
# Dataset ImageNet
#     Number of datapoints: 1281167
#     Root location: D:/data
#     Split: train

train_data.root
# 'data'

train_data.split
# 'train'

print(train_data.transform)
# None

print(train_data.target_transform)
# None

train_data.loader
# <function torchvision.datasets.folder.default_loader(path: str) -> Any>

len(train_data.classes), train_data.classes
# (1000,
#  [('tench', 'Tinca tinca'), ('goldfish', 'Carassius auratus'),
#   ('great white shark', 'white shark', 'man-eater', 'man-eating shark',
#    'Carcharodon carcharias'), ('tiger shark', 'Galeocerdo cuvieri'),
#   ('hammerhead', 'hammerhead shark'), ('electric ray', 'crampfish',
#    'numbfish', 'torpedo'), ('stingray',), ('cock',), ('hen',),
#   ('ostrich', 'Struthio camelus'), ..., ('bolete',), ('ear', 'spike',
#    'capitulum'), ('toilet tissue', 'toilet paper', 'bathroom tissue')])

train_data[0]
# (<PIL.Image.Image image mode=RGB size=250x250>, 0)

train_data[1]
# (<PIL.Image.Image image mode=RGB size=200x150>, 0)

train_data[2]
# (<PIL.Image.Image image mode=RGB size=500x375>, 0)

train_data[1300]
# (<PIL.Image.Image image mode=RGB size=640x480>, 1)

train_data[2600]
# (<PIL.Image.Image image mode=RGB size=500x375>, 2)

val_data[0]
# (<PIL.Image.Image image mode=RGB size=500x375>, 0)

val_data[1]
# (<PIL.Image.Image image mode=RGB size=500x375>, 0)

val_data[2]
# (<PIL.Image.Image image mode=RGB size=500x375>, 0)

val_data[50]
# (<PIL.Image.Image image mode=RGB size=500x500>, 1)

val_data[100]
# (<PIL.Image.Image image mode=RGB size=679x444>, 2)

import matplotlib.pyplot as plt

def show_images(data, ims, main_title=None):
    plt.figure(figsize=[12, 6])
    plt.suptitle(t=main_title, y=1.0, fontsize=14)
    for i, j in enumerate(iterable=ims, start=1):
        plt.subplot(2, 5, i)
        im, lab = data[j]
        plt.imshow(X=im)
        plt.title(label=lab)
    plt.tight_layout(h_pad=3.0)
    plt.show()

train_ims = [0, 1, 2, 1300, 2600, 3900, 5200, 6500, 7800, 9100]
val_ims = [0, 1, 2, 50, 100, 150, 200, 250, 300, 350]

show_images(data=train_data, ims=train_ims, main_title="train_data")
show_images(data=val_data, ims=val_ims, main_title="val_data")
Copier après la connexion

ImageNet in PyTorch

ImageNet in PyTorch

Ce qui précède est le contenu détaillé de. pour plus d'informations, suivez d'autres articles connexes sur le site Web de PHP en chinois!

source:dev.to
Déclaration de ce site Web
Le contenu de cet article est volontairement contribué par les internautes et les droits d'auteur appartiennent à l'auteur original. Ce site n'assume aucune responsabilité légale correspondante. Si vous trouvez un contenu suspecté de plagiat ou de contrefaçon, veuillez contacter admin@php.cn
Derniers articles par auteur
Tutoriels populaires
Plus>
Derniers téléchargements
Plus>
effets Web
Code source du site Web
Matériel du site Web
Modèle frontal