Saving and Restoring Trained TensorFlow Models
TensorFlow provides seamless capabilities for saving and restoring trained models, allowing you to persist and reuse your models in various scenarios.
Saving the Model
To save a trained model in TensorFlow, you can use the tf.train.Saver class. Here's an example:
import tensorflow as tf # Prepare placeholders and variables w1 = tf.placeholder(tf.float32, name="w1") w2 = tf.placeholder(tf.float32, name="w2") b1 = tf.Variable(2.0, name="bias") feed_dict = {w1: 4, w2: 8} # Define an operation to be restored w3 = tf.add(w1, w2) w4 = tf.multiply(w3, b1, name="op_to_restore") sess = tf.Session() sess.run(tf.global_variables_initializer()) # Create a saver object saver = tf.train.Saver() # Run the operation and save the graph print(sess.run(w4, feed_dict)) saver.save(sess, 'my_test_model', global_step=1000)
Restoring the Model
To restore a previously saved model, you can use the following process:
import tensorflow as tf sess = tf.Session() # Load the meta graph and restore weights saver = tf.train.import_meta_graph('my_test_model-1000.meta') saver.restore(sess, tf.train.latest_checkpoint('./')) # Access saved variables directly print(sess.run('bias:0')) # Prints 2 (the bias value) # Access and create feed-dict for new input data graph = tf.get_default_graph() w1 = graph.get_tensor_by_name("w1:0") w2 = graph.get_tensor_by_name("w2:0") feed_dict = {w1: 13.0, w2: 17.0} # Access the desired operation op_to_restore = graph.get_tensor_by_name("op_to_restore:0") print(sess.run(op_to_restore, feed_dict)) # Prints 60 ((w1 + w2) * b1)
For additional scenarios and use cases, refer to the resources provided in the provided answers, which delve deeper into saving and restoring TensorFlow models.
The above is the detailed content of How Can I Save and Restore Trained TensorFlow Models?. For more information, please follow other related articles on the PHP Chinese website!