Heim Backend-Entwicklung Python-Tutorial Das Prinzip zum Aufbau regulärer äquivarianter CNNs

Das Prinzip zum Aufbau regulärer äquivarianter CNNs

Jul 18, 2024 am 11:29 AM

Das eine Prinzip wird einfach als „Lassen Sie den Kernel rotieren“ ausgedrückt und wir werden uns in diesem Artikel darauf konzentrieren, wie Sie es in Ihren Architekturen anwenden können.

Äquivariante Architekturen ermöglichen es uns, Modelle zu trainieren, die gegenüber bestimmten Gruppenaktionen gleichgültig sind.

Um zu verstehen, was das genau bedeutet, trainieren wir dieses einfache CNN-Modell auf dem MNIST-Datensatz (einem Datensatz handgeschriebener Ziffern von 0-9).

class SimpleCNN(nn.Module):

    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.cl1 = nn.Conv2d(in_channels=1, out_channels=8, kernel_size=3, padding=1)
        self.max_1 = nn.MaxPool2d(kernel_size=2)
        self.cl2 = nn.Conv2d(in_channels=8, out_channels=16, kernel_size=3, padding=1)
        self.max_2 = nn.MaxPool2d(kernel_size=2)
        self.cl3 = nn.Conv2d(in_channels=16, out_channels=16, kernel_size=7)
        self.dense = nn.Linear(in_features=16, out_features=10)

    def forward(self, x: torch.Tensor):
        x = nn.functional.silu(self.cl1(x))
        x = self.max_1(x)
        x = nn.functional.silu(self.cl2(x))
        x = self.max_2(x)
        x = nn.functional.silu(self.cl3(x))
        x = x.view(len(x), -1)
        logits = self.dense(x)
        return logits
Nach dem Login kopieren
Accuracy on test Accuracy on 90-degree rotated test
97.3% 15.1%

Tabelle 1: Testgenauigkeit des SimpleCNN-Modells

Wie erwartet erreichen wir beim Testdatensatz eine Genauigkeit von über 95 %, aber was ist, wenn wir das Bild um 90 Grad drehen? Ohne Gegenmaßnahmen sinken die Ergebnisse auf ein knapp besseres Ergebnis als geschätzt. Dieses Modell wäre für allgemeine Anwendungen unbrauchbar.

Im Gegensatz dazu trainieren wir eine ähnliche äquivariante Architektur mit der gleichen Anzahl von Parametern, bei der die Gruppenaktionen genau die 90-Grad-Rotationen sind.

Accuracy on test Accuracy on 90-degree rotated test
96.5% 96.5%

Tabelle 2: Testgenauigkeit des EqCNN-Modells mit der gleichen Anzahl an Parametern wie das SimpleCNN-Modell

Die Genauigkeit bleibt gleich und wir haben uns nicht einmal für eine Datenerweiterung entschieden.

Diese Modelle werden mit 3D-Daten noch beeindruckender, aber wir bleiben bei diesem Beispiel, um die Kernidee zu untersuchen.

Falls Sie es selbst testen möchten, können Sie unter Github-Repo kostenlos auf den gesamten in PyTorch und JAX geschriebenen Code zugreifen, und das Training mit Docker oder Podman ist mit nur zwei Befehlen möglich.

Viel Spaß!

Was ist also Äquivarianz?

Äquivariante Architekturen garantieren die Stabilität von Funktionen unter bestimmten Gruppenaktionen. Gruppen sind einfache Strukturen, in denen Gruppenelemente kombiniert, umgekehrt oder gar nichts bewirken können.

Bei Interesse können Sie die formale Definition auf Wikipedia nachschlagen.

Für unsere Zwecke können Sie sich eine Gruppe von 90-Grad-Rotationen vorstellen, die auf quadratische Bilder wirken. Wir können ein Bild um 90, 180, 270 oder 360 Grad drehen. Um die Aktion umzukehren, wenden wir eine Drehung um 270, 180, 90 bzw. 0 Grad an. Es ist leicht zu erkennen, dass wir die als bezeichnete Gruppe kombinieren, umkehren oder nichts tun können C4C_4C4 . Das Bild visualisiert alle Aktionen auf einem Bild.

