Home > Backend Development > Python Tutorial > Why fine-tuning an MLP model on a small dataset still maintains the same test accuracy as pre-trained weights?

Why fine-tuning an MLP model on a small dataset still maintains the same test accuracy as pre-trained weights?

WBOY
Release: 2024-02-10 21:36:04
forward
638 people have browsed it

为什么在小数据集上微调 MLP 模型,仍然保持与预训练权重相同的测试精度?

Question content

I designed a simple mlp model to train on 6k data samples.

class mlp(nn.module):
    def __init__(self,input_dim=92, hidden_dim = 150, num_classes=2):
        super().__init__()
        self.input_dim = input_dim
        self.num_classes = num_classes
        self.hidden_dim = hidden_dim
        #self.softmax = nn.softmax(dim=1)

        self.layers = nn.sequential(
            nn.linear(self.input_dim, self.hidden_dim),
            nn.relu(),
            nn.linear(self.hidden_dim, self.hidden_dim),
            nn.relu(),
            nn.linear(self.hidden_dim, self.hidden_dim),
            nn.relu(),
            nn.linear(self.hidden_dim, self.num_classes),

        )

    def forward(self, x):
        x = self.layers(x)
        return x
Copy after login

And the model has been instantiated

model = mlp(input_dim=input_dim, hidden_dim=hidden_dim, num_classes=num_classes).to(device)

optimizer = optimizer.adam(model.parameters(), lr=learning_rate, weight_decay=1e-4)
criterion = nn.crossentropyloss()
Copy after login

and hyperparameters:

num_epoch = 300   # 200e3//len(train_loader)
learning_rate = 1e-3
batch_size = 64
device = torch.device("cuda")
seed = 42
torch.manual_seed(42)
Copy after login

My implementation mainly follows this question. I saved the model as pretrained weights model_weights.pth.

The accuracy of

model on the test data set is 96.80%.

Then, I have another 50 samples (in finetune_loader) on which I am trying to fine-tune the model:

model_finetune = MLP()
model_finetune.load_state_dict(torch.load('model_weights.pth'))
model_finetune.to(device)
model_finetune.train()
# train the network
for t in tqdm(range(num_epoch)):
  for i, data in enumerate(finetune_loader, 0):
    #def closure():
      # Get and prepare inputs
      inputs, targets = data
      inputs, targets = inputs.float(), targets.long()
      inputs, targets = inputs.to(device), targets.to(device)
      
      # Zero the gradients
      optimizer.zero_grad()
      # Perform forward pass
      outputs = model_finetune(inputs)
      # Compute loss
      loss = criterion(outputs, targets)
      # Perform backward pass
      loss.backward()
      #return loss
      optimizer.step()     # a

model_finetune.eval()
with torch.no_grad():
    outputs2 = model_finetune(test_data)
    #predicted_labels = outputs.squeeze().tolist()

    _, preds = torch.max(outputs2, 1)
    prediction_test = np.array(preds.cpu())
    accuracy_test_finetune = accuracy_score(y_test, prediction_test)
    accuracy_test_finetune
    
    Output: 0.9680851063829787
Copy after login

I checked, the accuracy remains the same as before fine-tuning the model to 50 samples, and the output probabilities are also the same.

What could be the reason? Did I make some mistakes in fine-tuning the code?


Correct answer


You must reinitialize the optimizer with a new model (model_finetune object). Currently, as I can see in your code, it seems to still use the optimizer that is initialized with the old model weights - model.parameters().

The above is the detailed content of Why fine-tuning an MLP model on a small dataset still maintains the same test accuracy as pre-trained weights?. For more information, please follow other related articles on the PHP Chinese website!

source:stackoverflow.com
Statement of this Website
The content of this article is voluntarily contributed by netizens, and the copyright belongs to the original author. This site does not assume corresponding legal responsibility. If you find any content suspected of plagiarism or infringement, please contact admin@php.cn
Popular Tutorials
More>
Latest Downloads
More>
Web Effects
Website Source Code
Website Materials
Front End Template