Maison développement back-end Tutoriel Python Le principe pour créer des CNN équivariants réguliers

Le principe pour créer des CNN équivariants réguliers

Jul 18, 2024 am 11:29 AM

Le principe est simplement énoncé comme « Laissez le noyau tourner » et nous nous concentrerons dans cet article sur la façon dont vous pouvez l'appliquer dans vos architectures.

Les architectures équivariantes permettent de former des modèles indifférents à certaines actions de groupe.

Pour comprendre ce que cela signifie exactement, entraînons ce modèle CNN simple sur l'ensemble de données MNIST (un ensemble de données de chiffres manuscrits de 0 à 9).

class SimpleCNN(nn.Module):

    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.cl1 = nn.Conv2d(in_channels=1, out_channels=8, kernel_size=3, padding=1)
        self.max_1 = nn.MaxPool2d(kernel_size=2)
        self.cl2 = nn.Conv2d(in_channels=8, out_channels=16, kernel_size=3, padding=1)
        self.max_2 = nn.MaxPool2d(kernel_size=2)
        self.cl3 = nn.Conv2d(in_channels=16, out_channels=16, kernel_size=7)
        self.dense = nn.Linear(in_features=16, out_features=10)

    def forward(self, x: torch.Tensor):
        x = nn.functional.silu(self.cl1(x))
        x = self.max_1(x)
        x = nn.functional.silu(self.cl2(x))
        x = self.max_2(x)
        x = nn.functional.silu(self.cl3(x))
        x = x.view(len(x), -1)
        logits = self.dense(x)
        return logits
Copier après la connexion
Accuracy on test Accuracy on 90-degree rotated test
97.3% 15.1%

Tableau 1 : Précision des tests du modèle SimpleCNN

Comme prévu, nous obtenons une précision de plus de 95 % sur l'ensemble de données de test, mais que se passe-t-il si nous faisons pivoter l'image de 90 degrés ? Sans aucune contre-mesure appliquée, les résultats tombent à peine meilleurs que ce que l’on aurait pu deviner. Ce modèle serait inutile pour les applications générales.

En revanche, formons une architecture équivariante similaire avec le même nombre de paramètres, où les actions de groupe sont exactement les rotations de 90 degrés.

Accuracy on test Accuracy on 90-degree rotated test
96.5% 96.5%

Tableau 2 : Test de précision du modèle EqCNN avec le même nombre de paramètres que le modèle SimpleCNN

La précision reste la même, et nous n'avons même pas opté pour l'augmentation des données.

Ces modèles deviennent encore plus impressionnants avec des données 3D, mais nous nous en tiendrons à cet exemple pour explorer l'idée de base.

Si vous souhaitez le tester par vous-même, vous pouvez accéder gratuitement à tout le code écrit en PyTorch et JAX sous Github-Repo, et la formation avec Docker ou Podman est possible avec seulement deux commandes.

Amusez-vous !

Alors, qu’est-ce que l’équivariance ?

Les architectures équivariantes garantissent la stabilité des fonctionnalités sous certaines actions de groupe. Les groupes sont des structures simples où les éléments du groupe peuvent être combinés, inversés ou ne rien faire.

Vous pouvez rechercher la définition formelle sur Wikipédia si vous êtes intéressé.

Pour nos besoins, vous pouvez penser à un groupe de rotations de 90 degrés agissant sur des images carrées. Nous pouvons faire pivoter une image de 90, 180, 270 ou 360 degrés. Pour inverser l'action, nous appliquons respectivement une rotation de 270, 180, 90 ou 0 degrés. Il est simple de voir que nous pouvons combiner, inverser ou ne rien faire avec le groupe noté C4C_4C4 . L'image visualise toutes les actions sur une image.

Figure 1: Rotated MNIST image by 90°, 180°, 270°, 360°, respectively
Figure 1 : Image MNIST pivotée de 90°, 180°, 270°, 360°, respectivement

Now, given an input image xxx , our CNN model classifier fθf_\thetafθ , and an arbitrary 90-degree rotation ggg , the equivariant property can be expressed as
fθ(rotate x by g)=fθ(x) f_\theta(\text{rotate } x \text{ by } g) = f_\theta(x) fθ(rotate x by g)=fθ(x)

