백엔드 개발 파이썬 튜토리얼 정규 등변량 CNN을 구축하는 원리

정규 등변량 CNN을 구축하는 원리

Jul 18, 2024 am 11:29 AM

한 가지 원칙은 간단하게 '커널을 회전시키세요'라고 명시되어 있으며 이 기사에서는 이를 아키텍처에 적용하는 방법에 중점을 둘 것입니다.

등변 아키텍처를 사용하면 특정 그룹 활동에 무관한 모델을 훈련할 수 있습니다.

이것이 정확히 무엇을 의미하는지 이해하기 위해 MNIST 데이터세트(0-9의 손으로 쓴 숫자 데이터세트)에서 이 간단한 CNN 모델을 훈련시켜 보겠습니다.

class SimpleCNN(nn.Module):

    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.cl1 = nn.Conv2d(in_channels=1, out_channels=8, kernel_size=3, padding=1)
        self.max_1 = nn.MaxPool2d(kernel_size=2)
        self.cl2 = nn.Conv2d(in_channels=8, out_channels=16, kernel_size=3, padding=1)
        self.max_2 = nn.MaxPool2d(kernel_size=2)
        self.cl3 = nn.Conv2d(in_channels=16, out_channels=16, kernel_size=7)
        self.dense = nn.Linear(in_features=16, out_features=10)

    def forward(self, x: torch.Tensor):
        x = nn.functional.silu(self.cl1(x))
        x = self.max_1(x)
        x = nn.functional.silu(self.cl2(x))
        x = self.max_2(x)
        x = nn.functional.silu(self.cl3(x))
        x = x.view(len(x), -1)
        logits = self.dense(x)
        return logits
로그인 후 복사
Accuracy on test Accuracy on 90-degree rotated test
97.3% 15.1%

표 1: SimpleCNN 모델의 정확도 테스트

예상대로 테스트 데이터 세트에서 95% 이상의 정확도를 얻었습니다. 그런데 이미지를 90도 회전하면 어떻게 될까요? 어떤 대책도 적용하지 않으면 결과는 추측보다 약간 더 나은 수준으로 떨어집니다. 이 모델은 일반 응용 프로그램에는 쓸모가 없습니다.

반대로, 그룹 작업이 정확히 90도 회전하는 동일한 수의 매개변수를 사용하여 유사한 등변 아키텍처를 훈련해 보겠습니다.

Accuracy on test Accuracy on 90-degree rotated test
96.5% 96.5%

표 2: SimpleCNN 모델과 동일한 양의 매개변수를 사용하여 EqCNN 모델의 정확도 테스트

정확도는 그대로 유지되었으며 데이터 증강도 선택하지 않았습니다.

이러한 모델은 3D 데이터로 더욱 인상적이지만 핵심 아이디어를 탐구하기 위해 이 예를 고수하겠습니다.

직접 테스트해보고 싶다면 Github-Repo에서 PyTorch와 JAX로 작성된 모든 코드에 무료로 접근할 수 있으며, 단 두 개의 명령어만으로 Docker나 Podman을 이용한 학습도 가능합니다.

즐거운 시간 보내세요!

그렇다면 등분산이란 무엇입니까?

등가 아키텍처는 특정 그룹 작업에서 기능의 안정성을 보장합니다. 그룹은 그룹 요소를 결합하거나 뒤집거나 아무것도 할 수 없는 간단한 구조입니다.

관심이 있으시면 Wikipedia에서 공식적인 정의를 찾아보실 수 있습니다.

우리의 목적에 맞게 정사각형 이미지에 작용하는 90도 회전 그룹을 생각해 볼 수 있습니다. 이미지를 90도, 180도, 270도, 360도로 회전할 수 있습니다. 동작을 되돌리려면 각각 270도, 180도, 90도 또는 0도 회전을 적용합니다. 으로 표시된 그룹을 사용하면 결합, 반전 또는 아무것도 수행할 수 없음을 쉽게 알 수 있습니다. C4C_44 . 이미지는 이미지의 모든 동작을 시각화합니다.

Figure 1: Rotated MNIST image by 90°, 180°, 270°, 360°, respectively
그림 1: MNIST 이미지를 각각 90°, 180°, 270°, 360° 회전

Now, given an input image xxx , our CNN model classifier fθf_\thetafθ , and an arbitrary 90-degree rotation ggg , the equivariant property can be expressed as
fθ(rotate x by g)=fθ(x) f_\theta(\text{rotate } x \text{ by } g) = f_\theta(x) fθ(rotate x by g)=fθ(x)

