Home > Technology peripherals > AI > Code examples for knowledge distillation using PyTorch

Code examples for knowledge distillation using PyTorch

王林
Release: 2023-04-11 22:31:13
forward
1004 people have browsed it

As machine learning models continue to increase in complexity and capabilities. An effective technique for improving the performance of large, complex models on small data sets is knowledge distillation, which involves training a smaller, more efficient model to mimic the behavior of a larger "teacher" model.

Code examples for knowledge distillation using PyTorch

In this article, we will explore the concept of knowledge distillation and how to implement it in PyTorch. We'll see how it can be used to compress a large, unwieldy model into a smaller, more efficient model while still retaining the accuracy and performance of the original model.

We first define the problem to be solved by knowledge distillation.

We trained a large deep neural network to perform complex tasks such as image classification or machine translation. This model may have thousands of layers and millions of parameters, making it difficult to deploy in real-world applications, edge devices, etc. And this very large model also requires a lot of computing resources to run, which makes it unable to work on some resource-constrained platforms.

One way to solve this problem is to use knowledge distillation to compress large models into smaller models. This process involves training a smaller model to mimic the behavior of the larger model in a given task.

We will do an example of knowledge distillation using the chest x-ray dataset from Kaggle for pneumonia classification. The dataset we used is organized into 3 folders (train, test, val) and contains subfolders for each image category (Pneumonia/Normal). There are 5,863 x-ray images (JPEG) and 2 categories (pneumonia/normal).

Compare the pictures of these two classes:

Code examples for knowledge distillation using PyTorch

#The loading and preprocessing of data has nothing to do with whether we use knowledge distillation or a specific model, the code snippet may As shown below:

transforms_train = transforms.Compose([
transforms.Resize((224, 224)),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406],
[0.229, 0.224, 0.225])])
 
 transforms_test = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406],
[0.229, 0.224, 0.225])])
 
 train_data = ImageFolder(root=train_dir, transform=transforms_train)
 test_data = ImageFolder(root=test_dir, transform=transforms_test)
 
 train_loader = DataLoader(train_data, batch_size=32, shuffle=True)
 test_loader = DataLoader(test_data, batch_size=32, shuffle=True)
Copy after login

Teacher Model

In this background teacher model we use Resnet-18 and fine-tuned on this dataset.

import torch
 import torch.nn as nn
 import torchvision
 
 class TeacherNet(nn.Module):
def __init__(self):
super().__init__()
self.model = torchvision.models.resnet18(pretrained=True)
for params in self.model.parameters():
params.requires_grad_ = False
 
n_filters = self.model.fc.in_features
self.model.fc = nn.Linear(n_filters, 2)
 
def forward(self, x):
x = self.model(x)
return x
Copy after login

The code for fine-tuning training is as follows

 def train(model, train_loader, test_loader, optimizer, criterion, device):
dataloaders = {'train': train_loader, 'val': test_loader}
 
for epoch in range(30):
print('Epoch {}/{}'.format(epoch, num_epochs - 1))
print('-' * 10)
 
for phase in ['train', 'val']:
if phase == 'train':
model.train()
else:
model.eval()
 
running_loss = 0.0
running_corrects = 0
 
for inputs, labels in tqdm.tqdm(dataloaders[phase]):
inputs = inputs.to(device)
labels = labels.to(device)
 
optimizer.zero_grad()
 
with torch.set_grad_enabled(phase == 'train'):
outputs = model(inputs)
loss = criterion(outputs, labels)
 
_, preds = torch.max(outputs, 1)
 
if phase == 'train':
loss.backward()
optimizer.step()
 
running_loss += loss.item() * inputs.size(0)
running_corrects += torch.sum(preds == labels.data)
 
epoch_loss = running_loss / len(dataloaders[phase].dataset)
epoch_acc = running_corrects.double() / len(dataloaders[phase].dataset)
 
print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc))
Copy after login

This is a standard fine-tuning training step. After training, we can see that the model achieved 91% accuracy on the test set, which is also That's why we didn't choose a larger model, because the accuracy of test 91 is enough to be used as a base model.

We know that the model has 11.7 million parameters, so it may not necessarily be able to adapt to edge devices or other specific scenarios.

Student Model

Our student is a shallower CNN with only a few layers and about 100k parameters.

class StudentNet(nn.Module):
def __init__(self):
super().__init__()
self.layer1 = nn.Sequential(
nn.Conv2d(3, 4, kernel_size=3, padding=1),
nn.BatchNorm2d(4),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2)
)
self.fc = nn.Linear(4 * 112 * 112, 2)
 
def forward(self, x):
out = self.layer1(x)
out = out.view(out.size(0), -1)
out = self.fc(out)
return out
Copy after login

It’s very simple when you look at the code, right.

Why should I bother with knowledge distillation if I can simply train this smaller neural network? We will attach at the end the results of training this network from scratch through hyperparameter adjustment and other means for comparison.

But now we continue our knowledge distillation steps

Knowledge distillation training

The basic steps of training are unchanged, but the difference is how to calculate the final training loss, We will use the teacher model loss, student model loss and distillation loss together to calculate the final loss.

class DistillationLoss:
def __init__(self):
self.student_loss = nn.CrossEntropyLoss()
self.distillation_loss = nn.KLDivLoss()
self.temperature = 1
self.alpha = 0.25
 
def __call__(self, student_logits, student_target_loss, teacher_logits):
distillation_loss = self.distillation_loss(F.log_softmax(student_logits / self.temperature, dim=1),
F.softmax(teacher_logits / self.temperature, dim=1))
 
loss = (1 - self.alpha) * student_target_loss + self.alpha * distillation_loss
return loss
Copy after login

The loss function is the weighted sum of the following two things:

  • Classification loss, called student_target_loss
  • Distillation loss, the sum of the student logarithm and the teacher logarithm Cross entropy loss between

Code examples for knowledge distillation using PyTorch

Simply put, our teacher model needs to teach students how to "think", which refers to its uncertainty ;For example, if the final output probability of the teacher model is [0.53, 0.47], we hope that the student will also get the same similar results, and the difference between these predictions is the distillation loss.

In order to control the loss, there are two main parameters:

  • The weight of the distillation loss: 0 means we only consider the distillation loss, and vice versa.
  • Temperature: Measuring the uncertainty of teacher predictions.

In the above points, the values ​​of alpha and temperature are based on the best results obtained by some combinations we tried.

Result comparison

This is a tabular summary of this experiment.

Code examples for knowledge distillation using PyTorch

We can clearly see the huge benefit gained from using a smaller (99.14%), shallower CNN: the accuracy is improved compared to training without distillation 10 points, and 11 times faster than Resnet-18! That is, our small model really learned something useful from the large model.


The above is the detailed content of Code examples for knowledge distillation using PyTorch. For more information, please follow other related articles on the PHP Chinese website!

Related labels:
source:51cto.com
Statement of this Website
The content of this article is voluntarily contributed by netizens, and the copyright belongs to the original author. This site does not assume corresponding legal responsibility. If you find any content suspected of plagiarism or infringement, please contact admin@php.cn
Popular Tutorials
More>
Latest Downloads
More>
Web Effects
Website Source Code
Website Materials
Front End Template