PyTorch の FashionMNIST

Patricia Arquette
リリース: 2024-12-11 15:24:16
オリジナル
759 人が閲覧しました

コーヒー買ってきて☕

*私の投稿ではファッション MNIST について説明しています。

FashionMNIST() は、以下に示すように Fashion-MNIST データセットを使用できます。

*メモ:

  • 最初の引数は root(Required-Type:str または pathlib.Path) です。 *絶対パスまたは相対パスが可能です。
  • 2 番目の引数は train(Optional-Default:True-Type:bool) です。 ※Trueの場合はトレーニングデータ(60,000枚)、Falseの場合はテストデータ(10,000枚)を使用します。
  • 3 番目の引数は、transform(Optional-Default:None-Type:callable) です。
  • 4 番目の引数は target_transform(Optional-Default:None-Type:callable) です。
  • 5 番目の引数は download(Optional-Default:False-Type:bool) です。 *メモ:
    • True の場合、データセットはインターネットからダウンロードされ、ルートに抽出 (解凍) されます。
    • これが True で、データセットが既にダウンロードされている場合、データセットは抽出されます。
    • これが True で、データセットがすでにダウンロードされ抽出されている場合は、何も起こりません。
    • データセットが既にダウンロードされ抽出されている場合は、その方が高速であるため、False にする必要があります。
    • データセット (t10k-images-idx3-ubyte.gz、t10k-labels-idx1-ubyte.gz、train-images-idx3-ubyte.gz、train-labels-idx1-ubyte) を手動でダウンロードして抽出できます。 gz) ここから data/FashionMNIST/raw/ に移動します。
from torchvision.datasets import FashionMNIST

train_data = FashionMNIST(
    root="data"
)

train_data = FashionMNIST(
    root="data",
    train=True,
    transform=None,
    target_transform=None,
    download=False
)

test_data = FashionMNIST(
    root="data",
    train=False
)

len(train_data), len(test_data)
# (60000, 10000)

train_data
# Dataset FashionMNIST
#     Number of datapoints: 60000
#     Root location: data
#     Split: Train

train_data.root
# 'data'

train_data.train
# True

print(train_data.transform)
# None

print(train_data.target_transform)
# None

train_data.download
# <bound method MNIST.download of Dataset FashionMNIST
#     Number of datapoints: 60000
#     Root location: data
#     Split: Train>

len(train_data.classes)
# 10

train_data.classes
# ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat',
#  'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot']

train_data[0]
# (<PIL.Image.Image image mode=L size=28x28>, 9)

train_data[1]
# (<PIL.Image.Image image mode=L size=28x28>, 0)

train_data[2]
# (<PIL.Image.Image image mode=L size=28x28>, 0)

train_data[3]
# (<PIL.Image.Image image mode=L size=28x28>, 3)

train_data[4]
# (<PIL.Image.Image image mode=L size=28x28>, 0)

import matplotlib.pyplot as plt

def show_images(data, main_title=None):
    plt.figure(figsize=(8, 4))
    plt.suptitle(t=main_title, y=1.0, fontsize=14)
    for i, (image, label) in enumerate(data, 1):
        plt.subplot(2, 5, i)
        plt.tight_layout()
        plt.title(label)
        plt.imshow(image)
        if i == 10:
            break
    plt.show()

show_images(data=train_data, main_title="train_data")
show_images(data=test_data, main_title="test_data")
ログイン後にコピー

FashionMNIST in PyTorch

以上がPyTorch の FashionMNISTの詳細内容です。詳細については、PHP 中国語 Web サイトの他の関連記事を参照してください。

ソース:dev.to
このウェブサイトの声明
この記事の内容はネチズンが自主的に寄稿したものであり、著作権は原著者に帰属します。このサイトは、それに相当する法的責任を負いません。盗作または侵害の疑いのあるコンテンツを見つけた場合は、admin@php.cn までご連絡ください。
著者別の最新記事
人気のチュートリアル
詳細>
最新のダウンロード
詳細>
ウェブエフェクト
公式サイト
サイト素材
フロントエンドテンプレート