PyTorch を使用した知識蒸留のコード例

王林
リリース: 2023-04-11 22:31:13
転載
949 人が閲覧しました

機械学習モデルの複雑さと機能は増加し続けています。小規模なデータセット上で大規模で複雑なモデルのパフォーマンスを向上させる効果的な手法は、より小規模で効率的なモデルをトレーニングして、より大きな「教師」モデルの動作を模倣することを含む知識蒸留です。

PyTorch を使用した知識蒸留のコード例

この記事では、知識の蒸留の概念と、それを PyTorch で実装する方法について説明します。元のモデルの精度とパフォーマンスを維持しながら、大きくて扱いにくいモデルをより小さく効率的なモデルに圧縮するためにどのように使用できるかを見ていきます。

まず、知識の蒸留によって解決すべき問題を定義します。

私たちは、画像分類や機械翻訳などの複雑なタスクを実行するために大規模なディープ ニューラル ネットワークをトレーニングしました。このモデルには数千のレイヤーと数百万のパラメーターが含まれる可能性があるため、現実世界のアプリケーションやエッジ デバイスなどに展開することが困難になります。また、この非常に大規模なモデルの実行には大量のコンピューティング リソースも必要となるため、リソースに制約のある一部のプラットフォームでは動作できません。

この問題を解決する 1 つの方法は、知識の蒸留を使用して、大きなモデルを小さなモデルに圧縮することです。このプロセスには、特定のタスクにおける大きなモデルの動作を模倣するために小さなモデルをトレーニングすることが含まれます。

Kaggle の胸部 X 線データセットを肺炎分類に使用して知識を蒸留する例を実行します。使用したデータセットは 3 つのフォルダー (train、test、val) に編成されており、各画像カテゴリ (肺炎/正常) のサブフォルダーが含まれています。 5,863 枚の X 線画像 (JPEG) と 2 つのカテゴリ (肺炎/正常) があります。

これら 2 つのクラスの図を比較してください:

PyTorch を使用した知識蒸留のコード例

#データの読み込みと前処理は、知識の蒸留を使用するか特定のモデルを使用するかとは関係ありません。

transforms_train = transforms.Compose([
transforms.Resize((224, 224)),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406],
[0.229, 0.224, 0.225])])
 
 transforms_test = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406],
[0.229, 0.224, 0.225])])
 
 train_data = ImageFolder(root=train_dir, transform=transforms_train)
 test_data = ImageFolder(root=test_dir, transform=transforms_test)
 
 train_loader = DataLoader(train_data, batch_size=32, shuffle=True)
 test_loader = DataLoader(test_data, batch_size=32, shuffle=True)
ログイン後にコピー

教師モデル

このバックグラウンド教師モデルでは、Resnet-18 を使用し、このデータセットで微調整しています。

import torch
 import torch.nn as nn
 import torchvision
 
 class TeacherNet(nn.Module):
def __init__(self):
super().__init__()
self.model = torchvision.models.resnet18(pretrained=True)
for params in self.model.parameters():
params.requires_grad_ = False
 
n_filters = self.model.fc.in_features
self.model.fc = nn.Linear(n_filters, 2)
 
def forward(self, x):
x = self.model(x)
return x
ログイン後にコピー

微調整トレーニングのコードは次のとおりです

 def train(model, train_loader, test_loader, optimizer, criterion, device):
dataloaders = {'train': train_loader, 'val': test_loader}
 
for epoch in range(30):
print('Epoch {}/{}'.format(epoch, num_epochs - 1))
print('-' * 10)
 
for phase in ['train', 'val']:
if phase == 'train':
model.train()
else:
model.eval()
 
running_loss = 0.0
running_corrects = 0
 
for inputs, labels in tqdm.tqdm(dataloaders[phase]):
inputs = inputs.to(device)
labels = labels.to(device)
 
optimizer.zero_grad()
 
with torch.set_grad_enabled(phase == 'train'):
outputs = model(inputs)
loss = criterion(outputs, labels)
 
_, preds = torch.max(outputs, 1)
 
if phase == 'train':
loss.backward()
optimizer.step()
 
running_loss += loss.item() * inputs.size(0)
running_corrects += torch.sum(preds == labels.data)
 
