Maison > Périphériques technologiques > IA > Exemples de code pour la distillation des connaissances à l'aide de PyTorch

Exemples de code pour la distillation des connaissances à l'aide de PyTorch

王林
Libérer: 2023-04-11 22:31:13
avant
1003 Les gens l'ont consulté

Alors que les modèles d'apprentissage automatique continuent de gagner en complexité et en capacités. Une technique efficace pour améliorer les performances de modèles volumineux et complexes sur de petits ensembles de données est la distillation des connaissances, qui implique la formation d'un modèle plus petit et plus efficace pour imiter le comportement d'un modèle « enseignant » plus grand.

Exemples de code pour la distillation des connaissances à l'aide de PyTorch

Dans cet article, nous explorerons le concept de distillation des connaissances et comment le mettre en œuvre dans PyTorch. Nous verrons comment il peut être utilisé pour compresser un modèle volumineux et peu maniable en un modèle plus petit et plus efficace tout en conservant la précision et les performances du modèle d'origine.

Nous définissons d'abord le problème à résoudre par distillation des connaissances.

Nous avons formé un vaste réseau neuronal profond pour effectuer des tâches complexes telles que la classification d'images ou la traduction automatique. Ce modèle peut comporter des milliers de couches et des millions de paramètres, ce qui rend difficile son déploiement dans des applications du monde réel, des appareils de périphérie, etc. Et ce modèle très volumineux nécessite également beaucoup de ressources informatiques pour fonctionner, ce qui le rend incapable de fonctionner sur certaines plates-formes aux ressources limitées.

Une façon de résoudre ce problème consiste à utiliser la distillation des connaissances pour compresser de grands modèles en modèles plus petits. Ce processus implique la formation d'un modèle plus petit pour imiter le comportement du modèle plus grand dans une tâche donnée.

Nous utiliserons un exemple de distillation des connaissances en utilisant l'ensemble de données de radiographie pulmonaire de Kaggle pour la classification de la pneumonie. L'ensemble de données que nous avons utilisé est organisé en 3 dossiers (train, test, val) et contient des sous-dossiers pour chaque catégorie d'image (Pneumonie/Normal). Il existe 5 863 images radiographiques (JPEG) et 2 catégories (pneumonie/normale).

Comparez les images de ces deux classes :

Exemples de code pour la distillation des connaissances à laide de PyTorch

Le chargement et le prétraitement des données sont indépendants du fait que nous utilisons la distillation des connaissances ou un modèle spécifique, l'extrait de code pourrait ressembler à ceci :

transforms_train = transforms.Compose([
transforms.Resize((224, 224)),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406],
[0.229, 0.224, 0.225])])
 
 transforms_test = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406],
[0.229, 0.224, 0.225])])
 
 train_data = ImageFolder(root=train_dir, transform=transforms_train)
 test_data = ImageFolder(root=test_dir, transform=transforms_test)
 
 train_loader = DataLoader(train_data, batch_size=32, shuffle=True)
 test_loader = DataLoader(test_data, batch_size=32, shuffle=True)
Copier après la connexion

Teacher Model

Dans ce contexte Pour le modèle d'enseignant intermédiaire, nous utilisons Resnet-18 et l'affinons sur cet ensemble de données.

import torch
 import torch.nn as nn
 import torchvision
 
 class TeacherNet(nn.Module):
def __init__(self):
super().__init__()
self.model = torchvision.models.resnet18(pretrained=True)
for params in self.model.parameters():
params.requires_grad_ = False
 
n_filters = self.model.fc.in_features
self.model.fc = nn.Linear(n_filters, 2)
 
def forward(self, x):
x = self.model(x)
return x
Copier après la connexion

Le code pour l'entraînement de réglage fin est le suivant

 def train(model, train_loader, test_loader, optimizer, criterion, device):
dataloaders = {'train': train_loader, 'val': test_loader}
 
for epoch in range(30):
print('Epoch {}/{}'.format(epoch, num_epochs - 1))
print('-' * 10)
 
for phase in ['train', 'val']:
if phase == 'train':
model.train()
else:
model.eval()
 
running_loss = 0.0
running_corrects = 0
 
for inputs, labels in tqdm.tqdm(dataloaders[phase]):
inputs = inputs.to(device)
labels = labels.to(device)
 
optimizer.zero_grad()
 
with torch.set_grad_enabled(phase == 'train'):
outputs = model(inputs)
loss = criterion(outputs, labels)
 
_, preds = torch.max(outputs, 1)
 
if phase == 'train':
loss.backward()
optimizer.step()
 
running_loss += loss.item() * inputs.size(0)
running_corrects += torch.sum(preds == labels.data)
 
