CelebA は PyTorch です

Susan Sarandon
リリース: 2024-12-28 02:36:10
オリジナル
139 人が閲覧しました

コーヒー買ってきて☕

*私の投稿では CelebA について説明しています。

CelebA() は以下に示すように CelebA データセットを使用できます。

*メモ:

  • 最初の引数は root(Required-Type:str または pathlib.Path) です。 *絶対パスまたは相対パスが可能です。
  • 2番目の引数はsplit(Optional-Default:"train"-Type:str)です。 ※「train」(162,770枚)、「valid」(19,867枚)、「test」(19,962枚)、「all」(202,599枚)が設定可能です。
  • 3 番目の引数は target_type(Optional-Default:"attr"-Type:str または str のリスト) です。 *メモ:
    • 「attr」、「identity」、「bbox」、および/または「landmarks」を設定できます。
    • 空のリストを設定することもできます。
    • 同じ値を複数設定できます。
    • 値の順序が異なる場合、その要素の順序も異なります。
  • 4 番目の引数は、transform(Optional-Default:None-Type:callable) です。
  • 5 番目の引数は target_transform(Optional-Default:None-Type:callable) です。
  • 6 番目の引数は download(Optional-Default:False-Type:bool) です。 *メモ:
    • True の場合、データセットはインターネットからダウンロードされ、ルートに抽出 (解凍) されます。
    • これが True で、データセットが既にダウンロードされている場合、データセットは抽出されます。
    • これが True で、データセットがすでにダウンロードされ抽出されている場合は、何も起こりません。
    • データセットがすでにダウンロードされ抽出されている場合は、その方が高速であるため、False にする必要があります。
    • データセットをダウンロードするには gdown が必要です。
    • データセット (img_align_celeba.zip、identity_CelebA.txt、list_attr_celeba.txt、list_bbox_celeba.txt、list_eval_partition.txt、list_landmarks_align_celeba.txt を含む) をここから data/celeba/ に手動でダウンロードして抽出できます。
from torchvision.datasets import CelebA

train_attr_data = CelebA(
    root="data"
)

train_attr_data = CelebA(
    root="data",
    split="train",
    target_type="attr",
    transform=None,
    target_transform=None,
    download=False
)

valid_identity_data = CelebA(
    root="data",
    split="valid",
    target_type="identity"
)

test_bbox_data = CelebA(
    root="data",
    split="test",
    target_type="bbox"
)

all_landmarks_data = CelebA(
    root="data",
    split="all",
    target_type="landmarks"
)

all_empty_data = CelebA(
    root="data",
    split="all",
    target_type=[]
)

all_all_data = CelebA(
    root="data",
    split="all",
    target_type=["attr", "identity", "bbox", "landmarks"]
)

len(train_attr_data), len(valid_identity_data), len(test_bbox_data)
# (162770, 19867, 19962)

len(all_landmarks_data), len(all_empty_data), len(all_all_data)
# (202599, 202599, 202599)

train_attr_data
# Dataset CelebA
#     Number of datapoints: 162770
#     Root location: data
#     Target type: ['attr']
#     Split: train

train_attr_data.root
# 'data'

train_attr_data.split
# 'train'

train_attr_data.target_type
# ['attr']

print(train_attr_data.transform)
# None

print(train_attr_data.target_transform)
# None

train_attr_data.download
# <bound method CelebA.download of Dataset CelebA
#     Number of datapoints: 162770
#     Root location: data
#     Target type: ['attr']
#     Split: train>

len(train_attr_data.attr), train_attr_data.attr
# (162770, tensor([[0, 1, 1, ..., 0, 0, 1],
#                  [0, 0, 0, ..., 0, 0, 1],
#                  [0, 0, 0, ..., 0, 0, 1],
#                  ...,
#                  [1, 0, 1, ..., 0, 1, 1],
#                  [0, 0, 0, ..., 0, 0, 1],
#                  [0, 1, 1, ..., 1, 0, 1]]))

len(train_attr_data.attr_names), train_attr_data.attr_names
# (41, ['5_o_Clock_Shadow', 'Arched_Eyebrows', 'Attractive', 
#       'Bags_Under_Eyes', 'Bald', 'Bangs', 'Big_Lips', 'Big_Nose',
#       'Black_Hair', 'Blond_Hair', 'Blurry', 'Brown_Hair',
#       ...
#       'Wearing_Necklace', 'Wearing_Necktie', 'Young', ''])

len(train_attr_data.identity), train_attr_data.identity
# (162770, tensor([[2880], [2937], [8692], ..., [7391], [8610], [2304]]))

