한 가지 원칙은 간단하게 '커널을 회전시키세요'라고 명시되어 있으며 이 기사에서는 이를 아키텍처에 적용하는 방법에 중점을 둘 것입니다.
등변 아키텍처를 사용하면 특정 그룹 활동에 무관한 모델을 훈련할 수 있습니다.
이것이 정확히 무엇을 의미하는지 이해하기 위해 MNIST 데이터세트(0-9의 손으로 쓴 숫자 데이터세트)에서 이 간단한 CNN 모델을 훈련시켜 보겠습니다.
class SimpleCNN(nn.Module): def __init__(self): super(SimpleCNN, self).__init__() self.cl1 = nn.Conv2d(in_channels=1, out_channels=8, kernel_size=3, padding=1) self.max_1 = nn.MaxPool2d(kernel_size=2) self.cl2 = nn.Conv2d(in_channels=8, out_channels=16, kernel_size=3, padding=1) self.max_2 = nn.MaxPool2d(kernel_size=2) self.cl3 = nn.Conv2d(in_channels=16, out_channels=16, kernel_size=7) self.dense = nn.Linear(in_features=16, out_features=10) def forward(self, x: torch.Tensor): x = nn.functional.silu(self.cl1(x)) x = self.max_1(x) x = nn.functional.silu(self.cl2(x)) x = self.max_2(x) x = nn.functional.silu(self.cl3(x)) x = x.view(len(x), -1) logits = self.dense(x) return logits
Accuracy on test | Accuracy on 90-degree rotated test |
---|---|
97.3% | 15.1% |
표 1: SimpleCNN 모델의 정확도 테스트
예상대로 테스트 데이터 세트에서 95% 이상의 정확도를 얻었습니다. 그런데 이미지를 90도 회전하면 어떻게 될까요? 어떤 대책도 적용하지 않으면 결과는 추측보다 약간 더 나은 수준으로 떨어집니다. 이 모델은 일반 응용 프로그램에는 쓸모가 없습니다.
반대로, 그룹 작업이 정확히 90도 회전하는 동일한 수의 매개변수를 사용하여 유사한 등변 아키텍처를 훈련해 보겠습니다.
Accuracy on test | Accuracy on 90-degree rotated test |
---|---|
96.5% | 96.5% |
표 2: SimpleCNN 모델과 동일한 양의 매개변수를 사용하여 EqCNN 모델의 정확도 테스트
정확도는 그대로 유지되었으며 데이터 증강도 선택하지 않았습니다.
이러한 모델은 3D 데이터로 더욱 인상적이지만 핵심 아이디어를 탐구하기 위해 이 예를 고수하겠습니다.
직접 테스트해보고 싶다면 Github-Repo에서 PyTorch와 JAX로 작성된 모든 코드에 무료로 접근할 수 있으며, 단 두 개의 명령어만으로 Docker나 Podman을 이용한 학습도 가능합니다.
즐거운 시간 보내세요!
등가 아키텍처는 특정 그룹 작업에서 기능의 안정성을 보장합니다. 그룹은 그룹 요소를 결합하거나 뒤집거나 아무것도 할 수 없는 간단한 구조입니다.
관심이 있으시면 Wikipedia에서 공식적인 정의를 찾아보실 수 있습니다.
우리의 목적에 맞게 정사각형 이미지에 작용하는 90도 회전 그룹을 생각해 볼 수 있습니다. 이미지를 90도, 180도, 270도, 360도로 회전할 수 있습니다. 동작을 되돌리려면 각각 270도, 180도, 90도 또는 0도 회전을 적용합니다. 으로 표시된 그룹을 사용하면 결합, 반전 또는 아무것도 수행할 수 없음을 쉽게 알 수 있습니다. ㄷ4 . 이미지는 이미지의 모든 동작을 시각화합니다.
그림 1: MNIST 이미지를 각각 90°, 180°, 270°, 360° 회전
Now, given an input image
x
, our CNN model classifier
fθ
, and an arbitrary 90-degree rotation
g
, the equivariant property can be expressed as
fθ(rotate x by g)=fθ(x)
Generally speaking, we want our image-based model to have the same outputs when rotated.
As such, equivariant models promise us architectures with baked-in symmetries. In the following section, we will see how our principle can be applied to achieve this property.
The problem is the following: When the image rotates, the features rotate too. But as already hinted, we could also compute the features for each rotation upfront by rotating the kernel.
We could actually rotate the kernel, but it is much easier to rotate the feature map itself, thus avoiding interference with PyTorch's autodifferentiation algorithm altogether.
So, in code, our CNN kernel
x = nn.functional.silu(self.cl1(x))
now acts on all four rotated images:
x_0 = x x_90 = torch.rot90(x, k=1, dims=(2, 3)) x_180 = torch.rot90(x, k=2, dims=(2, 3)) x_270 = torch.rot90(x, k=3, dims=(2, 3)) x_0 = nn.functional.silu(self.cl1(x_0)) x_90 = nn.functional.silu(self.cl1(x_90)) x_180 = nn.functional.silu(self.cl1(x_180)) x_270 = nn.functional.silu(self.cl1(x_270))
Or more compactly written as a 3D convolution:
self.cl1 = nn.Conv3d(in_channels=1, out_channels=8, kernel_size=(1, 3, 3)) ... x = torch.stack([x_0, x_90, x_180, x_270], dim=-3) x = nn.functional.silu(self.cl1(x))
The resulting equivariant model has just a few lines more compared to the version above:
class EqCNN(nn.Module): def __init__(self): super(EqCNN, self).__init__() self.cl1 = nn.Conv3d(in_channels=1, out_channels=8, kernel_size=(1, 3, 3)) self.max_1 = nn.MaxPool3d(kernel_size=(1, 2, 2)) self.cl2 = nn.Conv3d(in_channels=8, out_channels=16, kernel_size=(1, 3, 3)) self.max_2 = nn.MaxPool3d(kernel_size=(1, 2, 2)) self.cl3 = nn.Conv3d(in_channels=16, out_channels=16, kernel_size=(1, 5, 5)) self.dense = nn.Linear(in_features=16, out_features=10) def forward(self, x: torch.Tensor): x_0 = x x_90 = torch.rot90(x, k=1, dims=(2, 3)) x_180 = torch.rot90(x, k=2, dims=(2, 3)) x_270 = torch.rot90(x, k=3, dims=(2, 3)) x = torch.stack([x_0, x_90, x_180, x_270], dim=-3) x = nn.functional.silu(self.cl1(x)) x = self.max_1(x) x = nn.functional.silu(self.cl2(x)) x = self.max_2(x) x = nn.functional.silu(self.cl3(x)) x = x.squeeze() x = torch.max(x, dim=-1).values logits = self.dense(x) return logits
But why is this equivariant to rotations?
First, observe that we get four copies of each feature map at each stage. At the end of the pipeline, we combine all of them with a max operation.
This is key, the max operation is indifferent to which place the rotated version of the feature ends up in.
To understand what is happening, let us plot the feature maps after the first convolution stage.
Figure 2: Feature maps for all four rotations
And now the same features after we rotate the input by 90 degrees.
그림 3: 입력 이미지가 회전된 후 4개 회전 모두에 대한 특징 맵
해당 지도를 색상으로 구분했습니다. 각 기능 맵은 하나씩 이동됩니다. 최종 최대값 연산자는 이러한 이동된 특징 맵에 대해 동일한 결과를 계산하므로 동일한 결과를 얻습니다.
내 코드에서는 커널이 이미지를 1차원 배열로 압축하기 때문에 최종 컨볼루션 후에 다시 회전하지 않았습니다. 이 예를 확장하려면 이 사실을 고려해야 합니다.
그룹 작업 또는 "커널 회전"에 대한 설명은 보다 정교한 아키텍처를 설계하는 데 중요한 역할을 합니다.
아니요. 계산 속도, 귀납적 편향, 더 복잡한 구현에 대한 대가를 지불합니다.
후자의 문제는 대부분의 무거운 수학이 추상화되는 E3NN과 같은 라이브러리를 사용하면 어느 정도 해결됩니다. 그럼에도 불구하고 건축설계에서는 많은 것을 고려해야 한다.
한 가지 표면적인 약점은 모든 회전된 피처 레이어를 계산하는 데 드는 계산 비용이 4배라는 것입니다. 그러나 대량 병렬화 기능을 갖춘 최신 하드웨어는 이러한 부하를 쉽게 상쇄할 수 있습니다. 대조적으로, 데이터 증강을 사용하여 간단한 CNN을 훈련하는 것은 훈련 시간이 쉽게 10배를 초과합니다. 이는 가능한 모든 회전을 보상하기 위해 데이터 증강에 약 500배의 훈련량이 필요한 3D 회전의 경우 더욱 악화됩니다.
전체적으로 등분산 모델 설계는 안정적인 기능을 원하는 경우 지불할 가치가 있는 가격이 아닌 경우가 많습니다.
최근 몇 년 동안 등가 모델 디자인이 폭발적으로 증가했으며 이 기사에서는 표면적인 내용만 다루었습니다. 사실 우리는 전체 기능을 활용하지도 못했습니다. ㄷ4 아직 그룹. 우리는 완전한 3D 커널을 사용할 수도 있었습니다. 그러나 우리 모델은 이미 95% 이상의 정확도를 달성했기 때문에 이 예를 더 이상 사용할 이유가 거의 없습니다.
CNN 외에도 연구자들은 이러한 원칙을 다음을 포함한 연속 그룹으로 성공적으로 번역했습니다. SO(2) (평면의 모든 회전 그룹) 및 SE(3) (3D 공간의 모든 이동 및 회전 그룹).
제 경험에 따르면 이러한 모델은 완전히 놀랍고 처음부터 훈련했을 때 몇 배 더 큰 데이터 세트에서 훈련된 기초 모델의 성능에 필적하는 성능을 달성합니다.
이 주제에 대해 더 많은 글을 쓰고 싶다면 알려주세요.
이 주제에 대해 정식으로 소개하고 싶다면 기계 학습의 등분산 역사 전체를 다루는 훌륭한 논문 모음집을 참조하세요.
아엔
저는 실제로 이 주제에 대한 심층적인 실습 튜토리얼을 만들 계획입니다. 이미 제 메일링 리스트에 가입하실 수 있습니다. 피드백과 Q&A를 위한 직접 채널과 함께 시간이 지나면서 무료 버전을 제공해 드리겠습니다.
주변에서 만나요 :)
위 내용은 정규 등변량 CNN을 구축하는 원리의 상세 내용입니다. 자세한 내용은 PHP 중국어 웹사이트의 기타 관련 기사를 참조하세요!