Maison > développement back-end > Tutoriel Python > Comment utiliser l'économiseur de Tensorflow

Comment utiliser l'économiseur de Tensorflow

不言
Libérer: 2018-04-23 15:46:31
original
2027 Les gens l'ont consulté

Cet article présente principalement l'utilisation détaillée de Tensorflow's Saver. Maintenant, je le partage avec vous et le donne comme référence. Venez jeter un oeil

Utilisation de Saver

Introduction générale à Saver

Nous souhaitons souvent sauvegarder les résultats de la formation après la formation d'un modèle. Ces résultats font référence aux paramètres du modèle pour la formation ou le test lors de l'itération suivante. Tensorflow fournit la classe Saver pour cette exigence.


La classe Saver fournit des méthodes associées pour enregistrer et restaurer les variables à partir des fichiers de points de contrôle. Le fichier de points de contrôle est un fichier binaire qui mappe les noms de variables aux valeurs de tenseur correspondantes.


Tant qu'un compteur est fourni, la classe Saver peut générer automatiquement un fichier de point de contrôle lorsque le compteur est déclenché. Cela nous permet de sauvegarder plusieurs résultats intermédiaires pendant l'entraînement. Par exemple, nous pouvons sauvegarder les résultats de chaque étape de formation.


Pour éviter de remplir tout le disque, Saver peut gérer automatiquement les fichiers Checkpoints. Par exemple, nous pouvons spécifier de sauvegarder les fichiers N Checkpoints les plus récents.


2. Instance Saver

Ce qui suit est un exemple de la façon d'utiliser la classe Saver


import tensorflow as tf 
import numpy as np  
x = tf.placeholder(tf.float32, shape=[None, 1]) 
y = 4 * x + 4  
w = tf.Variable(tf.random_normal([1], -1, 1)) 
b = tf.Variable(tf.zeros([1])) 
y_predict = w * x + b 
loss = tf.reduce_mean(tf.square(y - y_predict)) 
optimizer = tf.train.GradientDescentOptimizer(0.5) 
train = optimizer.minimize(loss)  
isTrain = False 
train_steps = 100 
checkpoint_steps = 50 
checkpoint_dir = ''  
saver = tf.train.Saver() # defaults to saving all variables - in this case w and b 
x_data = np.reshape(np.random.rand(10).astype(np.float32), (10, 1))  
with tf.Session() as sess: 
  sess.run(tf.initialize_all_variables()) 
  if isTrain: 
    for i in xrange(train_steps): 
      sess.run(train, feed_dict={x: x_data}) 
      if (i + 1) % checkpoint_steps == 0: 
        saver.save(sess, checkpoint_dir + 'model.ckpt', global_step=i+1) 
  else: 
    ckpt = tf.train.get_checkpoint_state(checkpoint_dir) 
    if ckpt and ckpt.model_checkpoint_path: 
      saver.restore(sess, ckpt.model_checkpoint_path) 
    else: 
      pass 
    print(sess.run(w)) 
    print(sess.run(b))
Copier après la connexion

  1. isTrain : utilisé pour distinguer la phase d'entraînement et la phase de test, True représente l'entraînement, False représente les tests

  2. train_steps : représente l'entraînement Le nombre de fois, l'exemple utilise 100

  3. checkpoint_steps : indique combien de fois enregistrer les points de contrôle pendant l'entraînement, l'exemple utilise 50

  4. checkpoint_dir : indique le fichier de points de contrôle Le chemin de sauvegarde, le chemin actuel est utilisé dans l'exemple

2.1 Phase d'entraînement

Utilisez la méthode Saver.save() pour enregistrer le modèle :

  1. sess : représente la session en cours, qui enregistre la valeur de la variable actuelle

  2. checkpoint_dir + 'model.ckpt' : représente le nom du fichier stocké

  3. global_step : indique l'étape en cours

Une fois la formation terminée, il y aura 5 fichiers supplémentaires dans le répertoire actuel .

Ouvrez le fichier nommé "checkpoint", vous pouvez voir l'enregistrement de sauvegarde et le dernier emplacement de stockage du modèle.

2.1 Phase de test

La phase de test utilise la méthode saver.restore() pour restaurer les variables :


sess : représente la session en cours, les résultats précédemment enregistrés seront chargés dans cette session


ckpt.model_checkpoint_path : Indique l'emplacement où le modèle est stocké. Il n'est pas nécessaire de fournir le nom du modèle. vérifiera le fichier de point de contrôle pour voir qui est le dernier, comment s'appelle-t-il.


Les résultats d'exécution sont présentés dans la figure ci-dessous, en chargeant les résultats des paramètres w et b précédemment entraînés


Recommandations associées :


Tensorflow utilise des indicateurs pour définir les paramètres de ligne de commande

Sauvegarde et restauration du modèle d'apprentissage Tensorflow1.0 (Saver)_python

Ce qui précède est le contenu détaillé de. pour plus d'informations, suivez d'autres articles connexes sur le site Web de PHP en chinois!

Étiquettes associées:
source:php.cn
Déclaration de ce site Web
Le contenu de cet article est volontairement contribué par les internautes et les droits d'auteur appartiennent à l'auteur original. Ce site n'assume aucune responsabilité légale correspondante. Si vous trouvez un contenu suspecté de plagiat ou de contrefaçon, veuillez contacter admin@php.cn
Tutoriels populaires
Plus>
Derniers téléchargements
Plus>
effets Web
Code source du site Web
Matériel du site Web
Modèle frontal