Generally speaking, we want our image-based model to have the same outputs when rotated.

As such, equivariant models promise us architectures with baked-in symmetries. In the following section, we will see how our principle can be applied to achieve this property.

How to Make Our CNN Equivariant

The problem is the following: When the image rotates, the features rotate too. But as already hinted, we could also compute the features for each rotation upfront by rotating the kernel.
We could actually rotate the kernel, but it is much easier to rotate the feature map itself, thus avoiding interference with PyTorch's autodifferentiation algorithm altogether.

So, in code, our CNN kernel

x = nn.functional.silu(self.cl1(x))
로그인 후 복사

now acts on all four rotated images:

x_0 = x
x_90 = torch.rot90(x, k=1, dims=(2, 3))
x_180 = torch.rot90(x, k=2, dims=(2, 3))
x_270 = torch.rot90(x, k=3, dims=(2, 3))

x_0 = nn.functional.silu(self.cl1(x_0))
x_90 = nn.functional.silu(self.cl1(x_90))
x_180 = nn.functional.silu(self.cl1(x_180))
x_270 = nn.functional.silu(self.cl1(x_270))
로그인 후 복사

Or more compactly written as a 3D convolution:

self.cl1 = nn.Conv3d(in_channels=1, out_channels=8, kernel_size=(1, 3, 3))
...
x = torch.stack([x_0, x_90, x_180, x_270], dim=-3)
x = nn.functional.silu(self.cl1(x))
로그인 후 복사

The resulting equivariant model has just a few lines more compared to the version above:

class EqCNN(nn.Module):

    def __init__(self):
        super(EqCNN, self).__init__()
        self.cl1 = nn.Conv3d(in_channels=1, out_channels=8, kernel_size=(1, 3, 3))
        self.max_1 = nn.MaxPool3d(kernel_size=(1, 2, 2))
        self.cl2 = nn.Conv3d(in_channels=8, out_channels=16, kernel_size=(1, 3, 3))
        self.max_2 = nn.MaxPool3d(kernel_size=(1, 2, 2))
        self.cl3 = nn.Conv3d(in_channels=16, out_channels=16, kernel_size=(1, 5, 5))
        self.dense = nn.Linear(in_features=16, out_features=10)

    def forward(self, x: torch.Tensor):
        x_0 = x
        x_90 = torch.rot90(x, k=1, dims=(2, 3))
        x_180 = torch.rot90(x, k=2, dims=(2, 3))
        x_270 = torch.rot90(x, k=3, dims=(2, 3))

        x = torch.stack([x_0, x_90, x_180, x_270], dim=-3)
        x = nn.functional.silu(self.cl1(x))
        x = self.max_1(x)

        x = nn.functional.silu(self.cl2(x))
        x = self.max_2(x)

        x = nn.functional.silu(self.cl3(x))

        x = x.squeeze()
        x = torch.max(x, dim=-1).values
        logits = self.dense(x)
        return logits
로그인 후 복사

But why is this equivariant to rotations?
First, observe that we get four copies of each feature map at each stage. At the end of the pipeline, we combine all of them with a max operation.

This is key, the max operation is indifferent to which place the rotated version of the feature ends up in.

To understand what is happening, let us plot the feature maps after the first convolution stage.

Figure 2: Feature maps for all four rotations
Figure 2: Feature maps for all four rotations

And now the same features after we rotate the input by 90 degrees.

Figure 3: Feature maps for all four rotations after the input image was rotated
그림 3: 입력 이미지가 회전된 후 4개 회전 모두에 대한 특징 맵

해당 지도를 색상으로 구분했습니다. 각 기능 맵은 하나씩 이동됩니다. 최종 최대값 연산자는 이러한 이동된 특징 맵에 대해 동일한 결과를 계산하므로 동일한 결과를 얻습니다.

내 코드에서는 커널이 이미지를 1차원 배열로 압축하기 때문에 최종 컨볼루션 후에 다시 회전하지 않았습니다. 이 예를 확장하려면 이 사실을 고려해야 합니다.

그룹 작업 또는 "커널 회전"에 대한 설명은 보다 정교한 아키텍처를 설계하는 데 중요한 역할을 합니다.

공짜 점심인가요?

아니요. 계산 속도, 귀납적 편향, 더 복잡한 구현에 대한 대가를 지불합니다.

