


Warum behält die Feinabstimmung eines MLP-Modells anhand eines kleinen Datensatzes immer noch die gleiche Testgenauigkeit bei wie vorab trainierte Gewichte?
Ich habe ein einfaches MLP-Modell entworfen, um anhand von 6.000 Datenproben zu trainieren.
class mlp(nn.module): def __init__(self,input_dim=92, hidden_dim = 150, num_classes=2): super().__init__() self.input_dim = input_dim self.num_classes = num_classes self.hidden_dim = hidden_dim #self.softmax = nn.softmax(dim=1) self.layers = nn.sequential( nn.linear(self.input_dim, self.hidden_dim), nn.relu(), nn.linear(self.hidden_dim, self.hidden_dim), nn.relu(), nn.linear(self.hidden_dim, self.hidden_dim), nn.relu(), nn.linear(self.hidden_dim, self.num_classes), ) def forward(self, x): x = self.layers(x) return x
und das Modell wird instanziiert
model = mlp(input_dim=input_dim, hidden_dim=hidden_dim, num_classes=num_classes).to(device) optimizer = optimizer.adam(model.parameters(), lr=learning_rate, weight_decay=1e-4) criterion = nn.crossentropyloss()
und Hyperparameter:
num_epoch = 300 # 200e3//len(train_loader) learning_rate = 1e-3 batch_size = 64 device = torch.device("cuda") seed = 42 torch.manual_seed(42)
Meine Implementierung folgt hauptsächlich dieser Frage. Ich speichere das Modell als vortrainierte Gewichte model_weights.pth
.
model
在测试数据集上的准确率是96.80%
.
Dann habe ich weitere 50 Proben (in finetune_loader
), an denen ich versuche, das Modell zu verfeinern:
model_finetune = MLP() model_finetune.load_state_dict(torch.load('model_weights.pth')) model_finetune.to(device) model_finetune.train() # train the network for t in tqdm(range(num_epoch)): for i, data in enumerate(finetune_loader, 0): #def closure(): # Get and prepare inputs inputs, targets = data inputs, targets = inputs.float(), targets.long() inputs, targets = inputs.to(device), targets.to(device) # Zero the gradients optimizer.zero_grad() # Perform forward pass outputs = model_finetune(inputs) # Compute loss loss = criterion(outputs, targets) # Perform backward pass loss.backward() #return loss optimizer.step() # a model_finetune.eval() with torch.no_grad(): outputs2 = model_finetune(test_data) #predicted_labels = outputs.squeeze().tolist() _, preds = torch.max(outputs2, 1) prediction_test = np.array(preds.cpu()) accuracy_test_finetune = accuracy_score(y_test, prediction_test) accuracy_test_finetune Output: 0.9680851063829787
Ich habe es überprüft, die Genauigkeit bleibt dieselbe wie vor der Feinabstimmung des Modells auf 50 Stichproben, und auch die Ausgabewahrscheinlichkeiten sind dieselben.
Was könnte der Grund sein? Habe ich bei der Feinabstimmung des Codes Fehler gemacht?
Richtige Antwort
Sie müssen den Optimierer mit einem neuen Modell (model_finetune-Objekt) neu initialisieren. Derzeit scheint es, wie ich in Ihrem Code sehen kann, immer noch den Optimierer zu verwenden, der mit den alten Modellgewichten initialisiert wurde – model.parameters().
Das obige ist der detaillierte Inhalt vonWarum behält die Feinabstimmung eines MLP-Modells anhand eines kleinen Datensatzes immer noch die gleiche Testgenauigkeit bei wie vorab trainierte Gewichte?. Für weitere Informationen folgen Sie bitte anderen verwandten Artikeln auf der PHP chinesischen Website!

Heiße KI -Werkzeuge

Undresser.AI Undress
KI-gestützte App zum Erstellen realistischer Aktfotos

AI Clothes Remover
Online-KI-Tool zum Entfernen von Kleidung aus Fotos.

Undress AI Tool
Ausziehbilder kostenlos

Clothoff.io
KI-Kleiderentferner

AI Hentai Generator
Erstellen Sie kostenlos Ai Hentai.

Heißer Artikel

Heiße Werkzeuge

Notepad++7.3.1
Einfach zu bedienender und kostenloser Code-Editor

SublimeText3 chinesische Version
Chinesische Version, sehr einfach zu bedienen

