Rumah pembangunan bahagian belakang Tutorial Python Rincian untuk Membina CNN Setara Biasa

Rincian untuk Membina CNN Setara Biasa

Jul 18, 2024 am 11:29 AM

Satu prinsip hanya dinyatakan sebagai 'Biarkan kernel berputar' dan kami akan menumpukan dalam artikel ini tentang cara anda boleh menerapkannya dalam seni bina anda.

Seni bina setara membolehkan kami melatih model yang tidak peduli dengan tindakan kumpulan tertentu.

Untuk memahami maksud ini sebenarnya, mari kita latih model CNN mudah ini pada set data MNIST (set data digit tulisan tangan dari 0-9).

class SimpleCNN(nn.Module):

    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.cl1 = nn.Conv2d(in_channels=1, out_channels=8, kernel_size=3, padding=1)
        self.max_1 = nn.MaxPool2d(kernel_size=2)
        self.cl2 = nn.Conv2d(in_channels=8, out_channels=16, kernel_size=3, padding=1)
        self.max_2 = nn.MaxPool2d(kernel_size=2)
        self.cl3 = nn.Conv2d(in_channels=16, out_channels=16, kernel_size=7)
        self.dense = nn.Linear(in_features=16, out_features=10)

    def forward(self, x: torch.Tensor):
        x = nn.functional.silu(self.cl1(x))
        x = self.max_1(x)
        x = nn.functional.silu(self.cl2(x))
        x = self.max_2(x)
        x = nn.functional.silu(self.cl3(x))
        x = x.view(len(x), -1)
        logits = self.dense(x)
        return logits
Salin selepas log masuk
Accuracy on test Accuracy on 90-degree rotated test
97.3% 15.1%

Jadual 1: Uji ketepatan model SimpleCNN

Seperti yang dijangkakan, kami mendapat lebih 95% ketepatan pada set data ujian, tetapi bagaimana jika kami memutarkan imej sebanyak 90 darjah? Tanpa sebarang tindakan balas yang dikenakan, keputusan menurun kepada hanya lebih baik sedikit daripada meneka. Model ini tidak berguna untuk aplikasi umum.

Sebaliknya, mari kita latih seni bina setara yang serupa dengan bilangan parameter yang sama, di mana tindakan kumpulan adalah tepat putaran 90 darjah.

Accuracy on test Accuracy on 90-degree rotated test
96.5% 96.5%

Jadual 2: Uji ketepatan model EqCNN dengan jumlah parameter yang sama seperti model SimpleCNN

Ketepatan tetap sama dan kami tidak memilih untuk menambah data.

Model ini menjadi lebih mengagumkan dengan data 3D, tetapi kami akan tetap menggunakan contoh ini untuk meneroka idea teras.

Sekiranya anda ingin mengujinya sendiri, anda boleh mengakses semua kod yang ditulis dalam kedua-dua PyTorch dan JAX secara percuma di bawah Github-Repo, dan latihan dengan Docker atau Podman boleh dilakukan dengan hanya dua arahan.

Selamat mencuba!

Jadi Apakah Kesetaraan?

Senibina setara menjamin kestabilan ciri di bawah tindakan kumpulan tertentu. Kumpulan ialah struktur ringkas di mana elemen kumpulan boleh digabungkan, diterbalikkan atau tidak melakukan apa-apa.

Anda boleh mencari definisi rasmi di Wikipedia jika anda berminat.

Untuk tujuan kami, anda boleh memikirkan sekumpulan putaran 90 darjah yang bertindak pada imej segi empat sama. Kita boleh memutarkan imej sebanyak 90, 180, 270 atau 360 darjah. Untuk membalikkan tindakan, kami menggunakan putaran 270, 180, 90 atau 0 darjah masing-masing. Adalah mudah untuk melihat bahawa kita boleh menggabungkan, membalikkan atau melakukan apa-apa dengan kumpulan yang dilambangkan sebagai C4C_4C4 . Imej menggambarkan semua tindakan pada imej.

Figure 1: Rotated MNIST image by 90°, 180°, 270°, 360°, respectively
Rajah 1: Imej MNIST diputar sebanyak 90°, 180°, 270°, 360°, masing-masing

Now, given an input image xxx , our CNN model classifier fθf_\thetafθ , and an arbitrary 90-degree rotation ggg , the equivariant property can be expressed as
fθ(rotate x by g)=fθ(x) f_\theta(\text{rotate } x \text{ by } g) = f_\theta(x) fθ(rotate x by g)=fθ(x)