후자의 문제는 대부분의 무거운 수학이 추상화되는 E3NN과 같은 라이브러리를 사용하면 어느 정도 해결됩니다. 그럼에도 불구하고 건축설계에서는 많은 것을 고려해야 한다.

한 가지 표면적인 약점은 모든 회전된 피처 레이어를 계산하는 데 드는 계산 비용이 4배라는 것입니다. 그러나 대량 병렬화 기능을 갖춘 최신 하드웨어는 이러한 부하를 쉽게 상쇄할 수 있습니다. 대조적으로, 데이터 증강을 사용하여 간단한 CNN을 훈련하는 것은 훈련 시간이 쉽게 10배를 초과합니다. 이는 가능한 모든 회전을 보상하기 위해 데이터 증강에 약 500배의 훈련량이 필요한 3D 회전의 경우 더욱 악화됩니다.

전체적으로 등분산 모델 설계는 안정적인 기능을 원하는 경우 지불할 가치가 있는 가격이 아닌 경우가 많습니다.

다음은 무엇입니까?

최근 몇 년 동안 등가 모델 디자인이 폭발적으로 증가했으며 이 기사에서는 표면적인 내용만 다루었습니다. 사실 우리는 전체 기능을 활용하지도 못했습니다. C4C_44 아직 그룹. 우리는 완전한 3D 커널을 사용할 수도 있었습니다. 그러나 우리 모델은 이미 95% 이상의 정확도를 달성했기 때문에 이 예를 더 이상 사용할 이유가 거의 없습니다.

CNN 외에도 연구자들은 이러한 원칙을 다음을 포함한 연속 그룹으로 성공적으로 번역했습니다. (2) SO(2)SO(2) (평면의 모든 회전 그룹) 및 (3) SE(3)SE(3) (3D 공간의 모든 이동 및 회전 그룹).

제 경험에 따르면 이러한 모델은 완전히 놀랍고 처음부터 훈련했을 때 몇 배 더 큰 데이터 세트에서 훈련된 기초 모델의 성능에 필적하는 성능을 달성합니다.

이 주제에 대해 더 많은 글을 쓰고 싶다면 알려주세요.

추가 참고자료

이 주제에 대해 정식으로 소개하고 싶다면 기계 학습의 등분산 역사 전체를 다루는 훌륭한 논문 모음집을 참조하세요.
아엔

저는 실제로 이 주제에 대한 심층적인 실습 튜토리얼을 만들 계획입니다. 이미 제 메일링 리스트에 가입하실 수 있습니다. 피드백과 Q&A를 위한 직접 채널과 함께 시간이 지나면서 무료 버전을 제공해 드리겠습니다.

주변에서 만나요 :)

위 내용은 정규 등변량 CNN을 구축하는 원리의 상세 내용입니다. 자세한 내용은 PHP 중국어 웹사이트의 기타 관련 기사를 참조하세요!

본 웹사이트의 성명
본 글의 내용은 네티즌들의 자발적인 기여로 작성되었으며, 저작권은 원저작자에게 있습니다. 본 사이트는 이에 상응하는 법적 책임을 지지 않습니다. 표절이나 침해가 의심되는 콘텐츠를 발견한 경우 admin@php.cn으로 문의하세요.

핫 AI 도구

Undresser.AI Undress

Undresser.AI Undress

사실적인 누드 사진을 만들기 위한 AI 기반 앱

AI Clothes Remover

AI Clothes Remover

사진에서 옷을 제거하는 온라인 AI 도구입니다.

Undress AI Tool

Undress AI Tool

무료로 이미지를 벗다

Clothoff.io

Clothoff.io

AI 옷 제거제

AI Hentai Generator

AI Hentai Generator

AI Hentai를 무료로 생성하십시오.

인기 기사

R.E.P.O. 에너지 결정과 그들이하는 일 (노란색 크리스탈)
1 몇 달 전 By 尊渡假赌尊渡假赌尊渡假赌
R.E.P.O. 최고의 그래픽 설정
1 몇 달 전 By 尊渡假赌尊渡假赌尊渡假赌
Will R.E.P.O. 크로스 플레이가 있습니까?
1 몇 달 전 By 尊渡假赌尊渡假赌尊渡假赌

뜨거운 도구

메모장++7.3.1

메모장++7.3.1

사용하기 쉬운 무료 코드 편집기

SublimeText3 중국어 버전

SublimeText3 중국어 버전

중국어 버전, 사용하기 매우 쉽습니다.

스튜디오 13.0.1 보내기

스튜디오 13.0.1 보내기

