Maison > développement back-end > Tutoriel Python > Sauvegarde et restauration du modèle appris par tensorflow1.0 (Saver)_python

Sauvegarde et restauration du modèle appris par tensorflow1.0 (Saver)_python

不言
Libérer: 2018-04-23 15:42:50
original
1857 Les gens l'ont consulté

Cet article présente principalement la sauvegarde et la récupération (Saver) du modèle d'apprentissage tensorflow1.0. Maintenant, je le partage avec vous et le donne comme référence. Jetons un coup d'œil ensemble

Enregistrons les paramètres du modèle entraîné pour une vérification ou des tests ultérieurs. C'est quelque chose que nous faisons souvent. Le module tf.train.Saver() qui permet la sauvegarde du modèle dans tf.

Pour enregistrer le modèle, vous devez d'abord créer un objet Saver : tel que

saver=tf.train.Saver()
Copier après la connexion

Lors de la création de cet objet Saver, il y a un paramètre we Le paramètre max_to_keep est souvent utilisé pour définir le nombre de modèles enregistrés. La valeur par défaut est 5, c'est-à-dire max_to_keep=5, qui enregistre les 5 derniers modèles. Si vous souhaitez enregistrer le modèle à chaque génération (époque) d'entraînement, vous pouvez définir max_to_keep sur Aucun ou 0, par exemple :

saver=tf.train.Saver(max_to_keep=0)
Copier après la connexion

Mais comme ça En plus d’occuper plus de disque dur, il n’a aucune utilité pratique, il n’est donc pas recommandé.

Bien sûr, si vous souhaitez uniquement enregistrer le modèle de dernière génération, il vous suffit de définir max_to_keep sur 1, c'est-à-dire

saver=tf.train.Saver(max_to_keep=1)
Copier après la connexion

Après avoir créé l'objet économiseur, vous pouvez enregistrer le modèle entraîné, tel que :

saver.save(sess,'ckpt/mnist.ckpt',global_step=step)
Copier après la connexion

La première session de paramètres, il va sans dire. Le deuxième paramètre définit le chemin et le nom enregistrés, et le troisième paramètre ajoute le nombre de temps de formation comme suffixe au nom du modèle.

saver.save(sess, 'mon-modèle', global_step=0) ==> nom de fichier : 'mon-modèle-0'
...
saver.save(sess, 'my-model', global_step=1000) ==> nom de fichier : 'my-model-1000'

Regardez un exemple de mnist :

# -*- coding: utf-8 -*-
"""
Created on Sun Jun 4 10:29:48 2017

@author: Administrator
"""
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("MNIST_data/", one_hot=False)

x = tf.placeholder(tf.float32, [None, 784])
y_=tf.placeholder(tf.int32,[None,])

dense1 = tf.layers.dense(inputs=x, 
           units=1024, 
           activation=tf.nn.relu,
           kernel_initializer=tf.truncated_normal_initializer(stddev=0.01),
           kernel_regularizer=tf.nn.l2_loss)
dense2= tf.layers.dense(inputs=dense1, 
           units=512, 
           activation=tf.nn.relu,
           kernel_initializer=tf.truncated_normal_initializer(stddev=0.01),
           kernel_regularizer=tf.nn.l2_loss)
logits= tf.layers.dense(inputs=dense2, 
            units=10, 
            activation=None,
            kernel_initializer=tf.truncated_normal_initializer(stddev=0.01),
            kernel_regularizer=tf.nn.l2_loss)

loss=tf.losses.sparse_softmax_cross_entropy(labels=y_,logits=logits)
train_op=tf.train.AdamOptimizer(learning_rate=0.001).minimize(loss)
correct_prediction = tf.equal(tf.cast(tf.argmax(logits,1),tf.int32), y_)  
acc= tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

sess=tf.InteractiveSession() 
sess.run(tf.global_variables_initializer())

