Maison > développement back-end > Tutoriel Python > Explication détaillée de l'utilisation de PyTorch pour implémenter la détection et le suivi de cibles

Explication détaillée de l'utilisation de PyTorch pour implémenter la détection et le suivi de cibles

coldplay.xixi
Libérer: 2020-12-11 17:18:45
avant
8916 Les gens l'ont consulté

Tutoriel PythonLa colonne présente l'utilisation de PyTorch pour réaliser la détection et le suivi de cibles

Explication détaillée de l'utilisation de PyTorch pour implémenter la détection et le suivi de cibles

Beaucoup de recommandations d'apprentissage gratuites, merci Veuillez visiter le tutoriel Python(vidéo)

Introduction

Dans l'article d'hier, nous avons présenté comment utiliser vos propres images dans PyTorch pour entraîner un classificateur d'images, puis l'utiliser pour la reconnaissance d'images. Cet article montrera comment utiliser un classificateur pré-entraîné pour détecter plusieurs objets dans les images et les suivre dans des vidéos.

Détection d'objets dans les images

Il existe de nombreux algorithmes pour la détection d'objets, YOLO et SSD sont actuellement les algorithmes les plus populaires. Dans cet article, nous utiliserons YOLOv3. Nous ne discuterons pas de YOLO en détail ici. Si vous souhaitez en savoir plus, vous pouvez vous référer au lien ci-dessous ~ (https://pjreddie.com/darknet/yolo/)

Commençons, toujours à partir du module d'import :

from models import *
from utils import *
import os, sys, time, datetime, random
import torch
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torch.autograd import Variable
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from PIL import Image
Copier après la connexion

Chargez ensuite la configuration et les poids pré-entraînés, ainsi que certaines valeurs prédéfinies, notamment : la taille de l'image, le seuil de confiance et le seuil de suppression non maximal.

config_path='config/yolov3.cfg'
weights_path='config/yolov3.weights'
class_path='config/coco.names'
img_size=416
conf_thres=0.8
nms_thres=0.4
# Load model and weights
model = Darknet(config_path, img_size=img_size)
model.load_weights(weights_path)
model.cuda()
model.eval()
classes = utils.load_classes(class_path)
Tensor = torch.cuda.FloatTensor
Copier après la connexion

La fonction suivante renverra les résultats de détection de l'image spécifiée.

def detect_image(img):
    # scale and pad image
    ratio = min(img_size/img.size[0], img_size/img.size[1])
    imw = round(img.size[0] * ratio)
    imh = round(img.size[1] * ratio)
    img_transforms=transforms.Compose([transforms.Resize((imh,imw)),
         transforms.Pad((max(int((imh-imw)/2),0), 
              max(int((imw-imh)/2),0), max(int((imh-imw)/2),0),
              max(int((imw-imh)/2),0)), (128,128,128)),
         transforms.ToTensor(),
         ])
    # convert image to Tensor
    image_tensor = img_transforms(img).float()
    image_tensor = image_tensor.unsqueeze_(0)
    input_img = Variable(image_tensor.type(Tensor))
    # run inference on the model and get detections
    with torch.no_grad():
        detections = model(input_img)
        detections = utils.non_max_suppression(detections, 80, 
                        conf_thres, nms_thres)
    return detections[0]
Copier après la connexion

Enfin, obtenons les résultats de la détection en chargeant une image, puis affichons-la avec un cadre de délimitation autour de l'objet détecté. Et utilisez différentes couleurs pour différentes classes afin de les différencier.