Generally speaking, we want our image-based model to have the same outputs when rotated.

As such, equivariant models promise us architectures with baked-in symmetries. In the following section, we will see how our principle can be applied to achieve this property.

How to Make Our CNN Equivariant

The problem is the following: When the image rotates, the features rotate too. But as already hinted, we could also compute the features for each rotation upfront by rotating the kernel.
We could actually rotate the kernel, but it is much easier to rotate the feature map itself, thus avoiding interference with PyTorch's autodifferentiation algorithm altogether.

So, in code, our CNN kernel

x = nn.functional.silu(self.cl1(x))
Salin selepas log masuk

now acts on all four rotated images:

x_0 = x
x_90 = torch.rot90(x, k=1, dims=(2, 3))
x_180 = torch.rot90(x, k=2, dims=(2, 3))
x_270 = torch.rot90(x, k=3, dims=(2, 3))

x_0 = nn.functional.silu(self.cl1(x_0))
x_90 = nn.functional.silu(self.cl1(x_90))
x_180 = nn.functional.silu(self.cl1(x_180))
x_270 = nn.functional.silu(self.cl1(x_270))
Salin selepas log masuk

Or more compactly written as a 3D convolution:

self.cl1 = nn.Conv3d(in_channels=1, out_channels=8, kernel_size=(1, 3, 3))
...
x = torch.stack([x_0, x_90, x_180, x_270], dim=-3)
x = nn.functional.silu(self.cl1(x))
Salin selepas log masuk

The resulting equivariant model has just a few lines more compared to the version above:

class EqCNN(nn.Module):

    def __init__(self):
        super(EqCNN, self).__init__()
        self.cl1 = nn.Conv3d(in_channels=1, out_channels=8, kernel_size=(1, 3, 3))
        self.max_1 = nn.MaxPool3d(kernel_size=(1, 2, 2))
        self.cl2 = nn.Conv3d(in_channels=8, out_channels=16, kernel_size=(1, 3, 3))
        self.max_2 = nn.MaxPool3d(kernel_size=(1, 2, 2))
        self.cl3 = nn.Conv3d(in_channels=16, out_channels=16, kernel_size=(1, 5, 5))
        self.dense = nn.Linear(in_features=16, out_features=10)

    def forward(self, x: torch.Tensor):
        x_0 = x
        x_90 = torch.rot90(x, k=1, dims=(2, 3))
        x_180 = torch.rot90(x, k=2, dims=(2, 3))
        x_270 = torch.rot90(x, k=3, dims=(2, 3))

        x = torch.stack([x_0, x_90, x_180, x_270], dim=-3)
        x = nn.functional.silu(self.cl1(x))
        x = self.max_1(x)

        x = nn.functional.silu(self.cl2(x))
        x = self.max_2(x)

        x = nn.functional.silu(self.cl3(x))

        x = x.squeeze()
        x = torch.max(x, dim=-1).values
        logits = self.dense(x)
        return logits
Salin selepas log masuk

But why is this equivariant to rotations?
First, observe that we get four copies of each feature map at each stage. At the end of the pipeline, we combine all of them with a max operation.

This is key, the max operation is indifferent to which place the rotated version of the feature ends up in.

To understand what is happening, let us plot the feature maps after the first convolution stage.

Figure 2: Feature maps for all four rotations
Figure 2: Feature maps for all four rotations

And now the same features after we rotate the input by 90 degrees.

Figure 3: Feature maps for all four rotations after the input image was rotated
Rajah 3: Peta ciri untuk keempat-empat putaran selepas imej input diputar

Saya mengekod warna peta yang sepadan. Setiap peta ciri dianjakkan oleh satu. Apabila operator maks akhir mengira hasil yang sama untuk peta ciri yang dialih ini, kami memperoleh hasil yang sama.

Dalam kod saya, saya tidak berputar kembali selepas lilitan terakhir, kerana kernel saya memekatkan imej kepada tatasusunan satu dimensi. Jika anda ingin mengembangkan contoh ini, anda perlu mengambil kira fakta ini.

Perakaunan untuk tindakan kumpulan atau "putaran kernel" memainkan peranan penting dalam reka bentuk seni bina yang lebih canggih.

Adakah ia Makan Tengah Hari Percuma?

Tidak, kami membayar dalam kelajuan pengiraan, bias induktif dan pelaksanaan yang lebih kompleks.

