Cet article présente principalement l'utilisation détaillée de Tensorflow's Saver. Maintenant, je le partage avec vous et le donne comme référence. Venez jeter un oeil
Utilisation de Saver
Introduction générale à Saver
Nous souhaitons souvent sauvegarder les résultats de la formation après la formation d'un modèle. Ces résultats font référence aux paramètres du modèle pour la formation ou le test lors de l'itération suivante. Tensorflow fournit la classe Saver pour cette exigence.
La classe Saver fournit des méthodes associées pour enregistrer et restaurer les variables à partir des fichiers de points de contrôle. Le fichier de points de contrôle est un fichier binaire qui mappe les noms de variables aux valeurs de tenseur correspondantes.
Tant qu'un compteur est fourni, la classe Saver peut générer automatiquement un fichier de point de contrôle lorsque le compteur est déclenché. Cela nous permet de sauvegarder plusieurs résultats intermédiaires pendant l'entraînement. Par exemple, nous pouvons sauvegarder les résultats de chaque étape de formation.
Pour éviter de remplir tout le disque, Saver peut gérer automatiquement les fichiers Checkpoints. Par exemple, nous pouvons spécifier de sauvegarder les fichiers N Checkpoints les plus récents.
2. Instance Saver
Ce qui suit est un exemple de la façon d'utiliser la classe Saver
import tensorflow as tf
import numpy as np
x = tf.placeholder(tf.float32, shape=[None, 1])
y = 4 * x + 4
w = tf.Variable(tf.random_normal([1], -1, 1))
b = tf.Variable(tf.zeros([1]))
y_predict = w * x + b
loss = tf.reduce_mean(tf.square(y - y_predict))
optimizer = tf.train.GradientDescentOptimizer(0.5)
train = optimizer.minimize(loss)
isTrain = False
train_steps = 100
checkpoint_steps = 50
checkpoint_dir = ''
saver = tf.train.Saver() # defaults to saving all variables - in this case w and b
x_data = np.reshape(np.random.rand(10).astype(np.float32), (10, 1))
with tf.Session() as sess:
sess.run(tf.initialize_all_variables())
if isTrain:
for i in xrange(train_steps):
sess.run(train, feed_dict={x: x_data})
if (i + 1) % checkpoint_steps == 0:
saver.save(sess, checkpoint_dir + 'model.ckpt', global_step=i+1)
else:
ckpt = tf.train.get_checkpoint_state(checkpoint_dir)
if ckpt and ckpt.model_checkpoint_path:
saver.restore(sess, ckpt.model_checkpoint_path)
else:
pass
print(sess.run(w))
print(sess.run(b))
Copier après la connexion
- isTrain : utilisé pour distinguer la phase d'entraînement et la phase de test, True représente l'entraînement, False représente les tests
- train_steps : représente l'entraînement Le nombre de fois, l'exemple utilise 100
- checkpoint_steps : indique combien de fois enregistrer les points de contrôle pendant l'entraînement, l'exemple utilise 50
- checkpoint_dir : indique le fichier de points de contrôle Le chemin de sauvegarde, le chemin actuel est utilisé dans l'exemple
2.1 Phase d'entraînement
Utilisez la méthode Saver.save() pour enregistrer le modèle :
- sess : représente la session en cours, qui enregistre la valeur de la variable actuelle
- checkpoint_dir + 'model.ckpt' : représente le nom du fichier stocké
- global_step : indique l'étape en cours
Une fois la formation terminée, il y aura 5 fichiers supplémentaires dans le répertoire actuel .
Ouvrez le fichier nommé "checkpoint", vous pouvez voir l'enregistrement de sauvegarde et le dernier emplacement de stockage du modèle.
2.1 Phase de test
La phase de test utilise la méthode saver.restore() pour restaurer les variables :
sess : représente la session en cours, les résultats précédemment enregistrés seront chargés dans cette session
ckpt.model_checkpoint_path : Indique l'emplacement où le modèle est stocké. Il n'est pas nécessaire de fournir le nom du modèle. vérifiera le fichier de point de contrôle pour voir qui est le dernier, comment s'appelle-t-il.
Les résultats d'exécution sont présentés dans la figure ci-dessous, en chargeant les résultats des paramètres w et b précédemment entraînés
Recommandations associées :
Tensorflow utilise des indicateurs pour définir les paramètres de ligne de commande
Sauvegarde et restauration du modèle d'apprentissage Tensorflow1.0 (Saver)_python
Ce qui précède est le contenu détaillé de. pour plus d'informations, suivez d'autres articles connexes sur le site Web de PHP en chinois!