Figure 1: Rotated MNIST image by 90°, 180°, 270°, 360°, respectively
Abbildung 1: Gedrehtes MNIST-Bild um 90°, 180°, 270° bzw. 360°

Now, given an input image xxx , our CNN model classifier fθf_\thetafθ , and an arbitrary 90-degree rotation ggg , the equivariant property can be expressed as
fθ(rotate x by g)=fθ(x) f_\theta(\text{rotate } x \text{ by } g) = f_\theta(x) fθ(rotate x by g)=fθ(x)

Generally speaking, we want our image-based model to have the same outputs when rotated.

As such, equivariant models promise us architectures with baked-in symmetries. In the following section, we will see how our principle can be applied to achieve this property.

How to Make Our CNN Equivariant

The problem is the following: When the image rotates, the features rotate too. But as already hinted, we could also compute the features for each rotation upfront by rotating the kernel.
We could actually rotate the kernel, but it is much easier to rotate the feature map itself, thus avoiding interference with PyTorch's autodifferentiation algorithm altogether.

So, in code, our CNN kernel

x = nn.functional.silu(self.cl1(x))
Nach dem Login kopieren

now acts on all four rotated images:

x_0 = x
x_90 = torch.rot90(x, k=1, dims=(2, 3))
x_180 = torch.rot90(x, k=2, dims=(2, 3))
x_270 = torch.rot90(x, k=3, dims=(2, 3))

x_0 = nn.functional.silu(self.cl1(x_0))
x_90 = nn.functional.silu(self.cl1(x_90))
x_180 = nn.functional.silu(self.cl1(x_180))
x_270 = nn.functional.silu(self.cl1(x_270))
Nach dem Login kopieren

Or more compactly written as a 3D convolution:

self.cl1 = nn.Conv3d(in_channels=1, out_channels=8, kernel_size=(1, 3, 3))
...
x = torch.stack([x_0, x_90, x_180, x_270], dim=-3)
x = nn.functional.silu(self.cl1(x))
Nach dem Login kopieren

The resulting equivariant model has just a few lines more compared to the version above:

class EqCNN(nn.Module):

    def __init__(self):
        super(EqCNN, self).__init__()
        self.cl1 = nn.Conv3d(in_channels=1, out_channels=8, kernel_size=(1, 3, 3))
        self.max_1 = nn.MaxPool3d(kernel_size=(1, 2, 2))
        self.cl2 = nn.Conv3d(in_channels=8, out_channels=16, kernel_size=(1, 3, 3))
        self.max_2 = nn.MaxPool3d(kernel_size=(1, 2, 2))
        self.cl3 = nn.Conv3d(in_channels=16, out_channels=16, kernel_size=(1, 5, 5))
        self.dense = nn.Linear(in_features=16, out_features=10)

    def forward(self, x: torch.Tensor):
        x_0 = x
        x_90 = torch.rot90(x, k=1, dims=(2, 3))
        x_180 = torch.rot90(x, k=2, dims=(2, 3))
        x_270 = torch.rot90(x, k=3, dims=(2, 3))

        x = torch.stack([x_0, x_90, x_180, x_270], dim=-3)
        x = nn.functional.silu(self.cl1(x))
        x = self.max_1(x)

        x = nn.functional.silu(self.cl2(x))
        x = self.max_2(x)

        x = nn.functional.silu(self.cl3(x))

        x = x.squeeze()
        x = torch.max(x, dim=-1).values
        logits = self.dense(x)
        return logits
Nach dem Login kopieren

But why is this equivariant to rotations?
First, observe that we get four copies of each feature map at each stage. At the end of the pipeline, we combine all of them with a max operation.

This is key, the max operation is indifferent to which place the rotated version of the feature ends up in.

To understand what is happening, let us plot the feature maps after the first convolution stage.

Figure 2: Feature maps for all four rotations
Figure 2: Feature maps for all four rotations

