Mengapakah penalaan halus model MLP pada set data kecil masih mengekalkan ketepatan ujian yang sama seperti pemberat pra-latihan?

WBOY
Lepaskan: 2024-02-10 21:36:04
ke hadapan
568 orang telah melayarinya

为什么在小数据集上微调 MLP 模型,仍然保持与预训练权重相同的测试精度?

Kandungan soalan

Saya mereka bentuk model mlp mudah untuk melatih 6k sampel data.

class mlp(nn.module):
    def __init__(self,input_dim=92, hidden_dim = 150, num_classes=2):
        super().__init__()
        self.input_dim = input_dim
        self.num_classes = num_classes
        self.hidden_dim = hidden_dim
        #self.softmax = nn.softmax(dim=1)

        self.layers = nn.sequential(
            nn.linear(self.input_dim, self.hidden_dim),
            nn.relu(),
            nn.linear(self.hidden_dim, self.hidden_dim),
            nn.relu(),
            nn.linear(self.hidden_dim, self.hidden_dim),
            nn.relu(),
            nn.linear(self.hidden_dim, self.num_classes),

        )

    def forward(self, x):
        x = self.layers(x)
        return x
Salin selepas log masuk

dan model dibuat instantiated

model = mlp(input_dim=input_dim, hidden_dim=hidden_dim, num_classes=num_classes).to(device)

optimizer = optimizer.adam(model.parameters(), lr=learning_rate, weight_decay=1e-4)
criterion = nn.crossentropyloss()
Salin selepas log masuk

dan hiperparameter:

num_epoch = 300   # 200e3//len(train_loader)
learning_rate = 1e-3
batch_size = 64
device = torch.device("cuda")
seed = 42
torch.manual_seed(42)
Salin selepas log masuk

Pelaksanaan saya terutamanya mengikut soalan ini. Saya simpan model sebagai pemberat pralatihan model_weights.pth.

model在测试数据集上的准确率是96.80%.

Kemudian, saya mempunyai 50 sampel lagi (dalam finetune_loader) yang mana saya cuba memperhalusi model:

model_finetune = MLP()
model_finetune.load_state_dict(torch.load('model_weights.pth'))
model_finetune.to(device)
model_finetune.train()
# train the network
for t in tqdm(range(num_epoch)):
  for i, data in enumerate(finetune_loader, 0):
    #def closure():
      # Get and prepare inputs
      inputs, targets = data
      inputs, targets = inputs.float(), targets.long()
      inputs, targets = inputs.to(device), targets.to(device)
      
      # Zero the gradients
      optimizer.zero_grad()
      # Perform forward pass
      outputs = model_finetune(inputs)
      # Compute loss
      loss = criterion(outputs, targets)
      # Perform backward pass
      loss.backward()
      #return loss
      optimizer.step()     # a

model_finetune.eval()
with torch.no_grad():
    outputs2 = model_finetune(test_data)
    #predicted_labels = outputs.squeeze().tolist()

    _, preds = torch.max(outputs2, 1)
    prediction_test = np.array(preds.cpu())
    accuracy_test_finetune = accuracy_score(y_test, prediction_test)
    accuracy_test_finetune
    
    Output: 0.9680851063829787
Salin selepas log masuk

Saya semak, ketepatan tetap sama seperti sebelum memperhalusi model kepada 50 sampel, dan kebarangkalian keluarannya juga sama.

Apakah sebabnya? Adakah saya membuat beberapa kesilapan dalam memperhalusi kod?


Jawapan betul


Anda mesti memulakan semula pengoptimum dengan model baharu (objek model_finetune). Pada masa ini, seperti yang saya dapat lihat dalam kod anda, nampaknya masih menggunakan pengoptimum yang dimulakan dengan berat model lama - model.parameters().

Atas ialah kandungan terperinci Mengapakah penalaan halus model MLP pada set data kecil masih mengekalkan ketepatan ujian yang sama seperti pemberat pra-latihan?. Untuk maklumat lanjut, sila ikut artikel berkaitan lain di laman web China PHP!

sumber:stackoverflow.com
Kenyataan Laman Web ini
Kandungan artikel ini disumbangkan secara sukarela oleh netizen, dan hak cipta adalah milik pengarang asal. Laman web ini tidak memikul tanggungjawab undang-undang yang sepadan. Jika anda menemui sebarang kandungan yang disyaki plagiarisme atau pelanggaran, sila hubungi admin@php.cn
Tutorial Popular
Lagi>
Muat turun terkini
Lagi>
kesan web
Kod sumber laman web
Bahan laman web
Templat hujung hadapan
Tentang kita Penafian Sitemap
Laman web PHP Cina:Latihan PHP dalam talian kebajikan awam,Bantu pelajar PHP berkembang dengan cepat!