Heim Backend-Entwicklung Python-Tutorial Das Prinzip zum Aufbau regulärer äquivarianter CNNs

Das Prinzip zum Aufbau regulärer äquivarianter CNNs

Jul 18, 2024 am 11:29 AM

Das eine Prinzip wird einfach als „Lassen Sie den Kernel rotieren“ ausgedrückt und wir werden uns in diesem Artikel darauf konzentrieren, wie Sie es in Ihren Architekturen anwenden können.

Äquivariante Architekturen ermöglichen es uns, Modelle zu trainieren, die gegenüber bestimmten Gruppenaktionen gleichgültig sind.

Um zu verstehen, was das genau bedeutet, trainieren wir dieses einfache CNN-Modell auf dem MNIST-Datensatz (einem Datensatz handgeschriebener Ziffern von 0-9).

class SimpleCNN(nn.Module):

    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.cl1 = nn.Conv2d(in_channels=1, out_channels=8, kernel_size=3, padding=1)
        self.max_1 = nn.MaxPool2d(kernel_size=2)
        self.cl2 = nn.Conv2d(in_channels=8, out_channels=16, kernel_size=3, padding=1)
        self.max_2 = nn.MaxPool2d(kernel_size=2)
        self.cl3 = nn.Conv2d(in_channels=16, out_channels=16, kernel_size=7)
        self.dense = nn.Linear(in_features=16, out_features=10)

    def forward(self, x: torch.Tensor):
        x = nn.functional.silu(self.cl1(x))
        x = self.max_1(x)
        x = nn.functional.silu(self.cl2(x))
        x = self.max_2(x)
        x = nn.functional.silu(self.cl3(x))
        x = x.view(len(x), -1)
        logits = self.dense(x)
        return logits
Nach dem Login kopieren
Accuracy on test Accuracy on 90-degree rotated test
97.3% 15.1%

Tabelle 1: Testgenauigkeit des SimpleCNN-Modells

Wie erwartet erreichen wir beim Testdatensatz eine Genauigkeit von über 95 %, aber was ist, wenn wir das Bild um 90 Grad drehen? Ohne Gegenmaßnahmen sinken die Ergebnisse auf ein knapp besseres Ergebnis als geschätzt. Dieses Modell wäre für allgemeine Anwendungen unbrauchbar.

Im Gegensatz dazu trainieren wir eine ähnliche äquivariante Architektur mit der gleichen Anzahl von Parametern, bei der die Gruppenaktionen genau die 90-Grad-Rotationen sind.

Accuracy on test Accuracy on 90-degree rotated test
96.5% 96.5%

Tabelle 2: Testgenauigkeit des EqCNN-Modells mit der gleichen Anzahl an Parametern wie das SimpleCNN-Modell

Die Genauigkeit bleibt gleich und wir haben uns nicht einmal für eine Datenerweiterung entschieden.

Diese Modelle werden mit 3D-Daten noch beeindruckender, aber wir bleiben bei diesem Beispiel, um die Kernidee zu untersuchen.

Falls Sie es selbst testen möchten, können Sie unter Github-Repo kostenlos auf den gesamten in PyTorch und JAX geschriebenen Code zugreifen, und das Training mit Docker oder Podman ist mit nur zwei Befehlen möglich.

Viel Spaß!

Was ist also Äquivarianz?

Äquivariante Architekturen garantieren die Stabilität von Funktionen unter bestimmten Gruppenaktionen. Gruppen sind einfache Strukturen, in denen Gruppenelemente kombiniert, umgekehrt oder gar nichts bewirken können.

Bei Interesse können Sie die formale Definition auf Wikipedia nachschlagen.

Für unsere Zwecke können Sie sich eine Gruppe von 90-Grad-Rotationen vorstellen, die auf quadratische Bilder wirken. Wir können ein Bild um 90, 180, 270 oder 360 Grad drehen. Um die Aktion umzukehren, wenden wir eine Drehung um 270, 180, 90 bzw. 0 Grad an. Es ist leicht zu erkennen, dass wir die als bezeichnete Gruppe kombinieren, umkehren oder nichts tun können C4C_4C4 . Das Bild visualisiert alle Aktionen auf einem Bild.

Figure 1: Rotated MNIST image by 90°, 180°, 270°, 360°, respectively
Abbildung 1: Gedrehtes MNIST-Bild um 90°, 180°, 270° bzw. 360°

Now, given an input image xxx , our CNN model classifier fθf_\thetafθ , and an arbitrary 90-degree rotation ggg , the equivariant property can be expressed as
fθ(rotate x by g)=fθ(x) f_\theta(\text{rotate } x \text{ by } g) = f_\theta(x) fθ(rotate x by g)=fθ(x)

