Dans cet article, nous allons voir comment utiliser des modèles de détection d'objets comme YOLO ainsi que des modèles d'intégration multimodaux comme CLIP pour améliorer la récupération d'images.
Voici l'idée : la récupération d'images CLIP fonctionne comme suit : nous intégrons les images que nous avons à l'aide d'un modèle CLIP et les stockons quelque part, comme dans une base de données vectorielles. Ensuite, lors de l'inférence, nous pouvons utiliser une image de requête ou une invite, l'intégrer et trouver les images les plus proches des intégrations stockées qui peuvent être récupérées. Le problème est que les images intégrées contiennent trop d'objets ou que certains objets sont en arrière-plan et que nous souhaitons toujours que notre système les récupère. En effet, CLIP intègre l'image dans son ensemble. Pensez-y comme ce qu'est un modèle d'intégration de mots par rapport à un modèle d'intégration de phrases. Nous voulons pouvoir rechercher des mots équivalents à des objets dans une image. La solution consiste donc à décomposer l’image en différents objets à l’aide d’un modèle de détection d’objets. Ensuite, intégrez ces images décomposées mais liez-les à leur image parent. Cela nous permettra de récupérer les cultures et d'obtenir le parent d'où provient la culture. Voyons comment cela fonctionne.
!pip install -q ultralytics torch matplotlib numpy pillow zipfile36 transformers from ultralytics import YOLO import matplotlib.pyplot as plt from PIL import pillow import os from Zipfile import Zipfile, BadZipFile import torch from transformers import CLIPProcessor, CLIPModel, CLIPVisionModelWithProjection, CLIPTextModelWithProjection
!wget http://images.cocodataset.org/zips/val2017.zip -O coco_val2017.zip def extract_zip_file(extract_path): try: with ZipFile(extract_path+".zip") as zfile: zfile.extractall(extract_path) # remove zipfile zfileTOremove=f"{extract_path}"+".zip" if os.path.isfile(zfileTOremove): os.remove(zfileTOremove) else: print("Error: %s file not found" % zfileTOremove) except BadZipFile as e: print("Error:", e) extract_val_path = "./coco_val2017" extract_zip_file(extract_val_path)
Nous pouvons ensuite prendre certaines images et créer une liste d'exemples.
source = ['coco_val2017/val2017/000000000139.jpg', '/content/coco_val2017/val2017/000000000632.jpg', '/content/coco_val2017/val2017/000000000776.jpg', '/content/coco_val2017/val2017/000000001503.jpg', '/content/coco_val2017/val2017/000000001353.jpg', '/content/coco_val2017/val2017/000000003661.jpg']
Dans cet exemple, nous allons utiliser le dernier modèle Ultralytics Yolo10x avec OpenAI clip-vit-base-patch32 .
device = "cuda" # YOLO Model model = YOLO('yolov10x.pt') # Clip model model_id = "openai/clip-vit-base-patch32" image_model = CLIPVisionModelWithProjection.from_pretrained(model_id, device_map = device) text_model = CLIPTextModelWithProjection.from_pretrained(model_id, device_map = device) processor = CLIPProcessor.from_pretrained(model_id)
results = model(source=source, device = "cuda")
Montrons-nous les résultats avec cet extrait de code
# Visualize the results fig, ax = plt.subplots(2, 3, figsize=(15, 10)) for i, r in enumerate(results): # Plot results image im_bgr = r.plot() # BGR-order numpy array im_rgb = Image.fromarray(im_bgr[..., ::-1]) # RGB-order PIL image ax[i%2, i//2].imshow(im_rgb) ax[i%2, i//2].set_title(f"Image {i+1}")
Nous pouvons donc voir que le modèle YOLO fonctionne assez bien pour détecter les objets dans les images. Il fait quelques erreurs en étiquetant le moniteur comme TV. Mais c'est bien. Les classes réelles attribuées par YOLO ne sont pas si essentielles car nous allons utiliser CLIP pour faire l'inférence.
class CroppedImage: def __init__(self, parent, box, cls): self.parent = parent self.box = box self.cls = cls def display(self, ax = None): im_rgb = Image.open(self.parent) cropped_image = im_rgb.crop(self.box) if ax is not None: ax.imshow(cropped_image) ax.set_title(self.cls) else: plt.figure(figsize=(10, 10)) plt.imshow(cropped_image) plt.title(self.cls) plt.show() def get_cropped_image(self): im_rgb = Image.open(self.parent) cropped_image = im_rgb.crop(self.box) return cropped_image def __str__(self): return f"CroppedImage(parent={self.parent}, boxes={self.box}, cls={self.cls})" def __repr__(self): return self.__str__() class YOLOImage: def __init__(self, image_path, cropped_images): self.image_path = str(image_path) self.cropped_images = cropped_images def get_image(self): return Image.open(self.image_path) def get_caption(self): cls =[] for cropped_image in self.cropped_images: cls.append(cropped_image.cls) unique_cls = set(cls) count_cls = {cls: cls.count(cls) for cls in unique_cls} count_string = " ".join(f"{count} {cls}," for cls, count in count_cls.items()) return "this image contains " + count_string def __str__(self): return self.__repr__() def __repr__(self): cls =[] for cropped_image in self.cropped_images: cls.append(cropped_image.cls) return f"YOLOImage(image={self.image_path}, cropped_images={cls})" class ImageEmbedding: def __init__(self, image_path, embedding, cropped_image = None): self.image_path = image_path self.cropped_image = cropped_image self.embedding = embedding
La classe CroppedImage représente une partie d'une image recadrée à partir d'une image parent plus grande. Il est initialisé avec le chemin d'accès à l'image parent, le cadre de délimitation définissant la zone de recadrage et une étiquette de classe (par exemple, « chat » ou « chien »). Cette classe inclut des méthodes pour afficher l’image recadrée et la récupérer en tant qu’objet image. La méthode d'affichage permet de visualiser la partie recadrée soit sur un axe fourni, soit en créant une nouvelle figure, la rendant polyvalente pour différents cas d'utilisation. De plus, les méthodes __str__ et __repr__ sont implémentées pour une représentation sous forme de chaîne simple et informative de l'objet.
La classe YOLOImage est conçue pour gérer les images traitées avec le modèle de détection d'objets YOLO. Il prend le chemin d'accès à l'image d'origine et une liste d'instances CroppedImage qui représentent les objets détectés dans l'image. La classe fournit des méthodes pour ouvrir et afficher l'image complète et générer une légende résumant les objets détectés dans l'image. La méthode de légende regroupe et compte les étiquettes de classe uniques des images recadrées, fournissant une description concise du contenu de l'image. Cette classe est particulièrement utile pour gérer et interpréter les résultats des tâches de détection d'objets.
La classe ImageEmbedding a une image et son intégration associée, qui est une représentation numérique des caractéristiques de l'image. Cette classe peut être initialisée avec le chemin d'accès à l'image, le vecteur d'intégration et éventuellement une instance CroppedImage si l'intégration correspond à une partie recadrée spécifique de l'image. La classe ImageEmbedding est essentielle pour les tâches impliquant la similarité, la classification et la récupération d'images, car elle fournit un moyen structuré de stocker et d'accéder aux données d'image ainsi qu'à leurs fonctionnalités calculées. Cette intégration facilite des flux de travail efficaces de traitement d'images et d'apprentissage automatique.
yolo_images: list[YOLOImage]= [] names= model.names for i, r in enumerate(results): crops:list[CroppedImage] = [] boxes = r.boxes classes = r.boxes.cls for j, box in enumerate(r.boxes): box = tuple(box.xyxy.flatten().cpu().numpy()) cropped_image = CroppedImage(parent = r.path, box = box, cls = names[classes[j].int().item()]) crops.append(cropped_image) yolo_images.append(YOLOImage(image_path=r.path, cropped_images=crops))
image_embeddings = [] for image in yolo_images: input = processor.image_processor(images= image.get_image(), return_tensors = 'pt') input.to(device) embeddings = image_model(pixel_values = input.pixel_values).image_embeds embeddings = embeddings/embeddings.norm(p=2, dim = -1, keepdim = True) # Normalize the embeddings image_embedding = ImageEmbedding(image_path = image.image_path, embedding = embeddings) image_embeddings.append(image_embedding) for cropped_image in image.cropped_images: input = processor.image_processor(images= cropped_image.get_cropped_image(), return_tensors = 'pt') input.to(device) embeddings = image_model(pixel_values = input.pixel_values).image_embeds embeddings = embeddings/embeddings.norm(p=2, dim = -1, keepdim = True) # Normalize the embeddings image_embedding = ImageEmbedding(image_path = image.image_path, embedding = embeddings, cropped_image = cropped_image) image_embeddings.append(image_embedding) **image_embeddings_tensor = torch.stack([image_embedding.embedding for image_embedding in image_embeddings]).squeeze()**
Nous pouvons désormais prendre ces intégrations d'images et les stocker dans une base de données vectorielles si nous le souhaitons. Mais dans cet exemple, nous utiliserons simplement la technique du produit scalaire interne pour vérifier la similarité et récupérer les images.
query = "image of a flowerpot" text_embedding = processor.tokenizer(query, return_tensors="pt").to(device) text_embedding = text_model(**text_embedding).text_embeds similarities = (torch.matmul(text_embedding, image_embeddings_tensor.T)).flatten().detach().cpu().numpy() # get the top 5 similar images k = 5 top_k_indices = similarities.argsort()[-k:] # Display the top 5 results fig, ax = plt.subplots(2, 5, figsize=(20, 5)) for i, index in enumerate(top_k_indices): if image_embeddings[index].cropped_image is not None: image_embeddings[index].cropped_image.display(ax = ax[0][i]) else: ax[0][i].imshow(Image.open(image_embeddings[index].image_path)) ax[1][i].imshow(Image.open(image_embeddings[index].image_path)) ax[0][i].axis('off') ax[1][i].axis('off') ax[1][i].set_title("Original Image") plt.show()
Vous pouvez voir que nous sommes capables de récupérer même les petites plantes cachées en arrière-plan. De plus, parfois, il extrait l'image originale comme résultat car nous l'intégrons également.
Cela peut être une technique très puissante. Vous pouvez également affiner les modèles de détection et d'intégration pour vos propres images et améliorer encore plus les performances.
Un inconvénient est que nous devons exécuter le modèle CLIP sur tous les objets détectés. Une façon d'atténuer ce problème consiste à limiter le nombre de boîtes produites par YOLO.
Vous pouvez consulter le code sur Colab sur ce lien.
Vous voulez vous connecter ?
?Mon site Web
?Mon Twitter
?Mon LinkedIn
Ce qui précède est le contenu détaillé de. pour plus d'informations, suivez d'autres articles connexes sur le site Web de PHP en chinois!