Maison > Périphériques technologiques > IA > le corps du texte

Classification d'images avec apprentissage en quelques prises de vue à l'aide de PyTorch

WBOY
Libérer: 2023-04-09 10:51:05
avant
1401 Les gens l'ont consulté

Ces dernières années, les modèles basés sur l'apprentissage profond ont donné de bons résultats dans des tâches telles que la détection d'objets et la reconnaissance d'images. Sur des ensembles de données de classification d'images complexes comme ImageNet, qui contient 1 000 classifications d'objets différentes, certains modèles dépassent désormais les niveaux humains. Mais ces modèles s'appuient sur un processus de formation supervisé, ils sont considérablement affectés par la disponibilité de données de formation étiquetées, et les classes que les modèles sont capables de détecter sont limitées aux classes sur lesquelles ils ont été formés.

Comme il n'y a pas suffisamment d'images étiquetées pour toutes les classes pendant la formation, ces modèles peuvent être moins utiles dans des contextes réels. Et nous voulons que le modèle soit capable de reconnaître les classes qu'il n'a pas vues lors de l'entraînement, car il est presque impossible de s'entraîner sur des images de tous les objets potentiels. Le problème pour lequel nous apprendrons à partir de quelques échantillons est appelé apprentissage en quelques coups.

Qu'est-ce que l'apprentissage en quelques étapes ?

Classification dimages avec apprentissage en quelques prises de vue à laide de PyTorch

L'apprentissage en quelques étapes est un sous-domaine de l'apprentissage automatique. Cela implique de classer de nouvelles données avec seulement quelques échantillons de formation et données de supervision. Le modèle que nous avons créé fonctionne raisonnablement bien avec seulement un petit nombre d'échantillons d'apprentissage.

Considérez le scénario suivant : Dans le domaine médical, pour certaines maladies rares, il se peut qu'il n'y ait pas suffisamment d'images radiographiques pour la formation. Pour de tels scénarios, la création d’un classificateur d’apprentissage en quelques étapes est la solution parfaite.

Variation dans de petits échantillons

Généralement, les chercheurs ont identifié quatre types :

  1. N-Shot Learning (NSL)
  2. Few-Shot Learning (FSL)
  3. One-Shot Learning (OSL)
  4. Zero-Shot Learning (ZSL)

Quand on parle de FSL, on fait généralement référence à la classification N-way-K-Shot. N représente le nombre de classes et K représente le nombre d'échantillons à former dans chaque classe. Ainsi, N-Shot Learning est considéré comme un concept plus large que tous les autres concepts. On peut dire que Few-Shot, One-Shot et Zero-Shot sont des sous-domaines de NSL. Alors que l'apprentissage zéro-shot vise à classer les classes invisibles sans aucun exemple de formation.

Dans One-Shot Learning, il n'y a qu'un seul échantillon par classe. Few-Shot propose 2 à 5 échantillons par classe, ce qui signifie que Few-Shot est une version plus flexible de One-Shot Learning.

Méthode d'apprentissage sur petits échantillons

Généralement, deux méthodes doivent être envisagées lors de la résolution du problème d'apprentissage par quelques coups :

Approche au niveau des données (DLA)

Cette stratégie est très simple, s'il n'y a pas suffisamment de données pour créer un modèle solide et éviter les sous-ajustements et les surajustements, alors davantage de données doivent être ajoutées. Pour cette raison, de nombreux problèmes de FLS peuvent être résolus en exploitant davantage de données provenant d’un ensemble de données sous-jacentes plus vaste. Une caractéristique notable de l'ensemble de données de base est qu'il lui manque les classes qui constituent notre ensemble de support pour le défi Few-Shot. Par exemple, si nous souhaitons classer une certaine espèce d’oiseau, l’ensemble de données sous-jacent peut contenir des images de nombreux autres oiseaux.

Approche au niveau des paramètres (PLA)

Du point de vue du niveau des paramètres, les échantillons d'apprentissage par quelques tirs sont relativement faciles à surajuster car ils ont généralement de grands espaces de grande dimension. Restreindre l'espace des paramètres, utiliser la régularisation et utiliser une fonction de perte appropriée aideront à résoudre ce problème. Un petit nombre d'échantillons d'apprentissage seront utilisés par le modèle pour généraliser.

