This article mainly introduces the detailed usage of Tensorflow's Saver. Now I will share it with you and give you a reference. Let’s come and take a look
How to use Saver
1. Background introduction to Saver
We often want to save the training results after training a model. These results refer to the parameters of the model for training in the next iteration or for testing. Tensorflow provides the Saver class for this requirement.
The Saver class provides related methods for saving to checkpoints files and restoring variables from checkpoints files. The checkpoints file is a binary file that maps variable names to corresponding tensor values.
As long as a counter is provided, the Saver class can automatically generate a checkpoint file when the counter is triggered. This allows us to save multiple intermediate results during training. For example, we can save the results of each training step.
To avoid filling up the entire disk, Saver can automatically manage Checkpoints files. For example, we can specify to save the most recent N Checkpoints files.
2. Saver instance
The following is an example to describe how to use the Saver class
import tensorflow as tf import numpy as np x = tf.placeholder(tf.float32, shape=[None, 1]) y = 4 * x + 4 w = tf.Variable(tf.random_normal([1], -1, 1)) b = tf.Variable(tf.zeros([1])) y_predict = w * x + b loss = tf.reduce_mean(tf.square(y - y_predict)) optimizer = tf.train.GradientDescentOptimizer(0.5) train = optimizer.minimize(loss) isTrain = False train_steps = 100 checkpoint_steps = 50 checkpoint_dir = '' saver = tf.train.Saver() # defaults to saving all variables - in this case w and b x_data = np.reshape(np.random.rand(10).astype(np.float32), (10, 1)) with tf.Session() as sess: sess.run(tf.initialize_all_variables()) if isTrain: for i in xrange(train_steps): sess.run(train, feed_dict={x: x_data}) if (i + 1) % checkpoint_steps == 0: saver.save(sess, checkpoint_dir + 'model.ckpt', global_step=i+1) else: ckpt = tf.train.get_checkpoint_state(checkpoint_dir) if ckpt and ckpt.model_checkpoint_path: saver.restore(sess, ckpt.model_checkpoint_path) else: pass print(sess.run(w)) print(sess.run(b))
isTrain: used to distinguish the training phase and the testing phase, True represents training, False represents testing
train_steps: represents the number of training times , 100
checkpoint_steps is used in the example: indicates how many times to save checkpoints during training, 50
checkpoint_dir is used in the example: checkpoints file is saved Path, the current path is used in the example
2.1 Training phase
Use the Saver.save() method to save the model:
sess: indicates the current session, which records the current variable value
checkpoint_dir 'model.ckpt': indicates the stored file name
global_step: Indicates the current step
After the training is completed, there will be 5 more files in the current directory.
Open the file named "checkpoint", you can see the save record and the latest model storage location.
2.1 Test Phase
The saver.restore() method is used to restore variables during the test phase:
sess: represents the current session , the previously saved results will be loaded into this session
ckpt.model_checkpoint_path: Indicates the location where the model is stored. There is no need to provide the name of the model. It will check the checkpoint file to see who is the latest. , what is it called.
The running results are shown in the figure below, loading the results of the previously trained parameters w and b
Related recommendations:
tensorflow How to use flags to define command line parameters
Save and restore the model learned by tensorflow1.0 (Saver)_python
The above is the detailed content of How to use Saver in Tensorflow. For more information, please follow other related articles on the PHP Chinese website!