Generally speaking, we want our image-based model to have the same outputs when rotated.

As such, equivariant models promise us architectures with baked-in symmetries. In the following section, we will see how our principle can be applied to achieve this property.

How to Make Our CNN Equivariant

The problem is the following: When the image rotates, the features rotate too. But as already hinted, we could also compute the features for each rotation upfront by rotating the kernel.
We could actually rotate the kernel, but it is much easier to rotate the feature map itself, thus avoiding interference with PyTorch's autodifferentiation algorithm altogether.

So, in code, our CNN kernel

x = nn.functional.silu(self.cl1(x))
Copier après la connexion

now acts on all four rotated images:

x_0 = x
x_90 = torch.rot90(x, k=1, dims=(2, 3))
x_180 = torch.rot90(x, k=2, dims=(2, 3))
x_270 = torch.rot90(x, k=3, dims=(2, 3))

x_0 = nn.functional.silu(self.cl1(x_0))
x_90 = nn.functional.silu(self.cl1(x_90))
x_180 = nn.functional.silu(self.cl1(x_180))
x_270 = nn.functional.silu(self.cl1(x_270))
Copier après la connexion

Or more compactly written as a 3D convolution:

self.cl1 = nn.Conv3d(in_channels=1, out_channels=8, kernel_size=(1, 3, 3))
...
x = torch.stack([x_0, x_90, x_180, x_270], dim=-3)
x = nn.functional.silu(self.cl1(x))
Copier après la connexion

The resulting equivariant model has just a few lines more compared to the version above:

class EqCNN(nn.Module):

    def __init__(self):
        super(EqCNN, self).__init__()
        self.cl1 = nn.Conv3d(in_channels=1, out_channels=8, kernel_size=(1, 3, 3))
        self.max_1 = nn.MaxPool3d(kernel_size=(1, 2, 2))
        self.cl2 = nn.Conv3d(in_channels=8, out_channels=16, kernel_size=(1, 3, 3))
        self.max_2 = nn.MaxPool3d(kernel_size=(1, 2, 2))
        self.cl3 = nn.Conv3d(in_channels=16, out_channels=16, kernel_size=(1, 5, 5))
        self.dense = nn.Linear(in_features=16, out_features=10)

    def forward(self, x: torch.Tensor):
        x_0 = x
        x_90 = torch.rot90(x, k=1, dims=(2, 3))
        x_180 = torch.rot90(x, k=2, dims=(2, 3))
        x_270 = torch.rot90(x, k=3, dims=(2, 3))

        x = torch.stack([x_0, x_90, x_180, x_270], dim=-3)
        x = nn.functional.silu(self.cl1(x))
        x = self.max_1(x)

        x = nn.functional.silu(self.cl2(x))
        x = self.max_2(x)

        x = nn.functional.silu(self.cl3(x))

        x = x.squeeze()
        x = torch.max(x, dim=-1).values
        logits = self.dense(x)
        return logits
Copier après la connexion

But why is this equivariant to rotations?
First, observe that we get four copies of each feature map at each stage. At the end of the pipeline, we combine all of them with a max operation.

This is key, the max operation is indifferent to which place the rotated version of the feature ends up in.

To understand what is happening, let us plot the feature maps after the first convolution stage.

Figure 2: Feature maps for all four rotations
Figure 2: Feature maps for all four rotations

And now the same features after we rotate the input by 90 degrees.

Figure 3: Feature maps for all four rotations after the input image was rotated
Figure 3 : Cartes de caractéristiques pour les quatre rotations après la rotation de l'image d'entrée

J'ai codé par couleur les cartes correspondantes. Chaque carte de fonctionnalités est décalée de un. Comme l'opérateur max final calcule le même résultat pour ces cartes de caractéristiques décalées, nous obtenons les mêmes résultats.

Dans mon code, je n'ai pas effectué de rotation après la convolution finale, car mes noyaux condensent l'image en un tableau unidimensionnel. Si vous souhaitez développer cet exemple, vous devrez tenir compte de ce fait.