And now the same features after we rotate the input by 90 degrees.

Figure 3: Feature maps for all four rotations after the input image was rotated
Abbildung 3: Feature-Maps für alle vier Rotationen, nachdem das Eingabebild gedreht wurde

Ich habe die entsprechenden Karten farblich markiert. Jede Feature-Map wird um eins verschoben. Da der endgültige Max-Operator das gleiche Ergebnis für diese verschobenen Feature-Maps berechnet, erhalten wir die gleichen Ergebnisse.

In meinem Code habe ich nach der letzten Faltung nicht zurückrotiert, da meine Kernel das Bild zu einem eindimensionalen Array verdichten. Wenn Sie dieses Beispiel näher erläutern möchten, müssen Sie diese Tatsache berücksichtigen.

Die Berücksichtigung von Gruppenaktionen oder „Kernelrotationen“ spielt eine entscheidende Rolle beim Entwurf anspruchsvollerer Architekturen.

Ist es ein kostenloses Mittagessen?

Nein, wir bezahlen mit Rechengeschwindigkeit, induktivem Bias und einer komplexeren Implementierung.

Der letzte Punkt lässt sich einigermaßen mit Bibliotheken wie E3NN lösen, in denen der Großteil der schweren Mathematik abstrahiert wird. Dennoch muss man beim Architekturentwurf einiges berücksichtigen.

Eine oberflächliche Schwäche ist der vierfache Rechenaufwand für die Berechnung aller gedrehten Feature-Layer. Moderne Hardware mit Massenparallelisierung kann dieser Belastung jedoch problemlos entgegenwirken. Im Gegensatz dazu würde das Training eines einfachen CNN mit Datenerweiterung die Trainingszeit leicht um das Zehnfache überschreiten. Noch schlimmer wird es bei 3D-Rotationen, bei denen die Datenerweiterung etwa das 500-fache des Trainingsaufwands erfordern würde, um alle möglichen Rotationen zu kompensieren.

Insgesamt ist der Entwurf eines Äquivarianzmodells oft ein lohnenswerter Preis, wenn man stabile Funktionen wünscht.

Was kommt als nächstes?

Äquivariante Modelldesigns haben in den letzten Jahren explosionsartig zugenommen, und in diesem Artikel haben wir kaum an der Oberfläche gekratzt. Tatsächlich haben wir nicht einmal das volle Potenzial ausgeschöpft C4C_4C4 Gruppe noch. Wir hätten vollständige 3D-Kernel verwenden können. Allerdings erreicht unser Modell bereits eine Genauigkeit von über 95 %, sodass es kaum einen Grund gibt, mit diesem Beispiel noch weiter zu gehen.

Außer CNNs haben Forscher diese Prinzipien erfolgreich auf kontinuierliche Gruppen übertragen, darunter SO(2) SO(2)SO(2) (die Gruppe aller Drehungen in der Ebene) und SE(3) SE(3)SE(3) (die Gruppe aller Translationen und Rotationen im 3D-Raum).

Meiner Erfahrung nach sind diese Modelle absolut umwerfend und erreichen eine Leistung, wenn sie von Grund auf trainiert werden, vergleichbar mit der Leistung von Basismodellen, die auf mehrfach größeren Datensätzen trainiert werden.

Lassen Sie mich wissen, wenn Sie möchten, dass ich mehr zu diesem Thema schreibe.

Weitere Referenzen

Falls Sie eine formelle Einführung in dieses Thema wünschen, finden Sie hier eine hervorragende Zusammenstellung von Artikeln, die die gesamte Geschichte der Äquivarianz im maschinellen Lernen abdecken.
AEN

Ich habe tatsächlich vor, ein ausführliches, praktisches Tutorial zu diesem Thema zu erstellen. Sie können sich bereits jetzt für meine Mailingliste anmelden und ich werde Ihnen im Laufe der Zeit kostenlose Versionen zur Verfügung stellen, zusammen mit einem direkten Kanal für Feedback und Fragen und Antworten.