Les performances peuvent être améliorées en guidant le modèle dans un large espace de paramètres. Les méthodes d'optimisation normales peuvent ne pas produire de résultats précis en raison du manque de données d'entraînement.

Pour les raisons ci-dessus, entraîner notre modèle pour trouver le meilleur chemin à travers l'espace des paramètres produit les meilleurs résultats de prédiction. Cette approche est appelée méta-apprentissage.

Algorithme de classification d'images d'apprentissage pour petits échantillons

Il existe 4 méthodes courantes d'apprentissage pour petits échantillons :

Méta-apprentissage indépendant du modèle Méta-apprentissage indépendant du modèle

Méta-apprentissage basé sur les gradients (GBML) Le principe est la base MAML. En GBML, les méta-apprenants acquièrent une expérience préalable grâce à la formation du modèle de base et à l'apprentissage des fonctionnalités partagées dans toutes les représentations de tâches. Chaque fois qu'il y a une nouvelle tâche à apprendre, le méta-apprenant est affiné en utilisant son expérience existante et la quantité minimale de nouvelles données de formation fournies par la nouvelle tâche.

Généralement, si nous initialisons les paramètres de manière aléatoire et les mettons à jour plusieurs fois, l'algorithme ne convergera pas vers de bonnes performances. MAML tente de résoudre ce problème. MAML fournit une initialisation fiable de l'apprenant des métaparamètres avec seulement quelques étapes de gradient et sans surapprentissage, afin que de nouvelles tâches puissent être apprises de manière optimale et rapide.

