在本文中,我们将了解如何使用 YOLO 等对象检测模型以及 CLIP 等多模态嵌入模型来更好地进行图像检索。
这个想法是:CLIP 图像检索的工作原理如下:我们使用 CLIP 模型嵌入我们拥有的图像并将它们存储在某个地方,例如矢量数据库中。然后,在推理过程中,我们可以使用查询图像或提示,将其嵌入,并从可检索的存储嵌入中找到最接近的图像。问题是当嵌入图像有太多对象或某些对象在背景中时,我们仍然希望我们的系统检索它们。这是因为 CLIP 将图像作为一个整体嵌入。可以将其想象为词嵌入模型与句子嵌入模型的关系。我们希望能够搜索与图像中的对象等效的单词。因此,解决方案是使用对象检测模型将图像分解为不同的对象。然后,嵌入这些分解的图像,但将它们链接到其父图像。这将使我们能够检索作物并获得作物起源的亲本。 让我们看看它是如何工作的。
!pip install -q ultralytics torch matplotlib numpy pillow zipfile36 transformers from ultralytics import YOLO import matplotlib.pyplot as plt from PIL import pillow import os from Zipfile import Zipfile, BadZipFile import torch from transformers import CLIPProcessor, CLIPModel, CLIPVisionModelWithProjection, CLIPTextModelWithProjection
!wget http://images.cocodataset.org/zips/val2017.zip -O coco_val2017.zip def extract_zip_file(extract_path): try: with ZipFile(extract_path+".zip") as zfile: zfile.extractall(extract_path) # remove zipfile zfileTOremove=f"{extract_path}"+".zip" if os.path.isfile(zfileTOremove): os.remove(zfileTOremove) else: print("Error: %s file not found" % zfileTOremove) except BadZipFile as e: print("Error:", e) extract_val_path = "./coco_val2017" extract_zip_file(extract_val_path)
然后我们可以拍摄一些图像并创建示例列表。
source = ['coco_val2017/val2017/000000000139.jpg', '/content/coco_val2017/val2017/000000000632.jpg', '/content/coco_val2017/val2017/000000000776.jpg', '/content/coco_val2017/val2017/000000001503.jpg', '/content/coco_val2017/val2017/000000001353.jpg', '/content/coco_val2017/val2017/000000003661.jpg']
在此示例中,我们将使用最新的 Ultralytics Yolo10x 模型以及 OpenAI Clip-vit-base-patch32 。
device = "cuda" # YOLO Model model = YOLO('yolov10x.pt') # Clip model model_id = "openai/clip-vit-base-patch32" image_model = CLIPVisionModelWithProjection.from_pretrained(model_id, device_map = device) text_model = CLIPTextModelWithProjection.from_pretrained(model_id, device_map = device) processor = CLIPProcessor.from_pretrained(model_id)
results = model(source=source, device = "cuda")
让我们用此代码片段向我们展示结果
# Visualize the results fig, ax = plt.subplots(2, 3, figsize=(15, 10)) for i, r in enumerate(results): # Plot results image im_bgr = r.plot() # BGR-order numpy array im_rgb = Image.fromarray(im_bgr[..., ::-1]) # RGB-order PIL image ax[i%2, i//2].imshow(im_rgb) ax[i%2, i//2].set_title(f"Image {i+1}")
所以我们可以看到YOLO模型在检测图像中的物体方面效果很好。它确实会犯一些错误,将显示器标记为电视。但那很好。 YOLO 分配的实际类并不是那么重要,因为我们将使用 CLIP 进行推理。
class CroppedImage: def __init__(self, parent, box, cls): self.parent = parent self.box = box self.cls = cls def display(self, ax = None): im_rgb = Image.open(self.parent) cropped_image = im_rgb.crop(self.box) if ax is not None: ax.imshow(cropped_image) ax.set_title(self.cls) else: plt.figure(figsize=(10, 10)) plt.imshow(cropped_image) plt.title(self.cls) plt.show() def get_cropped_image(self): im_rgb = Image.open(self.parent) cropped_image = im_rgb.crop(self.box) return cropped_image def __str__(self): return f"CroppedImage(parent={self.parent}, boxes={self.box}, cls={self.cls})" def __repr__(self): return self.__str__() class YOLOImage: def __init__(self, image_path, cropped_images): self.image_path = str(image_path) self.cropped_images = cropped_images def get_image(self): return Image.open(self.image_path) def get_caption(self): cls =[] for cropped_image in self.cropped_images: cls.append(cropped_image.cls) unique_cls = set(cls) count_cls = {cls: cls.count(cls) for cls in unique_cls} count_string = " ".join(f"{count} {cls}," for cls, count in count_cls.items()) return "this image contains " + count_string def __str__(self): return self.__repr__() def __repr__(self): cls =[] for cropped_image in self.cropped_images: cls.append(cropped_image.cls) return f"YOLOImage(image={self.image_path}, cropped_images={cls})" class ImageEmbedding: def __init__(self, image_path, embedding, cropped_image = None): self.image_path = image_path self.cropped_image = cropped_image self.embedding = embedding
CroppedImage 类表示从较大的父图像中裁剪出的图像的一部分。它使用父图像的路径、定义裁剪区域的边界框和类标签(例如“猫”或“狗”)进行初始化。此类包含显示裁剪图像并将其作为图像对象检索的方法。该显示方法允许在提供的轴上或通过创建新图形来可视化裁剪部分,使其适用于不同的用例。此外,还实现了 __str__ 和 __repr__ 方法,以便轻松且信息丰富地表示对象的字符串。
YOLOImage 类旨在处理使用 YOLO 对象检测模型处理的图像。它获取原始图像的路径和代表图像中检测到的对象的 CroppedImage 实例列表。该类提供了打开和显示完整图像以及生成总结图像中检测到的对象的标题的方法。标题方法聚合并计算裁剪图像中的唯一类标签,提供图像内容的简洁描述。此类对于管理和解释对象检测任务的结果特别有用。
ImageEmbedding 类具有图像及其关联的嵌入,它是图像特征的数字表示。可以使用图像的路径、嵌入向量以及可选的 CroppedImage 实例(如果嵌入对应于图像的特定裁剪部分)来初始化此类。 ImageEmbedding 类对于涉及图像相似性、分类和检索的任务至关重要,因为它提供了一种结构化方法来存储和访问图像数据及其计算特征。这种集成促进了高效的图像处理和机器学习工作流程。
yolo_images: list[YOLOImage]= [] names= model.names for i, r in enumerate(results): crops:list[CroppedImage] = [] boxes = r.boxes classes = r.boxes.cls for j, box in enumerate(r.boxes): box = tuple(box.xyxy.flatten().cpu().numpy()) cropped_image = CroppedImage(parent = r.path, box = box, cls = names[classes[j].int().item()]) crops.append(cropped_image) yolo_images.append(YOLOImage(image_path=r.path, cropped_images=crops))
image_embeddings = [] for image in yolo_images: input = processor.image_processor(images= image.get_image(), return_tensors = 'pt') input.to(device) embeddings = image_model(pixel_values = input.pixel_values).image_embeds embeddings = embeddings/embeddings.norm(p=2, dim = -1, keepdim = True) # Normalize the embeddings image_embedding = ImageEmbedding(image_path = image.image_path, embedding = embeddings) image_embeddings.append(image_embedding) for cropped_image in image.cropped_images: input = processor.image_processor(images= cropped_image.get_cropped_image(), return_tensors = 'pt') input.to(device) embeddings = image_model(pixel_values = input.pixel_values).image_embeds embeddings = embeddings/embeddings.norm(p=2, dim = -1, keepdim = True) # Normalize the embeddings image_embedding = ImageEmbedding(image_path = image.image_path, embedding = embeddings, cropped_image = cropped_image) image_embeddings.append(image_embedding) **image_embeddings_tensor = torch.stack([image_embedding.embedding for image_embedding in image_embeddings]).squeeze()**
如果愿意,我们现在可以将这些图像嵌入存储在矢量数据库中。但在这个例子中,我们将仅使用内点积技术来检查相似性并检索图像。
query = "image of a flowerpot" text_embedding = processor.tokenizer(query, return_tensors="pt").to(device) text_embedding = text_model(**text_embedding).text_embeds similarities = (torch.matmul(text_embedding, image_embeddings_tensor.T)).flatten().detach().cpu().numpy() # get the top 5 similar images k = 5 top_k_indices = similarities.argsort()[-k:] # Display the top 5 results fig, ax = plt.subplots(2, 5, figsize=(20, 5)) for i, index in enumerate(top_k_indices): if image_embeddings[index].cropped_image is not None: image_embeddings[index].cropped_image.display(ax = ax[0][i]) else: ax[0][i].imshow(Image.open(image_embeddings[index].image_path)) ax[1][i].imshow(Image.open(image_embeddings[index].image_path)) ax[0][i].axis('off') ax[1][i].axis('off') ax[1][i].set_title("Original Image") plt.show()
您可以看到,我们甚至能够检索隐藏在背景中的小植物。有时它也会拉出原始图像作为结果,因为我们也嵌入了它。
这是一项非常强大的技术。您还可以微调您自己的图像的检测和嵌入模型,并进一步提高性能。
一个缺点是我们必须对检测到的所有对象运行 CLIP 模型。缓解这种情况的一种方法是限制 YOLO 生产的盒子数量。
您可以通过此链接查看 Colab 上的代码。
想要连接吗?
?我的网站
?我的推特
?我的 LinkedIn
以上是使用 YOLO 和 CLIP 来改进检索的详细内容。更多信息请关注PHP中文网其他相关文章!