Perkara terakhir agak diselesaikan dengan perpustakaan seperti E3NN, di mana kebanyakan matematik berat diabstrakkan. Walau bagaimanapun, seseorang perlu mengambil kira banyak semasa reka bentuk seni bina.

Satu kelemahan cetek ialah 4x kos pengiraan untuk mengira semua lapisan ciri yang diputar. Walau bagaimanapun, perkakasan moden dengan paralelisasi jisim boleh mengatasi beban ini dengan mudah. Sebaliknya, melatih CNN mudah dengan penambahan data dengan mudah akan melebihi 10x dalam masa latihan. Ini menjadi lebih teruk lagi untuk putaran 3D yang mana penambahan data memerlukan kira-kira 500x jumlah latihan untuk mengimbangi semua putaran yang mungkin.

Secara keseluruhannya, reka bentuk model kesetaraan lebih kerap daripada bukan harga yang patut dibayar jika seseorang mahukan ciri yang stabil.

Apakah Seterusnya?

Reka bentuk model setara telah meletup dalam beberapa tahun kebelakangan ini, dan dalam artikel ini, kami hampir tidak mencalarkan permukaan. Malah, kami tidak mengeksploitasi sepenuhnya C4C_4C4 kumpulan lagi. Kami boleh menggunakan kernel 3D penuh. Walau bagaimanapun, model kami sudah mencapai ketepatan lebih 95%, jadi tiada sebab untuk pergi lebih jauh dengan contoh ini.

Selain CNN, penyelidik telah berjaya menterjemahkan prinsip ini kepada kumpulan berterusan, termasuk SO(2) JADI(2)JADI(2) (kumpulan semua putaran dalam satah) dan SE(3) SE(3)SE(3) (kumpulan semua terjemahan dan putaran dalam ruang 3D).

Menurut pengalaman saya, model ini benar-benar mengagumkan dan mencapai prestasi, apabila dilatih dari awal, setanding dengan prestasi model asas yang dilatih pada set data berbilang kali ganda lebih besar.

Beri tahu saya jika anda mahu saya menulis lebih lanjut mengenai topik ini.

Rujukan Lanjut

Sekiranya anda mahukan pengenalan rasmi kepada topik ini, berikut ialah kompilasi kertas kerja yang sangat baik, merangkumi sejarah lengkap kesetaraan dalam Pembelajaran Mesin.
AEN

Saya sebenarnya bercadang untuk membuat tutorial mendalam dan praktikal mengenai topik ini. Anda sudah boleh mendaftar untuk senarai mel saya dan saya akan memberikan anda versi percuma dari semasa ke semasa, bersama-sama saluran terus untuk maklum balas dan Soal Jawab.

Jumpa lagi :)

Atas ialah kandungan terperinci Rincian untuk Membina CNN Setara Biasa. Untuk maklumat lanjut, sila ikut artikel berkaitan lain di laman web China PHP!

Kenyataan Laman Web ini
Kandungan artikel ini disumbangkan secara sukarela oleh netizen, dan hak cipta adalah milik pengarang asal. Laman web ini tidak memikul tanggungjawab undang-undang yang sepadan. Jika anda menemui sebarang kandungan yang disyaki plagiarisme atau pelanggaran, sila hubungi admin@php.cn

Alat AI Hot

Undresser.AI Undress

Undresser.AI Undress

Apl berkuasa AI untuk mencipta foto bogel yang realistik

AI Clothes Remover

AI Clothes Remover

Alat AI dalam talian untuk mengeluarkan pakaian daripada foto.

Undress AI Tool

Undress AI Tool

Gambar buka pakaian secara percuma

Clothoff.io

Clothoff.io

Penyingkiran pakaian AI

AI Hentai Generator

AI Hentai Generator

Menjana ai hentai secara percuma.

Artikel Panas

R.E.P.O. Kristal tenaga dijelaskan dan apa yang mereka lakukan (kristal kuning)
1 bulan yang lalu By 尊渡假赌尊渡假赌尊渡假赌
R.E.P.O. Tetapan grafik terbaik
1 bulan yang lalu By 尊渡假赌尊渡假赌尊渡假赌
Akan R.E.P.O. Ada Crossplay?
1 bulan yang lalu By 尊渡假赌尊渡假赌尊渡假赌

Alat panas

Notepad++7.3.1

Notepad++7.3.1

Editor kod yang mudah digunakan dan percuma

SublimeText3 versi Cina

SublimeText3 versi Cina

Versi Cina, sangat mudah digunakan