len(train_attr_data.bbox), train_attr_data.bbox
# (162770, tensor([[95, 71, 226, 313],
#                  [72, 94, 221, 306],
#                  [216, 59, 91, 126],
#                  ...,
#                  [103, 103, 143, 198],
#                  [30, 59, 216, 280],
#                  [376, 4, 372, 515]]))

len(train_attr_data.landmarks_align), train_attr_data.landmarks_align
# (162770, tensor([[69, 109, 106, ..., 152, 108, 154],
#                  [69, 110, 107, ..., 151, 108, 153],
#                  [76, 112, 104, ..., 156, 98, 158],
#                  ...,
#                  [69, 113, 109, ..., 151, 110, 151],
#                  [68, 112, 109, ..., 150, 108, 151],
#                  [70, 111, 107, ..., 153, 102, 152]]))

train_attr_data[0]
# (<PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=178x218>,
#  tensor([0, 1, 1, 0, 0, 0, 0, 0, 0, 0,
#          0, 1, 0, 0, 0, 0, 0, 0, 1, 1,
#          0, 1, 0, 0, 1, 0, 0, 1, 0, 0,
#          0, 1, 1, 0, 1, 0, 1, 0, 0, 1]))

train_attr_data[1]
# (<PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=178x218>,
#  tensor([0, 0, 0, 1, 0, 0, 0, 1, 0, 0,
#          0, 1, 0, 0, 0, 0, 0, 0, 0, 1,
#          0, 1, 0, 0, 1, 0, 0, 0, 0, 0,
#          0, 1, 0, 0, 0, 0, 0, 0, 0, 1]))

train_attr_data[2]
# (<PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=178x218>,
#  tensor([0, 0, 0, 0, 0, 0, 1, 0, 0, 0,
#          1, 0, 0, 0, 0, 0, 0, 0, 0, 0,
#          1, 0, 0, 1, 1, 0, 0, 1, 0, 0,
#          0, 0, 0, 1, 0, 0, 0, 0, 0, 1]))

valid_identity_data[0]
# (<PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=178x218>,
#  tensor(2594))

valid_identity_data[1]
# (<PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=178x218>,
#  tensor(2795))

valid_identity_data[2]
# (<PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=178x218>,
#  tensor(947))

test_bbox_data[0]
# (<PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=178x218>,
#  tensor([147, 82, 120, 166]))

test_bbox_data[1]
# (<PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=178x218>,
#  tensor([106, 34, 140, 194]))

test_bbox_data[2]
# (<PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=178x218>,
#  tensor([107, 78, 109, 151]))

all_landmarks_data[0]
# (<PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=178x218>,
#  tensor([69, 109, 106, 113, 77, 142, 73, 152, 108, 154]))

all_landmarks_data[1]
# (<PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=178x218>,
#  tensor([69, 110, 107, 112, 81, 135, 70, 151, 108, 153]))

all_landmarks_data[2]
# (<PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=178x218>,
#  tensor([76, 112, 104, 106, 108, 128, 74, 156, 98, 158]))

all_empty_data[0]
# (<PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=178x218>, None)

all_empty_data[1]
# (<PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=178x218>, None)

all_empty_data[2]
# (<PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=178x218>, None)

all_all_data[0]
# (<PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=178x218>,
#  (tensor([0, 1, 1, 0, 0, 0, 0, 0, 0, 0,
#           0, 1, 0, 0, 0, 0, 0, 0, 1, 1,
#           0, 1, 0, 0, 1, 0, 0, 1, 0, 0,
#           0, 1, 1, 0, 1, 0, 1, 0, 0, 1]),
#   tensor(2880),
#   tensor([95, 71, 226, 313]),
#   tensor([69, 109, 106, 113, 77, 142, 73, 152, 108, 154])))

all_all_data[1]
# (<PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=178x218>,
#  (tensor([0, 0, 0, 1, 0, 0, 0, 1, 0, 0,
#           0, 1, 0, 0, 0, 0, 0, 0, 0, 1,
#           0, 1, 0, 0, 1, 0, 0, 0, 0, 0,
#           0, 1, 0, 0, 0, 0, 0, 0, 0, 1]),
#   tensor(2937),
#   tensor([72, 94, 221, 306]),
#   tensor([69, 110, 107, 112, 81, 135, 70, 151, 108, 153])))

all_all_data[2]
# (<PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=178x218>,
#  (tensor([0, 0, 0, 0, 0, 0, 1, 0, 0, 0,
#           1, 0, 0, 0, 0, 0, 0, 0, 0, 0,
#           1, 0, 0, 1, 1, 0, 0, 1, 0, 0,
#           0, 0, 0, 1, 0, 0, 0, 0, 0, 1]),
#  tensor(8692),
#  tensor([216, 59, 91, 126]),
#  tensor([76, 112, 104, 106, 108, 128, 74, 156, 98, 158])))

