首页 > 后端开发 > Python教程 > PyTorch 中的 CocoDetection (3)

PyTorch 中的 CocoDetection (3)

Mary-Kate Olsen
发布: 2025-01-08 14:13:41
原创
700 人浏览过

请我喝杯咖啡☕

*备忘录:

  • 我的帖子解释了CocoDetection()使用带有captions_train2014.json、instances_train2014.json和person_keypoints_train2014.json的train2014、带有captions_val2014.json、instances_val2014.json和person_keypoints_val2014.json的val2014以及带有image_info_test2014.json的test2017, image_info_test2015.json 和 image_info_test-dev2015.json。
  • 我的帖子解释了CocoDetection()使用train2017与captions_train2017.json,instances_train2017.json和person_keypoints_train2017.json,val2017与captions_val2017.json,instances_val2017.json和person_keypoints_val2017.json和test2017与image_info_test2017.json和image_info_test-dev2017.json.
  • 我的帖子解释了 MS COCO。

CocoDetection() 可以使用 MS COCO 数据集,如下所示。 *这是针对带有 stuff_train2017.json 的 train2017、带有 stuff_val2017.json 的 val2017、带有 stuff_train2017.json 的 stuff_train2017_pixelmaps、带有 stuff_val2017.json 的 stuff_val2017_pixelmaps、带有 panoptic_train2017.json 的 panoptic_train2017、带有 panoptic_train2017.json 的 panoptic_val2017 panoptic_val2017.json 和 unlabeled2017 以及 image_info_unlabeled2017.json:

from torchvision.datasets import CocoDetection

stf_train2017_data = CocoDetection(
    root="data/coco/imgs/train2017",
    annFile="data/coco/anns/stuff_trainval2017/stuff_train2017.json"
)

stf_val2017_data = CocoDetection(
    root="data/coco/imgs/val2017",
    annFile="data/coco/anns/stuff_trainval2017/stuff_val2017.json"
)

len(stf_train2017_data), len(stf_val2017_data)
# (118287, 5000)

# pms_stf_train2017_data = CocoDetection(
#     root="data/coco/anns/stuff_trainval2017/stuff_train2017_pixelmaps",
#     annFile="data/coco/anns/stuff_trainval2017/stuff_train2017.json"
# ) # Error

# pms_stf_val2017_data = CocoDetection(
#     root="data/coco/anns/stuff_trainval2017/stuff_val2017_pixelmaps",
#     annFile="data/coco/anns/stuff_trainval2017/stuff_val2017.json"
# ) # Error

# pan_train2017_data = CocoDetection(
#     root="data/coco/anns/panoptic_trainval2017/panoptic_train2017",
#     annFile="data/coco/anns/panoptic_trainval2017/panoptic_train2017.json"
# ) # Error

# pan_val2017_data = CocoDetection(
#     root="data/coco/anns/panoptic_trainval2017/panoptic_val2017",
#     annFile="data/coco/anns/panoptic_trainval2017/panoptic_val2017.json"
# ) # Error

unlabeled2017_data = CocoDetection(
    root="data/coco/imgs/unlabeled2017",
    annFile="data/coco/anns/unlabeled2017/image_info_unlabeled2017.json"
)

len(unlabeled2017_data)
# 123403

stf_train2017_data[2]
# (<PIL.Image.Image image mode=RGB size=640x428>,
#  [{'segmentation': {'counts': 'W2a0S2Q1T7mNmHS1R7mN...0100000000',
#    'size': [428, 640]}, 'area': 112666.0, 'iscrowd': 0, 'image_id': 30, 
#    'bbox': [0.0, 0.0, 640.0, 321.0], 'category_id': 119, 'id': 10000010},
#   {'segmentation': ..., 'category_id': 124, 'id': 10000011},
#   ...
#   {'segmentation': ..., 'category_id': 183, 'id': 10000014}])

stf_train2017_data[47]
# (<PIL.Image.Image image mode=RGB size=640x427>,
#  [{'segmentation': {'counts': '\\j1h0[<a0G2N001O0...00001O0000',
#    'size': [427, 640]}, 'area': 65213.0, 'iscrowd': 0, 'image_id': 294,
#    'bbox': [140.0, 0.0, 500.0, 326.0], 'category_id': 98, 'id': 10000284}, 
#   {'segmentation': ..., 'category_id': 123, 'id': 10000285},
#   ...
#   {'segmentation': ..., 'category_id': 183, 'id': 10000291}])

stf_train2017_data[64]
# (<PIL.Image.Image image mode=RGB size=480x640>,
#  [{'segmentation': {'counts': '0[9e:1O000000O100000...O5mc0F^Zj7',
#    'size': [640, 480]}, 'area': 20503.0, 'iscrowd': 0, 'image_id': 370,
#    'bbox': [0.0, 0.0, 79.0, 316.0], 'category_id': 102, 'id': 10000383},
#   {'segmentation': ..., 'category_id': 105, 'id': 10000384},
#   ...
#   {'segmentation': ..., 'category_id': 183, 'id': 10000389}])

stf_val2017_data[2]
# (<PIL.Image.Image image mode=RGB size=640x483>,
#  [{'segmentation': {'counts': '\9g5]9O1O1O;EU1kNU1...VMKQ?NY`d3',
#    'size': [483, 640]}, 'area': 5104.0, 'iscrowd': 0, 'image_id': 632,
#    'bbox': [0.0, 300.0, 392.0, 183.0], 'category_id': 93, 'id': 20000017},
#   {'segmentation': ..., 'category_id': 128, 'id': 20000018},
#   ...
#   {'segmentation': ..., 'category_id': 183, 'id': 20000020}])