강력한 PHP 통합 개발 환경

드림위버 CS6

드림위버 CS6

시각적 웹 개발 도구

SublimeText3 Mac 버전

SublimeText3 Mac 버전

신 수준의 코드 편집 소프트웨어(SublimeText3)

Linux 터미널에서 Python 버전을 볼 때 발생하는 권한 문제를 해결하는 방법은 무엇입니까? Linux 터미널에서 Python 버전을 볼 때 발생하는 권한 문제를 해결하는 방법은 무엇입니까? Apr 01, 2025 pm 05:09 PM

Linux 터미널에서 Python 버전을 보려고 할 때 Linux 터미널에서 Python 버전을 볼 때 권한 문제에 대한 솔루션 ... Python을 입력하십시오 ...

한 데이터 프레임의 전체 열을 Python의 다른 구조를 가진 다른 데이터 프레임에 효율적으로 복사하는 방법은 무엇입니까? 한 데이터 프레임의 전체 열을 Python의 다른 구조를 가진 다른 데이터 프레임에 효율적으로 복사하는 방법은 무엇입니까? Apr 01, 2025 pm 11:15 PM

Python의 Pandas 라이브러리를 사용할 때는 구조가 다른 두 데이터 프레임 사이에서 전체 열을 복사하는 방법이 일반적인 문제입니다. 두 개의 dats가 있다고 가정 해

10 시간 이내에 프로젝트 및 문제 중심 방법에서 컴퓨터 초보자 프로그래밍 기본 사항을 가르치는 방법? 10 시간 이내에 프로젝트 및 문제 중심 방법에서 컴퓨터 초보자 프로그래밍 기본 사항을 가르치는 방법? Apr 02, 2025 am 07:18 AM

10 시간 이내에 컴퓨터 초보자 프로그래밍 기본 사항을 가르치는 방법은 무엇입니까? 컴퓨터 초보자에게 프로그래밍 지식을 가르치는 데 10 시간 밖에 걸리지 않는다면 무엇을 가르치기로 선택 하시겠습니까?

중간 독서를 위해 Fiddler를 사용할 때 브라우저에서 감지되는 것을 피하는 방법은 무엇입니까? 중간 독서를 위해 Fiddler를 사용할 때 브라우저에서 감지되는 것을 피하는 방법은 무엇입니까? Apr 02, 2025 am 07:15 AM

Fiddlerevery Where를 사용할 때 Man-in-the-Middle Reading에 Fiddlereverywhere를 사용할 때 감지되는 방법 ...

정규 표현이란 무엇입니까? 정규 표현이란 무엇입니까? Mar 20, 2025 pm 06:25 PM

정규 표현식은 프로그래밍의 패턴 일치 및 텍스트 조작을위한 강력한 도구이며 다양한 응용 프로그램에서 텍스트 처리의 효율성을 높입니다.

Uvicorn은 Serving_forever ()없이 HTTP 요청을 어떻게 지속적으로 듣습니까? Uvicorn은 Serving_forever ()없이 HTTP 요청을 어떻게 지속적으로 듣습니까? Apr 01, 2025 pm 10:51 PM

Uvicorn은 HTTP 요청을 어떻게 지속적으로 듣습니까? Uvicorn은 ASGI를 기반으로 한 가벼운 웹 서버입니다. 핵심 기능 중 하나는 HTTP 요청을 듣고 진행하는 것입니다 ...

인기있는 파이썬 라이브러리와 그 용도는 무엇입니까? 인기있는 파이썬 라이브러리와 그 용도는 무엇입니까? Mar 21, 2025 pm 06:46 PM

이 기사는 Numpy, Pandas, Matplotlib, Scikit-Learn, Tensorflow, Django, Flask 및 요청과 같은 인기있는 Python 라이브러리에 대해 설명하고 과학 컴퓨팅, 데이터 분석, 시각화, 기계 학습, 웹 개발 및 H에서의 사용에 대해 자세히 설명합니다.

문자열을 통해 객체를 동적으로 생성하고 방법을 파이썬으로 호출하는 방법은 무엇입니까? 문자열을 통해 객체를 동적으로 생성하고 방법을 파이썬으로 호출하는 방법은 무엇입니까? Apr 01, 2025 pm 11:18 PM

파이썬에서 문자열을 통해 객체를 동적으로 생성하고 메소드를 호출하는 방법은 무엇입니까? 특히 구성 또는 실행 해야하는 경우 일반적인 프로그래밍 요구 사항입니다.

See all articles