epoch_loss = running_loss / len(dataloaders[phase].dataset)
epoch_acc = running_corrects.double() / len(dataloaders[phase].dataset)
 
print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc))
Copier après la connexion

Il s'agit d'une étape d'entraînement de réglage fin standard. Après l'entraînement, nous pouvons voir que le modèle a atteint une précision de 91 % sur l'ensemble de test, ce qui signifie que nous n'avons pas choisi. un modèle plus grand. La raison en est que la précision du test 91 est suffisante pour être utilisée comme modèle de classe de base.

Nous savons que le modèle comporte 11,7 millions de paramètres, il ne pourra donc pas nécessairement s'adapter aux appareils de pointe ou à d'autres scénarios spécifiques.

Modèle étudiant

Notre étudiant est un CNN moins profond avec seulement quelques couches et environ 100 000 paramètres.

class StudentNet(nn.Module):
def __init__(self):
super().__init__()
self.layer1 = nn.Sequential(
nn.Conv2d(3, 4, kernel_size=3, padding=1),
nn.BatchNorm2d(4),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2)
)
self.fc = nn.Linear(4 * 112 * 112, 2)
 
def forward(self, x):
out = self.layer1(x)
out = out.view(out.size(0), -1)
out = self.fc(out)
return out
Copier après la connexion

C'est très simple si vous regardez le code, n'est-ce pas.

Si je peux simplement former ce réseau neuronal plus petit, pourquoi devrais-je m'embêter avec la distillation des connaissances ? Nous joindrons enfin les résultats de la formation de ce réseau à partir de zéro grâce à l'ajustement des hyperparamètres et d'autres moyens de comparaison ?

Mais maintenant, nous continuons nos étapes de distillation des connaissances

Formation de distillation des connaissances

Les étapes de base de la formation sont les mêmes, mais la différence est de savoir comment calculer la perte de formation finale, nous utiliserons la perte du modèle de l'enseignant, le modèle de l'étudiant perte et La perte de distillation est calculée avec la perte finale.

class DistillationLoss:
def __init__(self):
self.student_loss = nn.CrossEntropyLoss()
self.distillation_loss = nn.KLDivLoss()
self.temperature = 1
self.alpha = 0.25
 
def __call__(self, student_logits, student_target_loss, teacher_logits):
distillation_loss = self.distillation_loss(F.log_softmax(student_logits / self.temperature, dim=1),
F.softmax(teacher_logits / self.temperature, dim=1))
 
loss = (1 - self.alpha) * student_target_loss + self.alpha * distillation_loss
return loss
Copier après la connexion

La fonction de perte est la somme pondérée des deux choses suivantes :

  • Perte de classification, appelée student_target_loss
  • Perte de distillation, perte d'entropie croisée entre le logarithme de l'élève et le logarithme de l'enseignant

Exemples de code pour la distillation des connaissances à laide de PyTorch

En termes simples, notre modèle d'enseignant doit apprendre aux élèves à « penser », ce qui fait référence à son incertitude ; par exemple, si la probabilité de sortie finale du modèle de l'enseignant est [0,53, 0,47], nous espérons que les élèves obtiendront également les mêmes résultats similaires, la différence entre ces prédictions sont la perte de distillation.

Afin de contrôler la perte, il y a deux paramètres principaux :

  • Le poids de la perte de distillation : 0 signifie qu'on ne considère que la perte de distillation et vice versa.
  • Température : mesure l'incertitude des prédictions des enseignants.

Dans les points ci-dessus, les valeurs d'alpha et de température sont basées sur les meilleurs résultats que nous avons essayés avec quelques combinaisons.

Comparaison des résultats

Ceci est un résumé tabulaire de cette expérience.

Exemples de code pour la distillation des connaissances à laide de PyTorch

Nous pouvons clairement voir les énormes avantages obtenus en utilisant un CNN plus petit (99,14 %), moins profond : 10 points d'amélioration de la précision par rapport à l'entraînement sans distillation, et 11 points plus rapide que Resnet-18 Times En d'autres termes, notre ! le petit modèle a vraiment appris quelque chose d'utile du grand modèle.


Ce qui précède est le contenu détaillé de. pour plus d'informations, suivez d'autres articles connexes sur le site Web de PHP en chinois!

Étiquettes associées:
source:51cto.com
Déclaration de ce site Web
Le contenu de cet article est volontairement contribué par les internautes et les droits d'auteur appartiennent à l'auteur original. Ce site n'assume aucune responsabilité légale correspondante. Si vous trouvez un contenu suspecté de plagiat ou de contrefaçon, veuillez contacter admin@php.cn
Tutoriels populaires
Plus>
Derniers téléchargements
Plus>
effets Web
Code source du site Web
Matériel du site Web
Modèle frontal