Les étapes sont les suivantes :

  1. Le méta-apprenant crée sa propre copie C au début de chaque épisode,
  2. C est entraîné sur cet épisode (avec l'aide du modèle de base),
  3. C paires Prédictions sont effectués sur l'ensemble de requêtes,
  4. La perte calculée à partir de ces prédictions est utilisée pour mettre à jour C,
  5. Cela continue jusqu'à ce que l'entraînement sur tous les épisodes soit terminé.

Classification dimages avec apprentissage en quelques prises de vue à laide de PyTorch

Le plus gros avantage de cette technique est qu'elle est considérée comme indépendante du choix de l'algorithme de méta-apprentissage. Par conséquent, les méthodes MAML sont largement utilisées dans de nombreux algorithmes d’apprentissage automatique qui nécessitent une adaptation rapide, notamment dans les réseaux de neurones profonds.

Matching Networks

La première méthode d'apprentissage métrique créée pour résoudre le problème FSL est le Matching Network (MN).

Un grand ensemble de données de base est requis lors de l'utilisation de la méthode de réseau de correspondance pour résoudre le problème d'apprentissage en quelques coups. .

Après avoir divisé cet ensemble de données en plusieurs épisodes, pour chaque épisode, le réseau de correspondance effectue les opérations suivantes :

  • Chaque image de l'ensemble de support et de l'ensemble de requêtes est transmise à un CNN qui génère des fonctionnalités pour elles. L'intégration de
  • requête l'image est obtenue à l'aide du modèle formé sur l'ensemble de supports pour obtenir la distance cosinus des fonctionnalités intégrées, classées par softmax
  • La perte d'entropie croisée des résultats de classification est rétropropagée via CNN pour mettre à jour le modèle d'intégration de fonctionnalités

Le réseau correspondant peut être utilisé de cette manière Apprenez à créer des intégrations d'images. MN est capable de classer les photos en utilisant cette méthode sans aucune connaissance préalable particulière des catégories. Il compare simplement plusieurs instances de la classe.

Étant donné que les catégories varient d'un épisode à l'autre, le réseau de correspondance calcule les attributs d'image (caractéristiques) qui sont importants pour la distinction des catégories. Lors de l'utilisation de la classification standard, l'algorithme sélectionne les caractéristiques uniques à chaque catégorie.

Réseaux prototypiques

Le réseau prototypique (PN) est similaire au réseau correspondant. Il améliore les performances de l'algorithme grâce à quelques changements subtils. PN obtient de meilleurs résultats que MN, mais leur processus de formation est essentiellement le même, il suffit de comparer certaines intégrations d'images de requête de l'ensemble de support, mais le réseau prototype propose des stratégies différentes.

Nous devons créer un prototype de la classe en PN : l'intégration de la classe créée en faisant la moyenne des intégrations des images dans la classe. Seuls ces prototypes de classe sont ensuite utilisés pour comparer les intégrations d’images de requête. Lorsqu'il est utilisé pour des problèmes d'apprentissage à échantillon unique, il est comparable aux réseaux d'appariement.

Réseau relationnel

Le réseau relationnel peut être considéré comme héritant des résultats de la recherche sur toutes les méthodes mentionnées ci-dessus. RN est basé sur les idées de PN mais contient des améliorations significatives de l'algorithme.

La fonction de distance utilisée par cette méthode est apprenable, plutôt que de la définir à l'avance comme les études précédentes. Le module de relation se trouve au-dessus du module d'intégration, qui est la partie qui calcule les intégrations et les prototypes de classe à partir de l'image d'entrée.

L'entrée du module de relation entraînable (fonction de distance) est l'intégration de l'image de requête avec le prototype de chaque classe, et la sortie est le score de relation de chaque correspondance de classe. Le score de relation est transmis via Softmax pour obtenir une prédiction.

Classification d'images avec apprentissage en quelques prises de vue à l'aide de PyTorch

Zero-shot learning à l'aide d'Open-AI Clip

CLIP (Contrastive Language-Image Pre-Training) est un réseau de neurones entraîné sur diverses paires (image, texte). Il peut prédire les fragments de texte les plus pertinents pour une image donnée sans être directement optimisé pour la tâche (similaire à la fonctionnalité zéro tir de GPT-2 et 3).

CLIP peut atteindre les performances du ResNet50 original sur ImageNet "zéro échantillon" et ne nécessite l'utilisation d'aucun exemple étiqueté. Il surmonte plusieurs défis majeurs en vision par ordinateur. Ci-dessous, nous utilisons Pytorch pour implémenter un modèle de classification simple.

Présentation du package

! pip install ftfy regex tqdm
 ! pip install git+https://github.com/openai/CLIP.gitimport numpy as np
 import torch
 from pkg_resources import packaging
 
 print("Torch version:", torch.__version__)
Copier après la connexion

Chargement du modèle

import clipclip.available_models() # it will list the names of available CLIP modelsmodel, preprocess = clip.load("ViT-B/32")
 model.cuda().eval()
 input_resolution = model.visual.input_resolution
 context_length = model.context_length
 vocab_size = model.vocab_size
 
 print("Model parameters:", f"{np.sum([int(np.prod(p.shape)) for p in model.parameters()]):,}")
 print("Input resolution:", input_resolution)
 print("Context length:", context_length)
 print("Vocab size:", vocab_size)
Copier après la connexion

Prétraitement des images

Nous saisirons 8 exemples d'images et leurs descriptions textuelles dans le modèle et comparerons les similitudes entre les fonctionnalités correspondantes.

Le tokenizer n'est pas sensible à la casse et nous sommes libres de donner toute description textuelle appropriée.

 import os
 import skimage
 import IPython.display
 import matplotlib.pyplot as plt
 from PIL import Image
 import numpy as np
 
 from collections import OrderedDict
 import torch
 
 %matplotlib inline
 %config InlineBackend.figure_format = 'retina'
 
 # images in skimage to use and their textual descriptions
 descriptions = {
"page": "a page of text about segmentation",
"chelsea": "a facial photo of a tabby cat",
"astronaut": "a portrait of an astronaut with the American flag",
"rocket": "a rocket standing on a launchpad",
"motorcycle_right": "a red motorcycle standing in a garage",
"camera": "a person looking at a camera on a tripod",
"horse": "a black-and-white silhouette of a horse",
"coffee": "a cup of coffee on a saucer"
 }original_images = []
 images = []
 texts = []
 plt.figure(figsize=(16, 5))
 
 for filename in [filename for filename in os.listdir(skimage.data_dir) if filename.endswith(".png") or filename.endswith(".jpg")]:
name = os.path.splitext(filename)[0]
if name not in descriptions:
continue
 
image = Image.open(os.path.join(skimage.data_dir, filename)).convert("RGB")
 
plt.subplot(2, 4, len(images) + 1)
plt.imshow(image)
plt.title(f"{filename}n{descriptions[name]}")
plt.xticks([])
plt.yticks([])
 
original_images.append(image)
images.append(preprocess(image))
texts.append(descriptions[name])
 
 plt.tight_layout()
Copier après la connexion

La visualisation des résultats est la suivante :

Classification dimages avec apprentissage en quelques prises de vue à laide de PyTorch

Nous normalisons les images, étiquetons chaque entrée de texte et exécutons la propagation avant du modèle pour obtenir les caractéristiques des images et du texte.

image_input = torch.tensor(np.stack(images)).cuda()
 text_tokens = clip.tokenize(["This is " + desc for desc in texts]).cuda()
 
 with torch.no_grad():
Copier après la connexion

Nous normalisons les caractéristiques, calculons le produit scalaire de chaque paire et effectuons un calcul de similarité cosinus

 image_features /= image_features.norm(dim=-1, keepdim=True)
 text_features /= text_features.norm(dim=-1, keepdim=True)
 similarity = text_features.cpu().numpy() @ image_features.cpu().numpy().T
 
 count = len(descriptions)
 
 plt.figure(figsize=(20, 14))
 plt.imshow(similarity, vmin=0.1, vmax=0.3)
 # plt.colorbar()
 plt.yticks(range(count), texts, fontsize=18)
 plt.xticks([])
 for i, image in enumerate(original_images):
plt.imshow(image, extent=(i - 0.5, i + 0.5, -1.6, -0.6), origin="lower")
 for x in range(similarity.shape[1]):
for y in range(similarity.shape[0]):
plt.text(x, y, f"{similarity[y, x]:.2f}", ha="center", va="center", size=12)
 
 for side in ["left", "top", "right", "bottom"]:
plt.gca().spines[side].set_visible(False)
 
 plt.xlim([-0.5, count - 0.5])
 plt.ylim([count + 0.5, -2])
 
 plt.title("Cosine similarity between text and image features", size=20)
Copier après la connexion

Classification dimages avec apprentissage en quelques prises de vue à laide de PyTorch

Classification d'image à échantillon zéro

 from torchvision.datasets import CIFAR100
 cifar100 = CIFAR100(os.path.expanduser("~/.cache"), transform=preprocess, download=True)
 text_descriptions = [f"This is a photo of a {label}" for label in cifar100.classes]
 text_tokens = clip.tokenize(text_descriptions).cuda()
 with torch.no_grad():
text_features = model.encode_text(text_tokens).float()
text_features /= text_features.norm(dim=-1, keepdim=True)
 
 text_probs = (100.0 * image_features @ text_features.T).softmax(dim=-1)
 top_probs, top_labels = text_probs.cpu().topk(5, dim=-1)
 plt.figure(figsize=(16, 16))
 for i, image in enumerate(original_images):
plt.subplot(4, 4, 2 * i + 1)
plt.imshow(image)
plt.axis("off")
 
plt.subplot(4, 4, 2 * i + 2)
y = np.arange(top_probs.shape[-1])
plt.grid()
plt.barh(y, top_probs[i])
plt.gca().invert_yaxis()
plt.gca().set_axisbelow(True)
plt.yticks(y, [cifar100.classes[index] for index in top_labels[i].numpy()])
plt.xlabel("probability")
 
 plt.subplots_adjust(wspace=0.5)
 plt.show()
Copier après la connexion

Classification dimages avec apprentissage en quelques prises de vue à laide de PyTorch

Vous pouvez voir que l'effet de classification est toujours très bon OK

Ce qui précède est le contenu détaillé de. pour plus d'informations, suivez d'autres articles connexes sur le site Web de PHP en chinois!

Étiquettes associées:
source:51cto.com
Déclaration de ce site Web
Le contenu de cet article est volontairement contribué par les internautes et les droits d'auteur appartiennent à l'auteur original. Ce site n'assume aucune responsabilité légale correspondante. Si vous trouvez un contenu suspecté de plagiat ou de contrefaçon, veuillez contacter admin@php.cn
Tutoriels populaires
Plus>
Derniers téléchargements
Plus>
effets Web
Code source du site Web
Matériel du site Web
Modèle frontal