La comptabilisation des actions de groupe ou des « rotations du noyau » joue un rôle essentiel dans la conception d'architectures plus sophistiquées.

Est-ce un déjeuner gratuit ?

Non, nous payons en vitesse de calcul, en biais inductif et en une mise en œuvre plus complexe.

Ce dernier point est quelque peu résolu avec des bibliothèques telles que E3NN, où la plupart des mathématiques lourdes sont abstraites. Néanmoins, il faut tenir compte de beaucoup de choses lors de la conception de l'architecture.

Une faiblesse superficielle est le coût de calcul 4x pour le calcul de toutes les couches d'entités pivotées. Cependant, le matériel moderne doté de parallélisation de masse peut facilement contrecarrer cette charge. En revanche, la formation d’un simple CNN avec augmentation des données dépasserait facilement 10 fois le temps de formation. Cela est encore pire pour les rotations 3D où l'augmentation des données nécessiterait environ 500 fois la quantité d'entraînement pour compenser toutes les rotations possibles.

Dans l'ensemble, la conception d'un modèle d'équivariance est le plus souvent un prix qui vaut la peine d'être payé si l'on veut des fonctionnalités stables.

Quelle est la prochaine étape ?

Les conceptions de modèles équivariants ont explosé ces dernières années, et dans cet article, nous avons à peine effleuré la surface. En fait, nous n'avons même pas exploité pleinement C4C_4C4 groupe encore. Nous aurions pu utiliser des noyaux entièrement 3D. Cependant, notre modèle atteint déjà une précision de plus de 95 %, il n'y a donc aucune raison d'aller plus loin avec cet exemple.

Outre les CNN, les chercheurs ont réussi à traduire ces principes en groupes continus, notamment SO(2) SO(2)SO(2) (le groupe de toutes les rotations dans le plan) et SE(3) SE(3)SE(3) (le groupe de toutes les traductions et rotations dans l'espace 3D).

D'après mon expérience, ces modèles sont absolument époustouflants et atteignent des performances, lorsqu'ils sont entraînés à partir de zéro, comparables aux performances des modèles de base entraînés sur des ensembles de données plusieurs fois plus grands.

Faites-moi savoir si vous souhaitez que j'écrive davantage sur ce sujet.

Autres références

Au cas où vous souhaiteriez une introduction formelle à ce sujet, voici une excellente compilation d'articles, couvrant l'histoire complète de l'équivariance dans l'apprentissage automatique.
AEN

Je prévois en fait de créer un tutoriel pratique et approfondi sur ce sujet. Vous pouvez déjà vous inscrire à ma liste de diffusion et je vous fournirai des versions gratuites au fil du temps, ainsi qu'un canal direct pour les commentaires et les questions-réponses.

À bientôt :)

Ce qui précède est le contenu détaillé de. pour plus d'informations, suivez d'autres articles connexes sur le site Web de PHP en chinois!

Déclaration de ce site Web
Le contenu de cet article est volontairement contribué par les internautes et les droits d'auteur appartiennent à l'auteur original. Ce site n'assume aucune responsabilité légale correspondante. Si vous trouvez un contenu suspecté de plagiat ou de contrefaçon, veuillez contacter admin@php.cn

Outils d'IA chauds

Undresser.AI Undress

Undresser.AI Undress

Application basée sur l'IA pour créer des photos de nu réalistes

AI Clothes Remover

AI Clothes Remover

Outil d'IA en ligne pour supprimer les vêtements des photos.

Undress AI Tool

Undress AI Tool

Images de déshabillage gratuites

Clothoff.io

Clothoff.io

Dissolvant de vêtements AI

AI Hentai Generator

AI Hentai Generator

Générez AI Hentai gratuitement.

Article chaud

R.E.P.O. Crystals d'énergie expliqués et ce qu'ils font (cristal jaune)
1 Il y a quelques mois By 尊渡假赌尊渡假赌尊渡假赌
R.E.P.O. Meilleurs paramètres graphiques
1 Il y a quelques mois By 尊渡假赌尊渡假赌尊渡假赌
Will R.E.P.O. Vous avez un jeu croisé?
1 Il y a quelques mois By 尊渡假赌尊渡假赌尊渡假赌

