首頁 > 後端開發 > Python教學 > CelebA 是 PyTorch

CelebA 是 PyTorch

Susan Sarandon
發布: 2024-12-28 02:36:10
原創
140 人瀏覽過

請我喝杯咖啡☕

*我的貼文解釋了 CelebA。

CelebA() 可以使用 CelebA 資料集,如下所示:

*備忘錄:

  • 第一個參數是 root(必要類型:str 或 pathlib.Path)。 *絕對或相對路徑都是可能的。
  • 第二個參數是 split(可選-預設:"train"-類型:str)。 *可設定「train」(162,770張圖片)、「valid」(19,867張圖片)、「test」(19,962張圖片)或「all」(202,599張圖片)。
  • 第三個參數是target_type(可選-預設:「attr」-類型:str或str列表): *備註:
    • 可為其設定「attr」、「identity」、「bbox」和/或「landmark」。
    • 也可以設定空列表。
    • 可以設定多個相同的值。
    • 如果值的順序不同,則其元素的順序也會不同。
  • 第四個參數是transform(Optional-Default:None-Type:callable)。
  • 第 5 個參數是 target_transform(Optional-Default:None-Type:callable)。
  • 第 6 個參數是 download(可選-預設:False-類型:bool): *備註:
    • 如果為 True,則從網路下載資料集並解壓縮(解壓縮)到根目錄。
    • 如果為 True 並且資料集已下載,則將其提取。
    • 如果為 True 並且資料集已下載並提取,則不會發生任何事情。
    • 如果資料集已經下載並提取,則應該為 False,因為它速度更快。
    • 下載資料集需要 gdown。
    • 您可以從這裡手動下載並解壓縮資料集(img_align_celeba.zip with Identity_CelebA.txt、list_attr_celeba.txt、list_bbox_celeba.txt、list_eval_partition.txt 和 list_landmarks_align_celeba.txt、list_eval_partition.txt 和 list_landmarks_align_cele
from torchvision.datasets import CelebA

train_attr_data = CelebA(
    root="data"
)

train_attr_data = CelebA(
    root="data",
    split="train",
    target_type="attr",
    transform=None,
    target_transform=None,
    download=False
)

valid_identity_data = CelebA(
    root="data",
    split="valid",
    target_type="identity"
)

test_bbox_data = CelebA(
    root="data",
    split="test",
    target_type="bbox"
)

all_landmarks_data = CelebA(
    root="data",
    split="all",
    target_type="landmarks"
)

all_empty_data = CelebA(
    root="data",
    split="all",
    target_type=[]
)

all_all_data = CelebA(
    root="data",
    split="all",
    target_type=["attr", "identity", "bbox", "landmarks"]
)

len(train_attr_data), len(valid_identity_data), len(test_bbox_data)
# (162770, 19867, 19962)

len(all_landmarks_data), len(all_empty_data), len(all_all_data)
# (202599, 202599, 202599)

train_attr_data
# Dataset CelebA
#     Number of datapoints: 162770
#     Root location: data
#     Target type: ['attr']
#     Split: train

train_attr_data.root
# 'data'

train_attr_data.split
# 'train'

train_attr_data.target_type
# ['attr']

print(train_attr_data.transform)
# None

print(train_attr_data.target_transform)
# None

train_attr_data.download
# <bound method CelebA.download of Dataset CelebA
#     Number of datapoints: 162770
#     Root location: data
#     Target type: ['attr']
#     Split: train>

len(train_attr_data.attr), train_attr_data.attr
# (162770, tensor([[0, 1, 1, ..., 0, 0, 1],
#                  [0, 0, 0, ..., 0, 0, 1],
#                  [0, 0, 0, ..., 0, 0, 1],
#                  ...,
#                  [1, 0, 1, ..., 0, 1, 1],
#                  [0, 0, 0, ..., 0, 0, 1],
#                  [0, 1, 1, ..., 1, 0, 1]]))

len(train_attr_data.attr_names), train_attr_data.attr_names
# (41, ['5_o_Clock_Shadow', 'Arched_Eyebrows', 'Attractive', 
#       'Bags_Under_Eyes', 'Bald', 'Bangs', 'Big_Lips', 'Big_Nose',
#       'Black_Hair', 'Blond_Hair', 'Blurry', 'Brown_Hair',
#       ...
#       'Wearing_Necklace', 'Wearing_Necktie', 'Young', ''])

len(train_attr_data.identity), train_attr_data.identity
# (162770, tensor([[2880], [2937], [8692], ..., [7391], [8610], [2304]]))

len(train_attr_data.bbox), train_attr_data.bbox
# (162770, tensor([[95, 71, 226, 313],
#                  [72, 94, 221, 306],
#                  [216, 59, 91, 126],
#                  ...,
#                  [103, 103, 143, 198],
#                  [30, 59, 216, 280],
#                  [376, 4, 372, 515]]))

len(train_attr_data.landmarks_align), train_attr_data.landmarks_align
# (162770, tensor([[69, 109, 106, ..., 152, 108, 154],
#                  [69, 110, 107, ..., 151, 108, 153],
#                  [76, 112, 104, ..., 156, 98, 158],
#                  ...,
#                  [69, 113, 109, ..., 151, 110, 151],
#                  [68, 112, 109, ..., 150, 108, 151],
#                  [70, 111, 107, ..., 153, 102, 152]]))

train_attr_data[0]
# (<PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=178x218>,
#  tensor([0, 1, 1, 0, 0, 0, 0, 0, 0, 0,
#          0, 1, 0, 0, 0, 0, 0, 0, 1, 1,
#          0, 1, 0, 0, 1, 0, 0, 1, 0, 0,
#          0, 1, 1, 0, 1, 0, 1, 0, 0, 1]))

train_attr_data[1]
# (<PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=178x218>,
#  tensor([0, 0, 0, 1, 0, 0, 0, 1, 0, 0,
#          0, 1, 0, 0, 0, 0, 0, 0, 0, 1,
#          0, 1, 0, 0, 1, 0, 0, 0, 0, 0,
#          0, 1, 0, 0, 0, 0, 0, 0, 0, 1]))

train_attr_data[2]
# (<PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=178x218>,
#  tensor([0, 0, 0, 0, 0, 0, 1, 0, 0, 0,
#          1, 0, 0, 0, 0, 0, 0, 0, 0, 0,
#          1, 0, 0, 1, 1, 0, 0, 1, 0, 0,
#          0, 0, 0, 1, 0, 0, 0, 0, 0, 1]))

valid_identity_data[0]
# (<PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=178x218>,
#  tensor(2594))

valid_identity_data[1]
# (<PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=178x218>,
#  tensor(2795))

valid_identity_data[2]
# (<PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=178x218>,
#  tensor(947))

test_bbox_data[0]
# (<PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=178x218>,
#  tensor([147, 82, 120, 166]))

test_bbox_data[1]
# (<PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=178x218>,
#  tensor([106, 34, 140, 194]))

test_bbox_data[2]
# (<PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=178x218>,
#  tensor([107, 78, 109, 151]))

all_landmarks_data[0]
# (<PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=178x218>,
#  tensor([69, 109, 106, 113, 77, 142, 73, 152, 108, 154]))

all_landmarks_data[1]
# (<PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=178x218>,
#  tensor([69, 110, 107, 112, 81, 135, 70, 151, 108, 153]))

all_landmarks_data[2]
# (<PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=178x218>,
#  tensor([76, 112, 104, 106, 108, 128, 74, 156, 98, 158]))

all_empty_data[0]
# (<PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=178x218>, None)

all_empty_data[1]
# (<PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=178x218>, None)

all_empty_data[2]
# (<PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=178x218>, None)

all_all_data[0]
# (<PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=178x218>,
#  (tensor([0, 1, 1, 0, 0, 0, 0, 0, 0, 0,
#           0, 1, 0, 0, 0, 0, 0, 0, 1, 1,
#           0, 1, 0, 0, 1, 0, 0, 1, 0, 0,
#           0, 1, 1, 0, 1, 0, 1, 0, 0, 1]),
#   tensor(2880),
#   tensor([95, 71, 226, 313]),
#   tensor([69, 109, 106, 113, 77, 142, 73, 152, 108, 154])))

all_all_data[1]
# (<PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=178x218>,
#  (tensor([0, 0, 0, 1, 0, 0, 0, 1, 0, 0,
#           0, 1, 0, 0, 0, 0, 0, 0, 0, 1,
#           0, 1, 0, 0, 1, 0, 0, 0, 0, 0,
#           0, 1, 0, 0, 0, 0, 0, 0, 0, 1]),
#   tensor(2937),
#   tensor([72, 94, 221, 306]),
#   tensor([69, 110, 107, 112, 81, 135, 70, 151, 108, 153])))

all_all_data[2]
# (<PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=178x218>,
#  (tensor([0, 0, 0, 0, 0, 0, 1, 0, 0, 0,
#           1, 0, 0, 0, 0, 0, 0, 0, 0, 0,
#           1, 0, 0, 1, 1, 0, 0, 1, 0, 0,
#           0, 0, 0, 1, 0, 0, 0, 0, 0, 1]),
#  tensor(8692),
#  tensor([216, 59, 91, 126]),
#  tensor([76, 112, 104, 106, 108, 128, 74, 156, 98, 158])))

import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle
from matplotlib.patches import Circle

def show_images(data, main_title=None):
    if "attr" in data.target_type and len(data.target_type) == 1 \
        or not data.target_type:
        plt.figure(figsize=(12, 6))
        plt.suptitle(t=main_title, y=1.0, fontsize=14)
        for i, (im, _) in enumerate(data, start=1):
            plt.subplot(2, 5, i)
            plt.imshow(X=im)
            if i == 10:
                break
        plt.tight_layout(h_pad=3.0)
        plt.show()
    elif "identity" in data.target_type and len(data.target_type) == 1:
        plt.figure(figsize=(12, 6))
        plt.suptitle(t=main_title, y=1.0, fontsize=14)
        for i, (im, lab) in enumerate(data, start=1):
            plt.subplot(2, 5, i)
            plt.title(label=lab.item())
            plt.imshow(X=im)
            if i == 10:
                break
        plt.tight_layout(h_pad=3.0)
        plt.show()
    elif "bbox" in data.target_type and len(data.target_type) == 1:
        fig, axes = plt.subplots(nrows=2, ncols=5, figsize=(12, 6))
        fig.suptitle(t=main_title, y=1.0, fontsize=14)
        for (i, (im, (x, y, w, h))), axis \
            in zip(enumerate(data, start=1), axes.ravel()):
            axis.imshow(X=im)
            rect = Rectangle(xy=(x, y), width=w, height=h,
                             linewidth=3, edgecolor='r',
                             facecolor='none')
            axis.add_patch(p=rect)
            if i == 10:
                break
        fig.tight_layout(h_pad=3.0)
        plt.show()
    elif "landmarks" in data.target_type and len(data.target_type) == 1:
        plt.figure(figsize=(12, 6))
        plt.suptitle(t=main_title, y=1.0, fontsize=14)
        for i, (im, lm) in enumerate(data, start=1):
            px = []
            py = []
            for j, v in enumerate(lm):
                if j%2 == 0:
                    px.append(v)
                else:
                    py.append(v)
            plt.subplot(2, 5, i)
            plt.imshow(X=im)
            plt.scatter(x=px, y=py)
            if i == 10:
                break
        plt.tight_layout(h_pad=3.0)
        plt.show()
    elif len(data.target_type) == 4:
        fig, axes = plt.subplots(nrows=2, ncols=5, figsize=(12, 6))
        fig.suptitle(t=main_title, y=1.0, fontsize=14)
        for (i, (im, (_, lab, (x, y, w, h), lm))), axis \
            in zip(enumerate(data, start=1), axes.ravel()):
            axis.set_title(label=lab.item())
            axis.imshow(X=im)
            rect = Rectangle(xy=(x, y), width=w, height=h,
                             linewidth=3, edgecolor='r',
                             facecolor='none', clip_on=True)
            axis.add_patch(p=rect)
            for j, (px, py) in enumerate(lm.split(2)):
                axis.add_patch(p=Circle(xy=(px, py)))
            # for j, v in enumerate(lm):
            #     if j%2 == 0:
            #         px.append(v)
            #     else:
            #         py.append(v)
            # axis.scatter(x=px, y=py)
            # axis.plot(px, py)
# `axis.scatter()` and `axis.plot()` of `plt.subplots()` don't work
# properly. They shrink images so use `axis.add_patch()` instead.
            if i == 10:
                break
        fig.tight_layout(h_pad=3.0)
        plt.show()

show_images(data=train_attr_data, main_title="train_attr_data")
show_images(data=valid_identity_data, main_title="valid_identity_data")
show_images(data=test_bbox_data, main_title="test_bbox_data")
show_images(data=all_landmarks_data, main_title="all_landmarks_data")
show_images(data=all_empty_data, main_title="all_empty_data")
show_images(data=all_all_data, main_title="all_all_data")
登入後複製

CelebA in PyTorch

CelebA in PyTorch

CelebA in PyTorch

CelebA in PyTorch

CelebA in PyTorch

CelebA in PyTorch

以上是CelebA 是 PyTorch的詳細內容。更多資訊請關注PHP中文網其他相關文章!

來源:dev.to
本網站聲明
本文內容由網友自願投稿,版權歸原作者所有。本站不承擔相應的法律責任。如發現涉嫌抄襲或侵權的內容,請聯絡admin@php.cn
作者最新文章
熱門教學
更多>
最新下載
更多>
網站特效
網站源碼
網站素材
前端模板