Rumah > pembangunan bahagian belakang > Tutorial Python > FashionMNIST dalam PyTorch

FashionMNIST dalam PyTorch

Patricia Arquette
Lepaskan: 2024-12-11 15:24:16
asal
773 orang telah melayarinya

Beli Saya Kopi☕

*Siaran saya menerangkan Fashion-MNIST.

FashionMNIST() boleh menggunakan dataset Fashion-MNIST seperti yang ditunjukkan di bawah:

*Memo:

  • Argumen pertama ialah root(Required-Type:str or pathlib.Path). *Laluan mutlak atau relatif boleh dilakukan.
  • Argumen ke-2 ialah train(Pilihan-Lalai:True-Type:bool). *Jika Benar, data kereta api(60,000 imej) digunakan manakala jika Salah, data ujian(10,000 imej) digunakan.
  • Argumen ke-3 ialah transform(Optional-Default:None-Type:callable).
  • Argumen ke-4 ialah target_transform(Optional-Default:None-Type:callable).
  • Argumen ke-5 ialah muat turun(Optional-Default:False-Type:bool): *Memo:
    • Jika Benar, set data dimuat turun dari internet dan diekstrak (dibuka zip) ke akar.
    • Jika ia Benar dan set data sudah dimuat turun, ia akan diekstrak.
    • Jika ia Benar dan set data sudah dimuat turun dan diekstrak, tiada apa yang berlaku.
    • Ia sepatutnya Palsu jika set data sudah dimuat turun dan diekstrak kerana ia lebih pantas.
    • Anda boleh memuat turun dan mengekstrak set data secara manual (t10k-images-idx3-ubyte.gz, t10k-labels-idx1-ubyte.gz, train-images-idx3-ubyte.gz dan train-labels-idx1-ubyte. gz) dari sini ke data/FashionMNIST/raw/.
from torchvision.datasets import FashionMNIST

train_data = FashionMNIST(
    root="data"
)

train_data = FashionMNIST(
    root="data",
    train=True,
    transform=None,
    target_transform=None,
    download=False
)

test_data = FashionMNIST(
    root="data",
    train=False
)

len(train_data), len(test_data)
# (60000, 10000)

train_data
# Dataset FashionMNIST
#     Number of datapoints: 60000
#     Root location: data
#     Split: Train

train_data.root
# 'data'

train_data.train
# True

print(train_data.transform)
# None

print(train_data.target_transform)
# None

train_data.download
# <bound method MNIST.download of Dataset FashionMNIST
#     Number of datapoints: 60000
#     Root location: data
#     Split: Train>

len(train_data.classes)
# 10

train_data.classes
# ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat',
#  'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot']

train_data[0]
# (<PIL.Image.Image image mode=L size=28x28>, 9)

train_data[1]
# (<PIL.Image.Image image mode=L size=28x28>, 0)

train_data[2]
# (<PIL.Image.Image image mode=L size=28x28>, 0)

train_data[3]
# (<PIL.Image.Image image mode=L size=28x28>, 3)

train_data[4]
# (<PIL.Image.Image image mode=L size=28x28>, 0)

import matplotlib.pyplot as plt

def show_images(data, main_title=None):
    plt.figure(figsize=(8, 4))
    plt.suptitle(t=main_title, y=1.0, fontsize=14)
    for i, (image, label) in enumerate(data, 1):
        plt.subplot(2, 5, i)
        plt.tight_layout()
        plt.title(label)
        plt.imshow(image)
        if i == 10:
            break
    plt.show()

show_images(data=train_data, main_title="train_data")
show_images(data=test_data, main_title="test_data")
Salin selepas log masuk

FashionMNIST in PyTorch

Atas ialah kandungan terperinci FashionMNIST dalam PyTorch. Untuk maklumat lanjut, sila ikut artikel berkaitan lain di laman web China PHP!

Kenyataan Laman Web ini
Kandungan artikel ini disumbangkan secara sukarela oleh netizen, dan hak cipta adalah milik pengarang asal. Laman web ini tidak memikul tanggungjawab undang-undang yang sepadan. Jika anda menemui sebarang kandungan yang disyaki plagiarisme atau pelanggaran, sila hubungi admin@php.cn
Artikel terbaru oleh pengarang
Tutorial Popular
Lagi>
Muat turun terkini
Lagi>
kesan web
Kod sumber laman web
Bahan laman web
Templat hujung hadapan