Home > Backend Development > Python Tutorial > MNIST in PyTorch

MNIST in PyTorch

Susan Sarandon
Release: 2024-12-23 05:04:31
Original
547 people have browsed it

Buy Me a Coffee☕

*My post explains MNIST.

MNIST() can use MNIST dataset as shown below:

*Memos:

  • The 1st argument is root(Required-Type:str or pathlib.Path). *An absolute or relative path is possible.
  • The 2nd argument is train(Optional-Default:False-Type:float). *If it's True, train data(60,000 samples) is used while if it's False, test data(60,000 samples) is used.
  • The 3rd argument is transform(Optional-Default:None-Type:callable).
  • The 4th argument is target_transform(Optional-Default:None-Type:callable).
  • The 5th argument is download(Optional-Default:False-Type:bool): *Memos:
    • If it's True, the dataset is downloaded from the internet and extracted(unzipped) to root.
    • If it's True and the dataset is already downloaded, it's extracted.
    • If it's True and the dataset is already downloaded and extracted, nothing happens.
    • It should be False if the dataset is already downloaded and extracted because it's faster.
    • You can manually download and extract the dataset from here to e.g. data/MNIST/raw/.
from torchvision.datasets import MNIST

train_data = MNIST(
    root="data"
)

train_data = MNIST(
    root="data",
    train=True,
    transform=None,
    target_transform=None,
    download=False
)

train_data
# Dataset MNIST
#     Number of datapoints: 60000
#     Root location: data
#     Split: Train

train_data.root
# 'data'

train_data.train
# True

print(train_data.transform)
# None

print(train_data.target_transform)
# None

train_data.download
# <bound method MNIST.download of Dataset MNIST
#     Number of datapoints: 60000
#     Root location: data
#     Split: Train>

train_data[0]
# (<PIL.Image.Image image mode=L size=28x28>, 5)

train_data[1]
# (<PIL.Image.Image image mode=L size=28x28>, 0)

train_data[2]
# (<PIL.Image.Image image mode=L size=28x28>, 4)

train_data[3]
# (<PIL.Image.Image image mode=L size=28x28>, 1)

train_data.classes
# ['0 - zero',
#  '1 - one',
#  '2 - two',
#  '3 - three',
#  '4 - four',
#  '5 - five',
#  '6 - six',
#  '7 - seven',
#  '8 - eight',
#  '9 - nine']
Copy after login
from torchvision.datasets import MNIST

train_data = MNIST(
    root="data"
)

test_data = MNIST(
    root="data",
    train=False
)

import matplotlib.pyplot as plt

def show_images(data):
    plt.figure(figsize=(10, 2))
    col = 4
    for i, (image, label) in enumerate(data, 1):
        plt.subplot(1, col, i)
        plt.title(label)
        plt.imshow(image)
        if i == col:
            break
    plt.show()

show_images(data=train_data)
show_images(data=test_data)
Copy after login

MNIST in PyTorch

The above is the detailed content of MNIST in PyTorch. For more information, please follow other related articles on the PHP Chinese website!

source:dev.to
Statement of this Website
The content of this article is voluntarily contributed by netizens, and the copyright belongs to the original author. This site does not assume corresponding legal responsibility. If you find any content suspected of plagiarism or infringement, please contact admin@php.cn
Latest Articles by Author
Popular Tutorials
More>
Latest Downloads
More>
Web Effects
Website Source Code
Website Materials
Front End Template