Wir sehen uns :)

Das obige ist der detaillierte Inhalt vonDas Prinzip zum Aufbau regulärer äquivarianter CNNs. Für weitere Informationen folgen Sie bitte anderen verwandten Artikeln auf der PHP chinesischen Website!

Erklärung dieser Website
Der Inhalt dieses Artikels wird freiwillig von Internetnutzern beigesteuert und das Urheberrecht liegt beim ursprünglichen Autor. Diese Website übernimmt keine entsprechende rechtliche Verantwortung. Wenn Sie Inhalte finden, bei denen der Verdacht eines Plagiats oder einer Rechtsverletzung besteht, wenden Sie sich bitte an admin@php.cn

Heiße KI -Werkzeuge

Undresser.AI Undress

Undresser.AI Undress

KI-gestützte App zum Erstellen realistischer Aktfotos

AI Clothes Remover

AI Clothes Remover

Online-KI-Tool zum Entfernen von Kleidung aus Fotos.

Undress AI Tool

Undress AI Tool

Ausziehbilder kostenlos

Clothoff.io

Clothoff.io

KI-Kleiderentferner

Video Face Swap

Video Face Swap

Tauschen Sie Gesichter in jedem Video mühelos mit unserem völlig kostenlosen KI-Gesichtstausch-Tool aus!

Heißer Artikel

<🎜>: Bubble Gum Simulator Infinity - So erhalten und verwenden Sie Royal Keys
4 Wochen vor By 尊渡假赌尊渡假赌尊渡假赌
Nordhold: Fusionssystem, erklärt
4 Wochen vor By 尊渡假赌尊渡假赌尊渡假赌
Mandragora: Flüstern des Hexenbaum
3 Wochen vor By 尊渡假赌尊渡假赌尊渡假赌

Heiße Werkzeuge

Notepad++7.3.1

Notepad++7.3.1

Einfach zu bedienender und kostenloser Code-Editor

SublimeText3 chinesische Version

SublimeText3 chinesische Version

Chinesische Version, sehr einfach zu bedienen

Senden Sie Studio 13.0.1

Senden Sie Studio 13.0.1

Leistungsstarke integrierte PHP-Entwicklungsumgebung

Dreamweaver CS6

Dreamweaver CS6

Visuelle Webentwicklungstools

SublimeText3 Mac-Version

SublimeText3 Mac-Version

Codebearbeitungssoftware auf Gottesniveau (SublimeText3)

Heiße Themen

Java-Tutorial
1671
14
PHP-Tutorial
1276
29
C#-Tutorial
1256
24
Python vs. C: Lernkurven und Benutzerfreundlichkeit Python vs. C: Lernkurven und Benutzerfreundlichkeit Apr 19, 2025 am 12:20 AM

Python ist leichter zu lernen und zu verwenden, während C leistungsfähiger, aber komplexer ist. 1. Python -Syntax ist prägnant und für Anfänger geeignet. Durch die dynamische Tippen und die automatische Speicherverwaltung können Sie die Verwendung einfach zu verwenden, kann jedoch zur Laufzeitfehler führen. 2.C bietet Steuerung und erweiterte Funktionen auf niedrigem Niveau, geeignet für Hochleistungsanwendungen, hat jedoch einen hohen Lernschwellenwert und erfordert manuellem Speicher und Typensicherheitsmanagement.

Python und Zeit: Machen Sie das Beste aus Ihrer Studienzeit Python und Zeit: Machen Sie das Beste aus Ihrer Studienzeit Apr 14, 2025 am 12:02 AM

Um die Effizienz des Lernens von Python in einer begrenzten Zeit zu maximieren, können Sie Pythons DateTime-, Zeit- und Zeitplanmodule verwenden. 1. Das DateTime -Modul wird verwendet, um die Lernzeit aufzuzeichnen und zu planen. 2. Das Zeitmodul hilft, die Studie zu setzen und Zeit zu ruhen. 3. Das Zeitplanmodul arrangiert automatisch wöchentliche Lernaufgaben.