Hantar Studio 13.0.1

Hantar Studio 13.0.1

Persekitaran pembangunan bersepadu PHP yang berkuasa

Dreamweaver CS6

Dreamweaver CS6

Alat pembangunan web visual

SublimeText3 versi Mac

SublimeText3 versi Mac

Perisian penyuntingan kod peringkat Tuhan (SublimeText3)

Bagaimana untuk menyelesaikan masalah kebenaran yang dihadapi semasa melihat versi Python di Terminal Linux? Bagaimana untuk menyelesaikan masalah kebenaran yang dihadapi semasa melihat versi Python di Terminal Linux? Apr 01, 2025 pm 05:09 PM

Penyelesaian kepada Isu Kebenaran Semasa Melihat Versi Python di Terminal Linux Apabila anda cuba melihat versi Python di Terminal Linux, masukkan Python ...

Bagaimana cara menyalin seluruh lajur satu data ke dalam data data lain dengan struktur yang berbeza di Python? Bagaimana cara menyalin seluruh lajur satu data ke dalam data data lain dengan struktur yang berbeza di Python? Apr 01, 2025 pm 11:15 PM

Apabila menggunakan Perpustakaan Pandas Python, bagaimana untuk menyalin seluruh lajur antara dua data data dengan struktur yang berbeza adalah masalah biasa. Katakan kita mempunyai dua DAT ...

Bagaimana Mengajar Asas Pengaturcaraan Pemula Komputer Dalam Kaedah Projek dan Masalah Dikemukakan Dalam masa 10 Jam? Bagaimana Mengajar Asas Pengaturcaraan Pemula Komputer Dalam Kaedah Projek dan Masalah Dikemukakan Dalam masa 10 Jam? Apr 02, 2025 am 07:18 AM

Bagaimana Mengajar Asas Pengaturcaraan Pemula Komputer Dalam masa 10 jam? Sekiranya anda hanya mempunyai 10 jam untuk mengajar pemula komputer beberapa pengetahuan pengaturcaraan, apa yang akan anda pilih untuk mengajar ...

Bagaimana untuk mengelakkan dikesan oleh penyemak imbas apabila menggunakan fiddler di mana-mana untuk membaca lelaki-dalam-tengah? Bagaimana untuk mengelakkan dikesan oleh penyemak imbas apabila menggunakan fiddler di mana-mana untuk membaca lelaki-dalam-tengah? Apr 02, 2025 am 07:15 AM

Cara mengelakkan dikesan semasa menggunakan fiddlerevery di mana untuk bacaan lelaki-dalam-pertengahan apabila anda menggunakan fiddlerevery di mana ...

Apakah ungkapan biasa? Apakah ungkapan biasa? Mar 20, 2025 pm 06:25 PM

Ekspresi biasa adalah alat yang berkuasa untuk memadankan corak dan manipulasi teks dalam pengaturcaraan, meningkatkan kecekapan dalam pemprosesan teks merentasi pelbagai aplikasi.

Bagaimanakah uvicorn terus mendengar permintaan http tanpa serving_forever ()? Bagaimanakah uvicorn terus mendengar permintaan http tanpa serving_forever ()? Apr 01, 2025 pm 10:51 PM

Bagaimanakah Uvicorn terus mendengar permintaan HTTP? Uvicorn adalah pelayan web ringan berdasarkan ASGI. Salah satu fungsi terasnya ialah mendengar permintaan HTTP dan teruskan ...

Apakah beberapa perpustakaan Python yang popular dan kegunaan mereka? Apakah beberapa perpustakaan Python yang popular dan kegunaan mereka? Mar 21, 2025 pm 06:46 PM

Artikel ini membincangkan perpustakaan Python yang popular seperti Numpy, Pandas, Matplotlib, Scikit-Learn, Tensorflow, Django, Flask, dan Permintaan, memperincikan kegunaan mereka dalam pengkomputeran saintifik, analisis data, visualisasi, pembelajaran mesin, pembangunan web, dan h

Bagaimana secara dinamik membuat objek melalui rentetan dan panggil kaedahnya dalam Python? Bagaimana secara dinamik membuat objek melalui rentetan dan panggil kaedahnya dalam Python? Apr 01, 2025 pm 11:18 PM

Di Python, bagaimana untuk membuat objek secara dinamik melalui rentetan dan panggil kaedahnya? Ini adalah keperluan pengaturcaraan yang biasa, terutamanya jika perlu dikonfigurasikan atau dijalankan ...

See all articles