stf_val2017_data[47]
# (<PIL.Image.Image image mode=RGB size=640x480>,
#  [{'segmentation': {'counts': '[da7T1X>D3M2J5M4M4LoQg1',
#    'size': [480, 640]}, 'area': 122.0, 'iscrowd': 0, 'image_id': 5001,
#    'bbox': [515.0, 235.0, 7.0, 36.0], 'category_id': 104, 'id': 20000247},
#   {'segmentation': ..., 'category_id': 105, 'id': 20000248},
#   ...
#   {'segmentation': ..., 'category_id': 183, 'id': 20000256}])

stf_val2017_data[64]
# (<PIL.Image.Image image mode=RGB size=640x483>,
#  [{'segmentation': {'counts': 'U<^1W>N020mN]B2e>N1O...Mb@N^?2hd2',
#    'size': [500, 375]}, 'area': 2404.0, 'iscrowd': 0, 'image_id': 6763,
#    'bbox': [0.0, 235.0, 369.0, 237.0], 'category_id': 105, 'id': 20000356},
#   {'segmentation': ..., 'category_id': 123, 'id': 20000357},
#   ...
#   {'segmentation': ..., 'category_id': 183, 'id': 20000362}])

unlabeled2017_data[2]
# (<PIL.Image.Image image mode=RGB size=640x427>, [])

unlabeled2017_data[47]
# (<PIL.Image.Image image mode=RGB size=428x640>, [])

unlabeled2017_data[64]
# (<PIL.Image.Image image mode=RGB size=640x480>, [])

import matplotlib.pyplot as plt
from matplotlib.patches import Polygon, Rectangle
import numpy as np
from pycocotools import mask

# `show_images1()` doesn't work very well for the images with
# segmentations so for it, use `show_images2()` which
# more uses the original coco functions. 
def show_images1(data, ims, main_title=None):
    file = data.root.split('/')[-1]
    fig, axes = plt.subplots(nrows=1, ncols=3, figsize=(14, 8))
    fig.suptitle(t=main_title, y=0.9, fontsize=14)
    for i, axis in zip(ims, axes.ravel()):
        if data[i][1] and "segmentation" in data[i][1][0]:
            im, anns = data[i]
            axis.imshow(X=im)
            axis.set_title(label=anns[0]["image_id"])
            ec = ['g', 'r', 'c', 'm', 'y', 'w']
            ec_index = 0
            for ann in anns:
                seg = ann['segmentation']
                compressed_rld = mask.decode(rleObjs=seg)
                y_plts, x_plts = np.nonzero(a=np.squeeze(a=compressed_rld))
                axis.plot(x_plts, y_plts, alpha=0.4)
                x, y, w, h = ann['bbox']
                rect = Rectangle(xy=(x, y), width=w, height=h,
                                 linewidth=3, edgecolor=ec[ec_index],
                                 facecolor='none', zorder=2)
                ec_index += 1
                if ec_index == len(ec)-1:
                    ec_index = 0
                axis.add_patch(p=rect)
        elif not data[i][1]:
            im, _ = data[i]
            axis.imshow(X=im)
    fig.tight_layout()
    plt.show()

ims = (2, 47, 64)

show_images1(data=stf_train2017_data, ims=ims,
             main_title="stf_train2017_data")
show_images1(data=stf_val2017_data, ims=ims, 
             main_title="stf_val2017_data")
show_images1(data=unlabeled2017_data, ims=ims,
             main_title="unlabeled2017_data")

def show_images2(data, index, main_title=None):
    img_set = data[index]
    img, img_anns = img_set
    if img_anns and "segmentation" in img_anns[0]:
        img_id = img_anns[0]['image_id']
        coco = data.coco
        def show_image(imgIds, areaRng=[],
                       iscrowd=None, draw_bbox=False):
            plt.figure(figsize=(11, 8))
            plt.imshow(X=img)
            plt.suptitle(t=main_title, y=1, fontsize=14)
            plt.title(label=img_id, fontsize=14)
            anns_ids = coco.getAnnIds(imgIds=img_id,
                                      areaRng=areaRng, iscrowd=iscrowd)
            anns = coco.loadAnns(ids=anns_ids)
            coco.showAnns(anns=anns, draw_bbox=draw_bbox)
            plt.show()
        show_image(imgIds=img_id, draw_bbox=True)
        show_image(imgIds=img_id, draw_bbox=False)
        show_image(imgIds=img_id, iscrowd=False, draw_bbox=True)
        show_image(imgIds=img_id, areaRng=[0, 5000], draw_bbox=True)
    elif not img_anns:
        plt.figure(figsize=(11, 8))
        plt.imshow(X=img)
        plt.suptitle(t=main_title, y=1, fontsize=14)
        plt.show()

show_images2(data=stf_val2017_data, index=47, 
             main_title="stf_train2017_data")
登录后复制

显示_图像1():

Image description

Image description

Image description

显示图像2():

Image description

Image description

Image description

Image description

以上是PyTorch 中的 CocoDetection (3)的详细内容。更多信息请关注PHP中文网其他相关文章!

来源:dev.to
本站声明
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系admin@php.cn
作者最新文章
热门教程
更多>
最新下载
更多>
网站特效
网站源码
网站素材
前端模板