Python vs. C: Erforschung von Leistung und Effizienz erforschen Python vs. C: Erforschung von Leistung und Effizienz erforschen Apr 18, 2025 am 12:20 AM

Python ist in der Entwicklungseffizienz besser als C, aber C ist in der Ausführungsleistung höher. 1. Pythons prägnante Syntax und reiche Bibliotheken verbessern die Entwicklungseffizienz. 2. Die Kompilierungsmerkmale von Compilation und die Hardwarekontrolle verbessern die Ausführungsleistung. Bei einer Auswahl müssen Sie die Entwicklungsgeschwindigkeit und die Ausführungseffizienz basierend auf den Projektanforderungen abwägen.

Python lernen: Ist 2 Stunden tägliches Studium ausreichend? Python lernen: Ist 2 Stunden tägliches Studium ausreichend? Apr 18, 2025 am 12:22 AM

Ist es genug, um Python für zwei Stunden am Tag zu lernen? Es hängt von Ihren Zielen und Lernmethoden ab. 1) Entwickeln Sie einen klaren Lernplan, 2) Wählen Sie geeignete Lernressourcen und -methoden aus, 3) praktizieren und prüfen und konsolidieren Sie praktische Praxis und Überprüfung und konsolidieren Sie und Sie können die Grundkenntnisse und die erweiterten Funktionen von Python während dieser Zeit nach und nach beherrschen.

Python vs. C: Verständnis der wichtigsten Unterschiede Python vs. C: Verständnis der wichtigsten Unterschiede Apr 21, 2025 am 12:18 AM

Python und C haben jeweils ihre eigenen Vorteile, und die Wahl sollte auf Projektanforderungen beruhen. 1) Python ist aufgrund seiner prägnanten Syntax und der dynamischen Typisierung für die schnelle Entwicklung und Datenverarbeitung geeignet. 2) C ist aufgrund seiner statischen Tipp- und manuellen Speicherverwaltung für hohe Leistung und Systemprogrammierung geeignet.

Welches ist Teil der Python Standard Library: Listen oder Arrays? Welches ist Teil der Python Standard Library: Listen oder Arrays? Apr 27, 2025 am 12:03 AM

PythonlistsarePartThestandardlibrary, whilearraysarenot.listarebuilt-in, vielseitig und UNDUSEDFORSPORINGECollections, während dieArrayRay-thearrayModulei und loses und loses und losesaluseduetolimitedFunctionality.

Python: Automatisierung, Skript- und Aufgabenverwaltung Python: Automatisierung, Skript- und Aufgabenverwaltung Apr 16, 2025 am 12:14 AM

Python zeichnet sich in Automatisierung, Skript und Aufgabenverwaltung aus. 1) Automatisierung: Die Sicherungssicherung wird durch Standardbibliotheken wie OS und Shutil realisiert. 2) Skriptschreiben: Verwenden Sie die PSUTIL -Bibliothek, um die Systemressourcen zu überwachen. 3) Aufgabenverwaltung: Verwenden Sie die Zeitplanbibliothek, um Aufgaben zu planen. Die Benutzerfreundlichkeit von Python und die Unterstützung der reichhaltigen Bibliothek machen es zum bevorzugten Werkzeug in diesen Bereichen.

Python für wissenschaftliches Computer: Ein detailliertes Aussehen Python für wissenschaftliches Computer: Ein detailliertes Aussehen Apr 19, 2025 am 12:15 AM

Zu den Anwendungen von Python im wissenschaftlichen Computer gehören Datenanalyse, maschinelles Lernen, numerische Simulation und Visualisierung. 1.Numpy bietet effiziente mehrdimensionale Arrays und mathematische Funktionen. 2. Scipy erweitert die Numpy -Funktionalität und bietet Optimierungs- und lineare Algebra -Tools. 3.. Pandas wird zur Datenverarbeitung und -analyse verwendet. 4.Matplotlib wird verwendet, um verschiedene Grafiken und visuelle Ergebnisse zu erzeugen.

See all articles