saver=tf.train.Saver(max_to_keep=1)
for i in range(100):
 batch_xs, batch_ys = mnist.train.next_batch(100)
 sess.run(train_op, feed_dict={x: batch_xs, y_: batch_ys})
 val_loss,val_acc=sess.run([loss,acc], feed_dict={x: mnist.test.images, y_: mnist.test.labels})
 print('epoch:%d, val_loss:%f, val_acc:%f'%(i,val_loss,val_acc))
 saver.save(sess,'ckpt/mnist.ckpt',global_step=i+1)
sess.close()
Copier après la connexion

La partie rouge dans le code est le code pour sauvegarder le modèle. Bien que je le sauvegarde après chaque génération d'entraînement, le modèle enregistré la prochaine fois écrasera le précédent, et seule la dernière fois sera enregistrée. . Par conséquent, nous pouvons gagner du temps et mettre le code de sauvegarde en dehors de la boucle (ne s'applique qu'à max_to_keep=1, sinon il doit quand même être placé à l'intérieur de la boucle).

Dans l'expérience, la dernière génération peut ne pas être la génération avec la précision de vérification la plus élevée, nous ne voulons donc pas enregistrer la dernière génération par défaut, mais nous voulons enregistrer la génération avec la précision de vérification la plus élevée, il suffit donc d'ajouter une variable intermédiaire et une déclaration de jugement.

saver=tf.train.Saver(max_to_keep=1)
max_acc=0
for i in range(100):
 batch_xs, batch_ys = mnist.train.next_batch(100)
 sess.run(train_op, feed_dict={x: batch_xs, y_: batch_ys})
 val_loss,val_acc=sess.run([loss,acc], feed_dict={x: mnist.test.images, y_: mnist.test.labels})
 print('epoch:%d, val_loss:%f, val_acc:%f'%(i,val_loss,val_acc))
 if val_acc>max_acc:
   max_acc=val_acc
   saver.save(sess,'ckpt/mnist.ckpt',global_step=i+1)
sess.close()
Copier après la connexion

Si nous voulons enregistrer les trois générations avec la précision de vérification la plus élevée, et également enregistrer la précision de vérification de chaque fois, nous pouvons générer un txt fichier à sauvegarder.

saver=tf.train.Saver(max_to_keep=3)
max_acc=0
f=open('ckpt/acc.txt','w')
for i in range(100):
 batch_xs, batch_ys = mnist.train.next_batch(100)
 sess.run(train_op, feed_dict={x: batch_xs, y_: batch_ys})
 val_loss,val_acc=sess.run([loss,acc], feed_dict={x: mnist.test.images, y_: mnist.test.labels})
 print('epoch:%d, val_loss:%f, val_acc:%f'%(i,val_loss,val_acc))
 f.write(str(i+1)+', val_acc: '+str(val_acc)+'\n')
 if val_acc>max_acc:
   max_acc=val_acc
   saver.save(sess,'ckpt/mnist.ckpt',global_step=i+1)
f.close()
sess.close()
Copier après la connexion

Le modèle est restauré à l'aide de la fonction restaurer(), qui nécessite deux paramètres de restauration (sess, save_path), save_path fait référence au chemin du modèle enregistré . Nous pouvons utiliser tf.train.latest_checkpoint() pour obtenir automatiquement le dernier modèle enregistré. Par exemple :

model_file=tf.train.latest_checkpoint('ckpt/')
saver.restore(sess,model_file)
Copier après la connexion

Ensuite, nous pouvons changer la seconde moitié du programme en :

sess=tf.InteractiveSession() 
sess.run(tf.global_variables_initializer())
is_train=False
saver=tf.train.Saver(max_to_keep=3)

#训练阶段
if is_train:
  max_acc=0
  f=open('ckpt/acc.txt','w')
  for i in range(100):
   batch_xs, batch_ys = mnist.train.next_batch(100)
   sess.run(train_op, feed_dict={x: batch_xs, y_: batch_ys})
   val_loss,val_acc=sess.run([loss,acc], feed_dict={x: mnist.test.images, y_: mnist.test.labels})
   print('epoch:%d, val_loss:%f, val_acc:%f'%(i,val_loss,val_acc))
   f.write(str(i+1)+', val_acc: '+str(val_acc)+'\n')
   if val_acc>max_acc:
     max_acc=val_acc
     saver.save(sess,'ckpt/mnist.ckpt',global_step=i+1)
  f.close()

