Warum behält die Feinabstimmung eines MLP-Modells anhand eines kleinen Datensatzes immer noch die gleiche Testgenauigkeit bei wie vorab trainierte Gewichte?

WBOY
Freigeben: 2024-02-10 21:36:04
nach vorne
568 Leute haben es durchsucht

为什么在小数据集上微调 MLP 模型,仍然保持与预训练权重相同的测试精度?

Frageninhalt

Ich habe ein einfaches MLP-Modell entworfen, um anhand von 6.000 Datenproben zu trainieren.

class mlp(nn.module):
    def __init__(self,input_dim=92, hidden_dim = 150, num_classes=2):
        super().__init__()
        self.input_dim = input_dim
        self.num_classes = num_classes
        self.hidden_dim = hidden_dim
        #self.softmax = nn.softmax(dim=1)

        self.layers = nn.sequential(
            nn.linear(self.input_dim, self.hidden_dim),
            nn.relu(),
            nn.linear(self.hidden_dim, self.hidden_dim),
            nn.relu(),
            nn.linear(self.hidden_dim, self.hidden_dim),
            nn.relu(),
            nn.linear(self.hidden_dim, self.num_classes),

        )

    def forward(self, x):
        x = self.layers(x)
        return x
Nach dem Login kopieren

und das Modell wird instanziiert

model = mlp(input_dim=input_dim, hidden_dim=hidden_dim, num_classes=num_classes).to(device)

optimizer = optimizer.adam(model.parameters(), lr=learning_rate, weight_decay=1e-4)
criterion = nn.crossentropyloss()
Nach dem Login kopieren

und Hyperparameter:

num_epoch = 300   # 200e3//len(train_loader)
learning_rate = 1e-3
batch_size = 64
device = torch.device("cuda")
seed = 42
torch.manual_seed(42)
Nach dem Login kopieren

Meine Implementierung folgt hauptsächlich dieser Frage. Ich speichere das Modell als vortrainierte Gewichte model_weights.pth.

model在测试数据集上的准确率是96.80%.

Dann habe ich weitere 50 Proben (in finetune_loader), an denen ich versuche, das Modell zu verfeinern:

model_finetune = MLP()
model_finetune.load_state_dict(torch.load('model_weights.pth'))
model_finetune.to(device)
model_finetune.train()
# train the network
for t in tqdm(range(num_epoch)):
  for i, data in enumerate(finetune_loader, 0):
    #def closure():
      # Get and prepare inputs
      inputs, targets = data
      inputs, targets = inputs.float(), targets.long()
      inputs, targets = inputs.to(device), targets.to(device)
      
      # Zero the gradients
      optimizer.zero_grad()
      # Perform forward pass
      outputs = model_finetune(inputs)
      # Compute loss
      loss = criterion(outputs, targets)
      # Perform backward pass
      loss.backward()
      #return loss
      optimizer.step()     # a

model_finetune.eval()
with torch.no_grad():
    outputs2 = model_finetune(test_data)
    #predicted_labels = outputs.squeeze().tolist()

    _, preds = torch.max(outputs2, 1)
    prediction_test = np.array(preds.cpu())
    accuracy_test_finetune = accuracy_score(y_test, prediction_test)
    accuracy_test_finetune
    
    Output: 0.9680851063829787
Nach dem Login kopieren

Ich habe es überprüft, die Genauigkeit bleibt dieselbe wie vor der Feinabstimmung des Modells auf 50 Stichproben, und auch die Ausgabewahrscheinlichkeiten sind dieselben.

Was könnte der Grund sein? Habe ich bei der Feinabstimmung des Codes Fehler gemacht?


Richtige Antwort


Sie müssen den Optimierer mit einem neuen Modell (model_finetune-Objekt) neu initialisieren. Derzeit scheint es, wie ich in Ihrem Code sehen kann, immer noch den Optimierer zu verwenden, der mit den alten Modellgewichten initialisiert wurde – model.parameters().

Das obige ist der detaillierte Inhalt vonWarum behält die Feinabstimmung eines MLP-Modells anhand eines kleinen Datensatzes immer noch die gleiche Testgenauigkeit bei wie vorab trainierte Gewichte?. Für weitere Informationen folgen Sie bitte anderen verwandten Artikeln auf der PHP chinesischen Website!

Quelle:stackoverflow.com
Erklärung dieser Website
Der Inhalt dieses Artikels wird freiwillig von Internetnutzern beigesteuert und das Urheberrecht liegt beim ursprünglichen Autor. Diese Website übernimmt keine entsprechende rechtliche Verantwortung. Wenn Sie Inhalte finden, bei denen der Verdacht eines Plagiats oder einer Rechtsverletzung besteht, wenden Sie sich bitte an admin@php.cn
Beliebte Tutorials
Mehr>
Neueste Downloads
Mehr>
Web-Effekte
Quellcode der Website
Website-Materialien
Frontend-Vorlage
Über uns Haftungsausschluss Sitemap
Chinesische PHP-Website:Online-PHP-Schulung für das Gemeinwohl,Helfen Sie PHP-Lernenden, sich schnell weiterzuentwickeln!