import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle
from matplotlib.patches import Circle

def show_images(data, main_title=None):
    if "attr" in data.target_type and len(data.target_type) == 1 \
        or not data.target_type:
        plt.figure(figsize=(12, 6))
        plt.suptitle(t=main_title, y=1.0, fontsize=14)
        for i, (im, _) in enumerate(data, start=1):
            plt.subplot(2, 5, i)
            plt.imshow(X=im)
            if i == 10:
                break
        plt.tight_layout(h_pad=3.0)
        plt.show()
    elif "identity" in data.target_type and len(data.target_type) == 1:
        plt.figure(figsize=(12, 6))
        plt.suptitle(t=main_title, y=1.0, fontsize=14)
        for i, (im, lab) in enumerate(data, start=1):
            plt.subplot(2, 5, i)
            plt.title(label=lab.item())
            plt.imshow(X=im)
            if i == 10:
                break
        plt.tight_layout(h_pad=3.0)
        plt.show()
    elif "bbox" in data.target_type and len(data.target_type) == 1:
        fig, axes = plt.subplots(nrows=2, ncols=5, figsize=(12, 6))
        fig.suptitle(t=main_title, y=1.0, fontsize=14)
        for (i, (im, (x, y, w, h))), axis \
            in zip(enumerate(data, start=1), axes.ravel()):
            axis.imshow(X=im)
            rect = Rectangle(xy=(x, y), width=w, height=h,
                             linewidth=3, edgecolor='r',
                             facecolor='none')
            axis.add_patch(p=rect)
            if i == 10:
                break
        fig.tight_layout(h_pad=3.0)
        plt.show()
    elif "landmarks" in data.target_type and len(data.target_type) == 1:
        plt.figure(figsize=(12, 6))
        plt.suptitle(t=main_title, y=1.0, fontsize=14)
        for i, (im, lm) in enumerate(data, start=1):
            px = []
            py = []
            for j, v in enumerate(lm):
                if j%2 == 0:
                    px.append(v)
                else:
                    py.append(v)
            plt.subplot(2, 5, i)
            plt.imshow(X=im)
            plt.scatter(x=px, y=py)
            if i == 10:
                break
        plt.tight_layout(h_pad=3.0)
        plt.show()
    elif len(data.target_type) == 4:
        fig, axes = plt.subplots(nrows=2, ncols=5, figsize=(12, 6))
        fig.suptitle(t=main_title, y=1.0, fontsize=14)
        for (i, (im, (_, lab, (x, y, w, h), lm))), axis \
            in zip(enumerate(data, start=1), axes.ravel()):
            axis.set_title(label=lab.item())
            axis.imshow(X=im)
            rect = Rectangle(xy=(x, y), width=w, height=h,
                             linewidth=3, edgecolor='r',
                             facecolor='none', clip_on=True)
            axis.add_patch(p=rect)
            for j, (px, py) in enumerate(lm.split(2)):
                axis.add_patch(p=Circle(xy=(px, py)))
            # for j, v in enumerate(lm):
            #     if j%2 == 0:
            #         px.append(v)
            #     else:
            #         py.append(v)
            # axis.scatter(x=px, y=py)
            # axis.plot(px, py)
# `axis.scatter()` and `axis.plot()` of `plt.subplots()` don't work
# properly. They shrink images so use `axis.add_patch()` instead.
            if i == 10:
                break
        fig.tight_layout(h_pad=3.0)
        plt.show()

show_images(data=train_attr_data, main_title="train_attr_data")
show_images(data=valid_identity_data, main_title="valid_identity_data")
show_images(data=test_bbox_data, main_title="test_bbox_data")
show_images(data=all_landmarks_data, main_title="all_landmarks_data")
show_images(data=all_empty_data, main_title="all_empty_data")
show_images(data=all_all_data, main_title="all_all_data")
ログイン後にコピー

CelebA in PyTorch

CelebA in PyTorch

CelebA in PyTorch

CelebA in PyTorch

CelebA in PyTorch

CelebA in PyTorch

以上がCelebA は PyTorch ですの詳細内容です。詳細については、PHP 中国語 Web サイトの他の関連記事を参照してください。

ソース:dev.to
このウェブサイトの声明
この記事の内容はネチズンが自主的に寄稿したものであり、著作権は原著者に帰属します。このサイトは、それに相当する法的責任を負いません。盗作または侵害の疑いのあるコンテンツを見つけた場合は、admin@php.cn までご連絡ください。
著者別の最新記事
人気のチュートリアル
詳細>
最新のダウンロード
詳細>
ウェブエフェクト
公式サイト
サイト素材
フロントエンドテンプレート