#验证阶段
else:
  model_file=tf.train.latest_checkpoint('ckpt/')
  saver.restore(sess,model_file)
  val_loss,val_acc=sess.run([loss,acc], feed_dict={x: mnist.test.images, y_: mnist.test.labels})
  print('val_loss:%f, val_acc:%f'%(val_loss,val_acc))
sess.close()
Copier après la connexion

La zone marquée en rouge est le code lié à la sauvegarde et à la restauration du modèle. Utilisez une variable booléenne is_train pour contrôler les phases de formation et de vérification.

Programme source complet :

# -*- coding: utf-8 -*-
"""
Created on Sun Jun 4 10:29:48 2017

@author: Administrator
"""
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("MNIST_data/", one_hot=False)

x = tf.placeholder(tf.float32, [None, 784])
y_=tf.placeholder(tf.int32,[None,])

dense1 = tf.layers.dense(inputs=x, 
           units=1024, 
           activation=tf.nn.relu,
           kernel_initializer=tf.truncated_normal_initializer(stddev=0.01),
           kernel_regularizer=tf.nn.l2_loss)
dense2= tf.layers.dense(inputs=dense1, 
           units=512, 
           activation=tf.nn.relu,
           kernel_initializer=tf.truncated_normal_initializer(stddev=0.01),
           kernel_regularizer=tf.nn.l2_loss)
logits= tf.layers.dense(inputs=dense2, 
            units=10, 
            activation=None,
            kernel_initializer=tf.truncated_normal_initializer(stddev=0.01),
            kernel_regularizer=tf.nn.l2_loss)

loss=tf.losses.sparse_softmax_cross_entropy(labels=y_,logits=logits)
train_op=tf.train.AdamOptimizer(learning_rate=0.001).minimize(loss)
correct_prediction = tf.equal(tf.cast(tf.argmax(logits,1),tf.int32), y_)  
acc= tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

sess=tf.InteractiveSession() 
sess.run(tf.global_variables_initializer())

is_train=True
saver=tf.train.Saver(max_to_keep=3)

#训练阶段
if is_train:
  max_acc=0
  f=open('ckpt/acc.txt','w')
  for i in range(100):
   batch_xs, batch_ys = mnist.train.next_batch(100)
   sess.run(train_op, feed_dict={x: batch_xs, y_: batch_ys})
   val_loss,val_acc=sess.run([loss,acc], feed_dict={x: mnist.test.images, y_: mnist.test.labels})
   print('epoch:%d, val_loss:%f, val_acc:%f'%(i,val_loss,val_acc))
   f.write(str(i+1)+', val_acc: '+str(val_acc)+'\n')
   if val_acc>max_acc:
     max_acc=val_acc
     saver.save(sess,'ckpt/mnist.ckpt',global_step=i+1)
  f.close()

#验证阶段
else:
  model_file=tf.train.latest_checkpoint('ckpt/')
  saver.restore(sess,model_file)
  val_loss,val_acc=sess.run([loss,acc], feed_dict={x: mnist.test.images, y_: mnist.test.labels})
  print('val_loss:%f, val_acc:%f'%(val_loss,val_acc))
sess.close()
Copier après la connexion

Recommandations associées :

Exporter le réseau de modèles TensorFlow en tant que fichier unique méthode

Ce qui précède est le contenu détaillé de. pour plus d'informations, suivez d'autres articles connexes sur le site Web de PHP en chinois!

Étiquettes associées:
source:php.cn
Déclaration de ce site Web
Le contenu de cet article est volontairement contribué par les internautes et les droits d'auteur appartiennent à l'auteur original. Ce site n'assume aucune responsabilité légale correspondante. Si vous trouvez un contenu suspecté de plagiat ou de contrefaçon, veuillez contacter admin@php.cn
Tutoriels populaires
Plus>
Derniers téléchargements
Plus>
effets Web
Code source du site Web
Matériel du site Web
Modèle frontal