Outils chauds

Bloc-notes++7.3.1

Bloc-notes++7.3.1

Éditeur de code facile à utiliser et gratuit

SublimeText3 version chinoise

SublimeText3 version chinoise

Version chinoise, très simple à utiliser

Envoyer Studio 13.0.1

Envoyer Studio 13.0.1

Puissant environnement de développement intégré PHP

Dreamweaver CS6

Dreamweaver CS6

Outils de développement Web visuel

SublimeText3 version Mac

SublimeText3 version Mac

Logiciel d'édition de code au niveau de Dieu (SublimeText3)

Comment résoudre le problème des autorisations rencontré lors de la visualisation de la version Python dans le terminal Linux? Comment résoudre le problème des autorisations rencontré lors de la visualisation de la version Python dans le terminal Linux? Apr 01, 2025 pm 05:09 PM

Solution aux problèmes d'autorisation Lors de la visualisation de la version Python dans Linux Terminal Lorsque vous essayez d'afficher la version Python dans Linux Terminal, entrez Python ...

Comment copier efficacement la colonne entière d'une dataframe dans une autre dataframe avec différentes structures dans Python? Comment copier efficacement la colonne entière d'une dataframe dans une autre dataframe avec différentes structures dans Python? Apr 01, 2025 pm 11:15 PM

Lorsque vous utilisez la bibliothèque Pandas de Python, comment copier des colonnes entières entre deux frames de données avec différentes structures est un problème courant. Supposons que nous ayons deux dats ...

Comment enseigner les bases de la programmation novice en informatique dans le projet et les méthodes axées sur les problèmes dans les 10 heures? Comment enseigner les bases de la programmation novice en informatique dans le projet et les méthodes axées sur les problèmes dans les 10 heures? Apr 02, 2025 am 07:18 AM

Comment enseigner les bases de la programmation novice en informatique dans les 10 heures? Si vous n'avez que 10 heures pour enseigner à l'informatique novice des connaissances en programmation, que choisissez-vous d'enseigner ...

Comment éviter d'être détecté par le navigateur lors de l'utilisation de Fiddler partout pour la lecture de l'homme au milieu? Comment éviter d'être détecté par le navigateur lors de l'utilisation de Fiddler partout pour la lecture de l'homme au milieu? Apr 02, 2025 am 07:15 AM

Comment éviter d'être détecté lors de l'utilisation de FiddlereVerywhere pour les lectures d'homme dans le milieu lorsque vous utilisez FiddlereVerywhere ...

Que sont les expressions régulières? Que sont les expressions régulières? Mar 20, 2025 pm 06:25 PM

Les expressions régulières sont des outils puissants pour la correspondance des motifs et la manipulation du texte dans la programmation, améliorant l'efficacité du traitement de texte sur diverses applications.

Comment Uvicorn écoute-t-il en permanence les demandes HTTP sans servir_forever ()? Comment Uvicorn écoute-t-il en permanence les demandes HTTP sans servir_forever ()? Apr 01, 2025 pm 10:51 PM

Comment Uvicorn écoute-t-il en permanence les demandes HTTP? Uvicorn est un serveur Web léger basé sur ASGI. L'une de ses fonctions principales est d'écouter les demandes HTTP et de procéder ...

Quelles sont les bibliothèques Python populaires et leurs utilisations? Quelles sont les bibliothèques Python populaires et leurs utilisations? Mar 21, 2025 pm 06:46 PM

L'article traite des bibliothèques Python populaires comme Numpy, Pandas, Matplotlib, Scikit-Learn, Tensorflow, Django, Flask et Demandes, détaillant leurs utilisations dans le calcul scientifique, l'analyse des données, la visualisation, l'apprentissage automatique, le développement Web et H et H

Comment créer dynamiquement un objet via une chaîne et appeler ses méthodes dans Python? Comment créer dynamiquement un objet via une chaîne et appeler ses méthodes dans Python? Apr 01, 2025 pm 11:18 PM

Dans Python, comment créer dynamiquement un objet via une chaîne et appeler ses méthodes? Il s'agit d'une exigence de programmation courante, surtout si elle doit être configurée ou exécutée ...

See all articles