Generally speaking, we want our image-based model to have the same outputs when rotated.

As such, equivariant models promise us architectures with baked-in symmetries. In the following section, we will see how our principle can be applied to achieve this property.

How to Make Our CNN Equivariant

The problem is the following: When the image rotates, the features rotate too. But as already hinted, we could also compute the features for each rotation upfront by rotating the kernel.
We could actually rotate the kernel, but it is much easier to rotate the feature map itself, thus avoiding interference with PyTorch's autodifferentiation algorithm altogether.

So, in code, our CNN kernel

x = nn.functional.silu(self.cl1(x))
Nach dem Login kopieren

now acts on all four rotated images:

x_0 = x
x_90 = torch.rot90(x, k=1, dims=(2, 3))
x_180 = torch.rot90(x, k=2, dims=(2, 3))
x_270 = torch.rot90(x, k=3, dims=(2, 3))

x_0 = nn.functional.silu(self.cl1(x_0))
x_90 = nn.functional.silu(self.cl1(x_90))
x_180 = nn.functional.silu(self.cl1(x_180))
x_270 = nn.functional.silu(self.cl1(x_270))
Nach dem Login kopieren

Or more compactly written as a 3D convolution:

self.cl1 = nn.Conv3d(in_channels=1, out_channels=8, kernel_size=(1, 3, 3))
...
x = torch.stack([x_0, x_90, x_180, x_270], dim=-3)
x = nn.functional.silu(self.cl1(x))
Nach dem Login kopieren

The resulting equivariant model has just a few lines more compared to the version above:

class EqCNN(nn.Module):

    def __init__(self):
        super(EqCNN, self).__init__()
        self.cl1 = nn.Conv3d(in_channels=1, out_channels=8, kernel_size=(1, 3, 3))
        self.max_1 = nn.MaxPool3d(kernel_size=(1, 2, 2))
        self.cl2 = nn.Conv3d(in_channels=8, out_channels=16, kernel_size=(1, 3, 3))
        self.max_2 = nn.MaxPool3d(kernel_size=(1, 2, 2))
        self.cl3 = nn.Conv3d(in_channels=16, out_channels=16, kernel_size=(1, 5, 5))
        self.dense = nn.Linear(in_features=16, out_features=10)

    def forward(self, x: torch.Tensor):
        x_0 = x
        x_90 = torch.rot90(x, k=1, dims=(2, 3))
        x_180 = torch.rot90(x, k=2, dims=(2, 3))
        x_270 = torch.rot90(x, k=3, dims=(2, 3))

        x = torch.stack([x_0, x_90, x_180, x_270], dim=-3)
        x = nn.functional.silu(self.cl1(x))
        x = self.max_1(x)

        x = nn.functional.silu(self.cl2(x))
        x = self.max_2(x)

        x = nn.functional.silu(self.cl3(x))

        x = x.squeeze()
        x = torch.max(x, dim=-1).values
        logits = self.dense(x)
        return logits
Nach dem Login kopieren

But why is this equivariant to rotations?
First, observe that we get four copies of each feature map at each stage. At the end of the pipeline, we combine all of them with a max operation.

This is key, the max operation is indifferent to which place the rotated version of the feature ends up in.

To understand what is happening, let us plot the feature maps after the first convolution stage.

Figure 2: Feature maps for all four rotations
Figure 2: Feature maps for all four rotations

And now the same features after we rotate the input by 90 degrees.

Figure 3: Feature maps for all four rotations after the input image was rotated
Abbildung 3: Feature-Maps für alle vier Rotationen, nachdem das Eingabebild gedreht wurde

Ich habe die entsprechenden Karten farblich markiert. Jede Feature-Map wird um eins verschoben. Da der endgültige Max-Operator das gleiche Ergebnis für diese verschobenen Feature-Maps berechnet, erhalten wir die gleichen Ergebnisse.

In meinem Code habe ich nach der letzten Faltung nicht zurückrotiert, da meine Kernel das Bild zu einem eindimensionalen Array verdichten. Wenn Sie dieses Beispiel näher erläutern möchten, müssen Sie diese Tatsache berücksichtigen.

Die Berücksichtigung von Gruppenaktionen oder „Kernelrotationen“ spielt eine entscheidende Rolle beim Entwurf anspruchsvollerer Architekturen.

Ist es ein kostenloses Mittagessen?

Nein, wir bezahlen mit Rechengeschwindigkeit, induktivem Bias und einer komplexeren Implementierung.

Der letzte Punkt lässt sich einigermaßen mit Bibliotheken wie E3NN lösen, in denen der Großteil der schweren Mathematik abstrahiert wird. Dennoch muss man beim Architekturentwurf einiges berücksichtigen.