# load image and get detections
img_path = "images/blueangels.jpg"
prev_time = time.time()
img = Image.open(img_path)
detections = detect_image(img)
inference_time = datetime.timedelta(seconds=time.time() - prev_time)
print ('Inference Time: %s' % (inference_time))
# Get bounding-box colors
cmap = plt.get_cmap('tab20b')
colors = [cmap(i) for i in np.linspace(0, 1, 20)]
img = np.array(img)
plt.figure()
fig, ax = plt.subplots(1, figsize=(12,9))
ax.imshow(img)
pad_x = max(img.shape[0] - img.shape[1], 0) * (img_size / max(img.shape))
pad_y = max(img.shape[1] - img.shape[0], 0) * (img_size / max(img.shape))
unpad_h = img_size - pad_y
unpad_w = img_size - pad_x
if detections is not None:
    unique_labels = detections[:, -1].cpu().unique()
    n_cls_preds = len(unique_labels)
    bbox_colors = random.sample(colors, n_cls_preds)
    # browse detections and draw bounding boxes
    for x1, y1, x2, y2, conf, cls_conf, cls_pred in detections:
        box_h = ((y2 - y1) / unpad_h) * img.shape[0]
        box_w = ((x2 - x1) / unpad_w) * img.shape[1]
        y1 = ((y1 - pad_y // 2) / unpad_h) * img.shape[0]
        x1 = ((x1 - pad_x // 2) / unpad_w) * img.shape[1]
        color = bbox_colors[int(np.where(
             unique_labels == int(cls_pred))[0])]
        bbox = patches.Rectangle((x1, y1), box_w, box_h,
             linewidth=2, edgecolor=color, facecolor='none')
        ax.add_patch(bbox)
        plt.text(x1, y1, s=classes[int(cls_pred)], 
                color='white', verticalalignment='top',
                bbox={'color': color, 'pad': 0})
plt.axis('off')
# save image
plt.savefig(img_path.replace(".jpg", "-det.jpg"),        
                  bbox_inches='tight', pad_inches=0.0)
plt.show()
Copier après la connexion

Voici quelques-uns de nos résultats de tests :

Explication détaillée de l'utilisation de PyTorch pour implémenter la détection et le suivi de cibles

Explication détaillée de l'utilisation de PyTorch pour implémenter la détection et le suivi de cibles

Explication détaillée de l'utilisation de PyTorch pour implémenter la détection et le suivi de cibles

Suivi d'objets en vidéo

Vous savez maintenant comment détecter différents objets dans les images. Lorsque vous le regardez image par image dans une vidéo, vous verrez ces cases de suivi bouger. Mais s'il y a plusieurs objets dans ces images vidéo, comment savoir si l'objet dans une image est le même que l'objet dans l'image précédente ? C'est ce qu'on appelle le suivi d'objet et utilise plusieurs détections pour identifier un objet spécifique.

Il existe plusieurs algorithmes pour faire cela, dans cet article j'ai décidé d'utiliser SORT (Simple Online and Realtime Tracking), qui utilise un filtre de Kalman pour prédire la trajectoire d'une cible précédemment identifiée et la comparer avec la nouvelle détection Les résultats correspondants sont très pratiques et rapides.

Commençons maintenant à écrire le code, les 3 premiers extraits de code seront les mêmes que ceux de la détection d'image unique car ils traitent de la détection YOLO sur une seule image. La différence vient dans la dernière partie, pour chaque détection nous appelons la fonction Update de l'objet Sort pour obtenir une référence à l'objet dans l'image. Ainsi, contrairement à la détection régulière de l'exemple précédent (incluant les coordonnées du cadre de sélection et la prédiction de classe), nous obtiendrons l'objet suivi, comprenant en plus des paramètres ci-dessus, un identifiant d'objet. Et vous devez utiliser OpenCV pour lire la vidéo et afficher les images vidéo.

videopath = 'video/interp.mp4'
%pylab inline 
import cv2
from IPython.display import clear_output
cmap = plt.get_cmap('tab20b')
colors = [cmap(i)[:3] for i in np.linspace(0, 1, 20)]
# initialize Sort object and video capture
from sort import *
vid = cv2.VideoCapture(videopath)
mot_tracker = Sort()
#while(True):
for ii in range(40):
    ret, frame = vid.read()
    frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
    pilimg = Image.fromarray(frame)
    detections = detect_image(pilimg)
    img = np.array(pilimg)
    pad_x = max(img.shape[0] - img.shape[1], 0) * 
            (img_size / max(img.shape))
    pad_y = max(img.shape[1] - img.shape[0], 0) * 
            (img_size / max(img.shape))
    unpad_h = img_size - pad_y
    unpad_w = img_size - pad_x
    if detections is not None:
        tracked_objects = mot_tracker.update(detections.cpu())
        unique_labels = detections[:, -1].cpu().unique()
        n_cls_preds = len(unique_labels)
        for x1, y1, x2, y2, obj_id, cls_pred in tracked_objects:
            box_h = int(((y2 - y1) / unpad_h) * img.shape[0])
            box_w = int(((x2 - x1) / unpad_w) * img.shape[1])
            y1 = int(((y1 - pad_y // 2) / unpad_h) * img.shape[0])
            x1 = int(((x1 - pad_x // 2) / unpad_w) * img.shape[1])
            color = colors[int(obj_id) % len(colors)]
            color = [i * 255 for i in color]
            cls = classes[int(cls_pred)]
            cv2.rectangle(frame, (x1, y1), (x1+box_w, y1+box_h),
                         color, 4)
            cv2.rectangle(frame, (x1, y1-35), (x1+len(cls)*19+60,
                         y1), color, -1)
            cv2.putText(frame, cls + "-" + str(int(obj_id)), 
                        (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 
                        1, (255,255,255), 3)
    fig=figure(figsize=(12, 8))
    title("Video Stream")
    imshow(frame)
    show()
    clear_output(wait=True)
Copier après la connexion

Recommandations d'apprentissage gratuites associées : programmation php(vidéo)

Ce qui précède est le contenu détaillé de. pour plus d'informations, suivez d'autres articles connexes sur le site Web de PHP en chinois!

Étiquettes associées:
source:csdn.net
Déclaration de ce site Web
Le contenu de cet article est volontairement contribué par les internautes et les droits d'auteur appartiennent à l'auteur original. Ce site n'assume aucune responsabilité légale correspondante. Si vous trouvez un contenu suspecté de plagiat ou de contrefaçon, veuillez contacter admin@php.cn
Tutoriels populaires
Plus>
Derniers téléchargements
Plus>
effets Web
Code source du site Web
Matériel du site Web
Modèle frontal