Heim > Backend-Entwicklung > Python-Tutorial > Erste Schritte mit TensorFlow und Verwendung von tf.train.Saver() zum Speichern des Modells

Erste Schritte mit TensorFlow und Verwendung von tf.train.Saver() zum Speichern des Modells

不言
Freigeben: 2018-04-24 14:15:07
Original
4141 Leute haben es durchsucht

In diesem Artikel wird hauptsächlich die Verwendung von tf.train.Saver() zum Speichern des Modells beim Einstieg in TensorFlow vorgestellt. Jetzt teile ich es mit Ihnen und gebe es als Referenz. Werfen wir gemeinsam einen Blick darauf

Einige Gedanken zum Modellspeichern

saver = tf.train.Saver(max_to_keep=3)
Nach dem Login kopieren

Bei der Definition des Sparers wird in der Regel die maximale Anzahl der gespeicherten Modelle festgelegt. Wenn das Modell selbst groß ist, müssen wir im Allgemeinen die Festplattengröße berücksichtigen. Wenn Sie eine Feinabstimmung auf der Grundlage des aktuell trainierten Modells durchführen müssen, speichern Sie so viele Modelle wie möglich. Die anschließende Feinabstimmung wird möglicherweise nicht unbedingt vom besten Cckpt aus durchgeführt, da es möglicherweise auf einmal überangepasst wird. Wenn Sie jedoch zu viele Dateien speichern, steht die Festplatte unter Druck. Wenn Sie nur das beste Modell behalten möchten, berechnen Sie die Genauigkeit oder den f1-Wert des Verifizierungssatzes bei jeder Iteration bis zu einer bestimmten Anzahl von Schritten. Wenn das Ergebnis dieses Mal besser ist als beim letzten Mal, speichern Sie das neue Andernfalls besteht keine Notwendigkeit, es zu speichern.

Wenn Sie in verschiedenen Epochen gespeicherte Modelle für die Fusion verwenden möchten, reichen 3 bis 5 Modelle aus. Nehmen Sie an, dass die fusionierten Modelle M werden und das beste Einzelmodell m_best heißt, sodass Fusion tatsächlich besser sein kann m_best für M. Wenn Sie dieses Modell jedoch mit Modellen anderer Strukturen fusionieren, ist der Effekt von M nicht so gut wie der von m_best, da M einer Durchschnittsoperation entspricht, was die „Eigenschaften“ des Modells verringert.

Aber es gibt eine neue Fusionsmethode, bei der die Lernrate angepasst wird, um mehrere lokale optimale Punkte zu erhalten. Das heißt, wenn der Verlust nicht reduziert werden kann, speichern Sie einen Cckpt und erhöhen Sie dann die Lernrate Finden Sie weiterhin den nächsten lokalen optimalen Punkt und verwenden Sie ihn dann für die Fusion. Das einzelne Modell wird definitiv verbessert, aber ich weiß nicht, ob es eine Situation geben wird, in der das oben genannte passiert Die Verbesserung wird in Kombination mit anderen Modellen nicht verbessert.

So verwenden Sie tf.train.Saver(), um das Modell zu speichern

Ich habe schon früher Fehler erhalten, hauptsächlich aufgrund von Cheat-Codierungsproblemen. Achten Sie daher darauf, dass der Dateipfad keine chinesischen Zeichen enthält.

import tensorflow as tf
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
sess = tf.Session(config=config)

# Create some variables.
v1 = tf.Variable([1.0, 2.3], name="v1")
v2 = tf.Variable(55.5, name="v2")

# Add an op to initialize the variables.
init_op = tf.global_variables_initializer()

# Add ops to save and restore all the variables.
saver = tf.train.Saver()

ckpt_path = './ckpt/test-model.ckpt'
# Later, launch the model, initialize the variables, do some work, save the
# variables to disk.
sess.run(init_op)
save_path = saver.save(sess, ckpt_path, global_step=1)
print("Model saved in file: %s" % save_path)
Nach dem Login kopieren

Modell in Datei gespeichert: ./ckpt/test-model.ckpt-1

Beachten Sie, dass nach dem Speichern des Modells oben. Sie sollten den Kernel neu starten, bevor Sie das folgende Modell zum Importieren verwenden. Andernfalls ist der Name falsch, wenn „v1“ zweimal genannt wird.

import tensorflow as tf
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
sess = tf.Session(config=config)

# Create some variables.
v1 = tf.Variable([11.0, 16.3], name="v1")
v2 = tf.Variable(33.5, name="v2")

# Add ops to save and restore all the variables.
saver = tf.train.Saver()

# Later, launch the model, use the saver to restore variables from disk, and
# do some work with the model.
# Restore variables from disk.
ckpt_path = './ckpt/test-model.ckpt'
saver.restore(sess, ckpt_path + '-'+ str(1))
print("Model restored.")

print sess.run(v1)
print sess.run(v2)
Nach dem Login kopieren

INFO:tensorflow:Parameter werden von ./ckpt/test-model.ckpt-1 wiederhergestellt
Modell wiederhergestellt.
[ 1.               2.29999995]
55.5

Vor dem Import des Modells müssen die Variablen neu definiert werden.

Aber es ist nicht notwendig, alle Variablen neu zu definieren, sondern nur die Variablen, die wir brauchen.

Mit anderen Worten: Die von Ihnen definierten Variablen müssen im Prüfpunkt vorhanden sein, aber nicht alle Variablen im Prüfpunkt müssen neu definiert werden.

import tensorflow as tf
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
sess = tf.Session(config=config)

# Create some variables.
v1 = tf.Variable([11.0, 16.3], name="v1")

# Add ops to save and restore all the variables.
saver = tf.train.Saver()

# Later, launch the model, use the saver to restore variables from disk, and
# do some work with the model.
# Restore variables from disk.
ckpt_path = './ckpt/test-model.ckpt'
saver.restore(sess, ckpt_path + '-'+ str(1))
print("Model restored.")

print sess.run(v1)
Nach dem Login kopieren

INFO:tensorflow:Parameter werden von ./ckpt/test-model.ckpt-1 wiederhergestellt
Modell wiederhergestellt.
[ 1.           2.29999995]

tf.Saver([tensors_to_be_saved]) Sie können eine Liste und die zu speichernden Tensoren übergeben. Wenn diese Liste nicht angegeben ist, wird sie standardmäßig verwendet Speichern Sie alle aktuellen Tensoren. Im Allgemeinen kann tf.Saver geschickt mit tf.variable_scope() kombiniert werden: [Lernen übertragen] Neue Variablen zu einem bereits gespeicherten Modell hinzufügen und optimieren

Verwandte Empfehlungen:

Über die tf.train.batch-Funktion in Tensorflow

Das obige ist der detaillierte Inhalt vonErste Schritte mit TensorFlow und Verwendung von tf.train.Saver() zum Speichern des Modells. Für weitere Informationen folgen Sie bitte anderen verwandten Artikeln auf der PHP chinesischen Website!

Verwandte Etiketten:
Quelle:php.cn
Erklärung dieser Website
Der Inhalt dieses Artikels wird freiwillig von Internetnutzern beigesteuert und das Urheberrecht liegt beim ursprünglichen Autor. Diese Website übernimmt keine entsprechende rechtliche Verantwortung. Wenn Sie Inhalte finden, bei denen der Verdacht eines Plagiats oder einer Rechtsverletzung besteht, wenden Sie sich bitte an admin@php.cn
Beliebte Tutorials
Mehr>
Neueste Downloads
Mehr>
Web-Effekte
Quellcode der Website
Website-Materialien
Frontend-Vorlage