Eine oberflächliche Schwäche ist der vierfache Rechenaufwand für die Berechnung aller gedrehten Feature-Layer. Moderne Hardware mit Massenparallelisierung kann dieser Belastung jedoch problemlos entgegenwirken. Im Gegensatz dazu würde das Training eines einfachen CNN mit Datenerweiterung die Trainingszeit leicht um das Zehnfache überschreiten. Noch schlimmer wird es bei 3D-Rotationen, bei denen die Datenerweiterung etwa das 500-fache des Trainingsaufwands erfordern würde, um alle möglichen Rotationen zu kompensieren.

Insgesamt ist der Entwurf eines Äquivarianzmodells oft ein lohnenswerter Preis, wenn man stabile Funktionen wünscht.

Was kommt als nächstes?

Äquivariante Modelldesigns haben in den letzten Jahren explosionsartig zugenommen, und in diesem Artikel haben wir kaum an der Oberfläche gekratzt. Tatsächlich haben wir nicht einmal das volle Potenzial ausgeschöpft C4C_4C4 Gruppe noch. Wir hätten vollständige 3D-Kernel verwenden können. Allerdings erreicht unser Modell bereits eine Genauigkeit von über 95 %, sodass es kaum einen Grund gibt, mit diesem Beispiel noch weiter zu gehen.

Außer CNNs haben Forscher diese Prinzipien erfolgreich auf kontinuierliche Gruppen übertragen, darunter SO(2) SO(2)SO(2) (die Gruppe aller Drehungen in der Ebene) und SE(3) SE(3)SE(3) (die Gruppe aller Translationen und Rotationen im 3D-Raum).

Meiner Erfahrung nach sind diese Modelle absolut umwerfend und erreichen eine Leistung, wenn sie von Grund auf trainiert werden, vergleichbar mit der Leistung von Basismodellen, die auf mehrfach größeren Datensätzen trainiert werden.

Lassen Sie mich wissen, wenn Sie möchten, dass ich mehr zu diesem Thema schreibe.

Weitere Referenzen

Falls Sie eine formelle Einführung in dieses Thema wünschen, finden Sie hier eine hervorragende Zusammenstellung von Artikeln, die die gesamte Geschichte der Äquivarianz im maschinellen Lernen abdecken.
AEN

Ich habe tatsächlich vor, ein ausführliches, praktisches Tutorial zu diesem Thema zu erstellen. Sie können sich bereits jetzt für meine Mailingliste anmelden und ich werde Ihnen im Laufe der Zeit kostenlose Versionen zur Verfügung stellen, zusammen mit einem direkten Kanal für Feedback und Fragen und Antworten.

Wir sehen uns :)

Das obige ist der detaillierte Inhalt vonDas Prinzip zum Aufbau regulärer äquivarianter CNNs. Für weitere Informationen folgen Sie bitte anderen verwandten Artikeln auf der PHP chinesischen Website!

Erklärung dieser Website
Der Inhalt dieses Artikels wird freiwillig von Internetnutzern beigesteuert und das Urheberrecht liegt beim ursprünglichen Autor. Diese Website übernimmt keine entsprechende rechtliche Verantwortung. Wenn Sie Inhalte finden, bei denen der Verdacht eines Plagiats oder einer Rechtsverletzung besteht, wenden Sie sich bitte an admin@php.cn

Heiße KI -Werkzeuge

Undresser.AI Undress

Undresser.AI Undress

KI-gestützte App zum Erstellen realistischer Aktfotos

AI Clothes Remover

AI Clothes Remover

Online-KI-Tool zum Entfernen von Kleidung aus Fotos.

Undress AI Tool

Undress AI Tool

Ausziehbilder kostenlos

Clothoff.io

Clothoff.io

KI-Kleiderentferner

AI Hentai Generator

AI Hentai Generator

Erstellen Sie kostenlos Ai Hentai.

Heißer Artikel

R.E.P.O. Energiekristalle erklärten und was sie tun (gelber Kristall)
1 Monate vor By 尊渡假赌尊渡假赌尊渡假赌
R.E.P.O. Beste grafische Einstellungen
1 Monate vor By 尊渡假赌尊渡假赌尊渡假赌
Will R.E.P.O. Crossplay haben?
1 Monate vor By 尊渡假赌尊渡假赌尊渡假赌

Heiße Werkzeuge

Notepad++7.3.1

Notepad++7.3.1

Einfach zu bedienender und kostenloser Code-Editor

SublimeText3 chinesische Version

SublimeText3 chinesische Version