epoch_loss = running_loss / len(dataloaders[phase].dataset)
epoch_acc = running_corrects.double() / len(dataloaders[phase].dataset)
 
print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc))
ログイン後にコピー

これは標準的な微調整トレーニング ステップです。トレーニング後、モデルがテストで 91% の精度を達成したことがわかります。テスト 91 の精度はベース モデルとして使用するのに十分であるため、より大きなモデルを選択しなかったのはそのためです。

モデルには 1,170 万個のパラメーターがあることがわかっているため、必ずしもエッジ デバイスやその他の特定のシナリオに適応できるとは限りません。

学生モデル

私たちの学生は、わずか数層と約 100k パラメータを持つ浅い CNN です。

class StudentNet(nn.Module):
def __init__(self):
super().__init__()
self.layer1 = nn.Sequential(
nn.Conv2d(3, 4, kernel_size=3, padding=1),
nn.BatchNorm2d(4),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2)
)
self.fc = nn.Linear(4 * 112 * 112, 2)
 
def forward(self, x):
out = self.layer1(x)
out = out.view(out.size(0), -1)
out = self.fc(out)
return out
ログイン後にコピー

コードを見るととてもシンプルですね。

この小さなニューラル ネットワークを単純にトレーニングできるのに、なぜわざわざ知識の蒸留をする必要があるのでしょうか? 比較のために、ハイパーパラメータ調整やその他の手段を通じてこのネットワークを最初からトレーニングした結果を最後に添付します。

しかし、知識の蒸留ステップを継続します

知識の蒸留トレーニング

トレーニングの基本的なステップは変わりませんが、違いは最終的なトレーニング損失の計算方法です。教師モデルの損失、生徒モデルの損失、蒸留損失を組み合わせて最終的な損失を計算します。

class DistillationLoss:
def __init__(self):
self.student_loss = nn.CrossEntropyLoss()
self.distillation_loss = nn.KLDivLoss()
self.temperature = 1
self.alpha = 0.25
 
def __call__(self, student_logits, student_target_loss, teacher_logits):
distillation_loss = self.distillation_loss(F.log_softmax(student_logits / self.temperature, dim=1),
F.softmax(teacher_logits / self.temperature, dim=1))
 
loss = (1 - self.alpha) * student_target_loss + self.alpha * distillation_loss
return loss
ログイン後にコピー

損失関数は、次の 2 つの重み付き合計です。

  • Student_target_loss と呼ばれる分類損失
  • 蒸留損失、スチューデントの対数とスチューデントの対数の合計教師の対数

PyTorch を使用した知識蒸留のコード例

間のクロスエントロピー損失 簡単に言うと、教師モデルは、不確実性を指す「考える」方法を生徒に教える必要があります。たとえば、教師モデルの最終出力確率が [0.53, 0.47] の場合、生徒も同様の結果が得られることが期待されます。これらの予測の差は蒸留損失です。

損失を制御するには、2 つの主なパラメータがあります:

  • 蒸留損失の重み: 0 は蒸留損失のみを考慮することを意味し、その逆も同様です。
  • 温度: 教師の予測の不確実性を測定します。

上記の点において、アルファと温度の値は、私たちが試したいくつかの組み合わせによって得られた最良の結果に基づいています。

結果の比較

これは、この実験の概要を表にまとめたものです。

PyTorch を使用した知識蒸留のコード例

より小さく (99.14%) 浅い CNN を使用することで得られる大きな利点が明らかにわかります。精度は、蒸留なしのトレーニングと比較して 10 ポイント、11 倍向上しました。 Resnet-18 よりも速い! つまり、私たちの小さなモデルは実際に大きなモデルから有益なことを学びました。


以上がPyTorch を使用した知識蒸留のコード例の詳細内容です。詳細については、PHP 中国語 Web サイトの他の関連記事を参照してください。

関連ラベル:
ソース:51cto.com
このウェブサイトの声明
この記事の内容はネチズンが自主的に寄稿したものであり、著作権は原著者に帰属します。このサイトは、それに相当する法的責任を負いません。盗作または侵害の疑いのあるコンテンツを見つけた場合は、admin@php.cn までご連絡ください。
最新の問題
人気のチュートリアル
詳細>
最新のダウンロード
詳細>
ウェブエフェクト
公式サイト
サイト素材
フロントエンドテンプレート