Senden Sie Studio 13.0.1
Leistungsstarke integrierte PHP-Entwicklungsumgebung

Dreamweaver CS6
Visuelle Webentwicklungstools

SublimeText3 Mac-Version
Codebearbeitungssoftware auf Gottesniveau (SublimeText3)

Heiße Themen

Dieses Tutorial zeigt, wie man Python verwendet, um das statistische Konzept des Zipf -Gesetzes zu verarbeiten, und zeigt die Effizienz des Lesens und Sortierens großer Textdateien von Python bei der Bearbeitung des Gesetzes. Möglicherweise fragen Sie sich, was der Begriff ZiPF -Verteilung bedeutet. Um diesen Begriff zu verstehen, müssen wir zunächst das Zipf -Gesetz definieren. Mach dir keine Sorgen, ich werde versuchen, die Anweisungen zu vereinfachen. Zipf -Gesetz Das Zipf -Gesetz bedeutet einfach: In einem großen natürlichen Sprachkorpus erscheinen die am häufigsten vorkommenden Wörter ungefähr doppelt so häufig wie die zweiten häufigen Wörter, dreimal wie die dritten häufigen Wörter, viermal wie die vierten häufigen Wörter und so weiter. Schauen wir uns ein Beispiel an. Wenn Sie sich den Brown Corpus in amerikanischem Englisch ansehen, werden Sie feststellen, dass das häufigste Wort "Th ist

In diesem Artikel wird erklärt, wie man schöne Suppe, eine Python -Bibliothek, verwendet, um HTML zu analysieren. Es beschreibt gemeinsame Methoden wie find (), find_all (), select () und get_text () für die Datenextraktion, die Behandlung verschiedener HTML -Strukturen und -Anternativen (SEL)

Der Umgang mit lauten Bildern ist ein häufiges Problem, insbesondere bei Mobiltelefonen oder mit geringen Auflösungskamera-Fotos. In diesem Tutorial wird die Bildfilterungstechniken in Python unter Verwendung von OpenCV untersucht, um dieses Problem anzugehen. Bildfilterung: Ein leistungsfähiges Werkzeug Bildfilter

PDF-Dateien sind für ihre plattformübergreifende Kompatibilität beliebt, wobei Inhalte und Layout für Betriebssysteme, Lesegeräte und Software konsistent sind. Im Gegensatz zu Python Processing -Klartextdateien sind PDF -Dateien jedoch binäre Dateien mit komplexeren Strukturen und enthalten Elemente wie Schriftarten, Farben und Bilder. Glücklicherweise ist es nicht schwierig, PDF -Dateien mit Pythons externen Modulen zu verarbeiten. In diesem Artikel wird das PYPDF2 -Modul verwendet, um zu demonstrieren, wie Sie eine PDF -Datei öffnen, eine Seite ausdrucken und Text extrahieren. Die Erstellung und Bearbeitung von PDF -Dateien finden Sie in einem weiteren Tutorial von mir. Vorbereitung Der Kern liegt in der Verwendung von externem Modul PYPDF2. Installieren Sie es zunächst mit PIP: pip ist p

Dieses Tutorial zeigt, wie man Redis Caching nutzt, um die Leistung von Python -Anwendungen zu steigern, insbesondere innerhalb eines Django -Frameworks. Wir werden Redis -Installation, Django -Konfiguration und Leistungsvergleiche abdecken, um den Vorteil hervorzuheben

Dieser Artikel vergleicht TensorFlow und Pytorch für Deep Learning. Es beschreibt die beteiligten Schritte: Datenvorbereitung, Modellbildung, Schulung, Bewertung und Bereitstellung. Wichtige Unterschiede zwischen den Frameworks, insbesondere bezüglich des rechnerischen Graps

Python, ein Favorit für Datenwissenschaft und Verarbeitung, bietet ein reichhaltiges Ökosystem für Hochleistungs-Computing. Die parallele Programmierung in Python stellt jedoch einzigartige Herausforderungen dar. Dieses Tutorial untersucht diese Herausforderungen und konzentriert sich auf die globale Interprete

Dieses Tutorial zeigt, dass eine benutzerdefinierte Pipeline -Datenstruktur in Python 3 erstellt wird, wobei Klassen und Bedienerüberladungen für verbesserte Funktionen genutzt werden. Die Flexibilität der Pipeline liegt in ihrer Fähigkeit, eine Reihe von Funktionen auf einen Datensatz GE anzuwenden