Chinesische Version, sehr einfach zu bedienen

Senden Sie Studio 13.0.1

Senden Sie Studio 13.0.1

Leistungsstarke integrierte PHP-Entwicklungsumgebung

Dreamweaver CS6

Dreamweaver CS6

Visuelle Webentwicklungstools

SublimeText3 Mac-Version

SublimeText3 Mac-Version

Codebearbeitungssoftware auf Gottesniveau (SublimeText3)

Wie löste ich das Problem der Berechtigungen beim Betrachten der Python -Version in Linux Terminal? Wie löste ich das Problem der Berechtigungen beim Betrachten der Python -Version in Linux Terminal? Apr 01, 2025 pm 05:09 PM

Lösung für Erlaubnisprobleme beim Betrachten der Python -Version in Linux Terminal Wenn Sie versuchen, die Python -Version in Linux Terminal anzuzeigen, geben Sie Python ein ...

Wie kann ich die gesamte Spalte eines Datenrahmens effizient in einen anderen Datenrahmen mit verschiedenen Strukturen in Python kopieren? Wie kann ich die gesamte Spalte eines Datenrahmens effizient in einen anderen Datenrahmen mit verschiedenen Strukturen in Python kopieren? Apr 01, 2025 pm 11:15 PM

Bei der Verwendung von Pythons Pandas -Bibliothek ist das Kopieren von ganzen Spalten zwischen zwei Datenrahmen mit unterschiedlichen Strukturen ein häufiges Problem. Angenommen, wir haben zwei Daten ...

Wie lehre ich innerhalb von 10 Stunden die Grundlagen für Computer-Anfänger-Programmierbasis in Projekt- und problemorientierten Methoden? Wie lehre ich innerhalb von 10 Stunden die Grundlagen für Computer-Anfänger-Programmierbasis in Projekt- und problemorientierten Methoden? Apr 02, 2025 am 07:18 AM

Wie lehre ich innerhalb von 10 Stunden die Grundlagen für Computer -Anfänger für Programmierungen? Wenn Sie nur 10 Stunden Zeit haben, um Computer -Anfänger zu unterrichten, was Sie mit Programmierkenntnissen unterrichten möchten, was würden Sie dann beibringen ...

Wie kann man vom Browser vermeiden, wenn man überall Fiddler für das Lesen des Menschen in der Mitte verwendet? Wie kann man vom Browser vermeiden, wenn man überall Fiddler für das Lesen des Menschen in der Mitte verwendet? Apr 02, 2025 am 07:15 AM

Wie kann man nicht erkannt werden, wenn Sie Fiddlereverywhere für Man-in-the-Middle-Lesungen verwenden, wenn Sie FiddLereverywhere verwenden ...

Was sind reguläre Ausdrücke? Was sind reguläre Ausdrücke? Mar 20, 2025 pm 06:25 PM

Regelmäßige Ausdrücke sind leistungsstarke Tools für Musteranpassung und Textmanipulation in der Programmierung, wodurch die Effizienz bei der Textverarbeitung in verschiedenen Anwendungen verbessert wird.

Wie hört Uvicorn kontinuierlich auf HTTP -Anfragen ohne Serving_forver () an? Wie hört Uvicorn kontinuierlich auf HTTP -Anfragen ohne Serving_forver () an? Apr 01, 2025 pm 10:51 PM

Wie hört Uvicorn kontinuierlich auf HTTP -Anfragen an? Uvicorn ist ein leichter Webserver, der auf ASGI basiert. Eine seiner Kernfunktionen ist es, auf HTTP -Anfragen zu hören und weiterzumachen ...

Was sind einige beliebte Python -Bibliotheken und ihre Verwendung? Was sind einige beliebte Python -Bibliotheken und ihre Verwendung? Mar 21, 2025 pm 06:46 PM

In dem Artikel werden beliebte Python-Bibliotheken wie Numpy, Pandas, Matplotlib, Scikit-Learn, TensorFlow, Django, Flask und Anfragen erörtert, die ihre Verwendung in wissenschaftlichen Computing, Datenanalyse, Visualisierung, maschinellem Lernen, Webentwicklung und h beschreiben

Wie erstelle ich dynamisch ein Objekt über eine Zeichenfolge und rufe seine Methoden in Python auf? Wie erstelle ich dynamisch ein Objekt über eine Zeichenfolge und rufe seine Methoden in Python auf? Apr 01, 2025 pm 11:18 PM

Wie erstellt in Python ein Objekt dynamisch über eine Zeichenfolge und ruft seine Methoden auf? Dies ist eine häufige Programmieranforderung, insbesondere wenn sie konfiguriert oder ausgeführt werden muss ...

See all articles