首页 > 后端开发 > Python教程 > 如何保存和恢复经过训练的 TensorFlow 模型?

如何保存和恢复经过训练的 TensorFlow 模型?

DDD
发布: 2024-12-19 17:41:09
原创
636 人浏览过

How Can I Save and Restore Trained TensorFlow Models?

保存和恢复经过训练的 TensorFlow 模型

TensorFlow 提供了保存和恢复经过训练的模型的无缝功能,允许您在以下环境中保存和重用您的模型各种场景。

保存模型

要在 TensorFlow 中保存经过训练的模型,您可以使用 tf.train.Saver 类。下面是一个示例:

import tensorflow as tf

# Prepare placeholders and variables
w1 = tf.placeholder(tf.float32, name="w1")
w2 = tf.placeholder(tf.float32, name="w2")
b1 = tf.Variable(2.0, name="bias")
feed_dict = {w1: 4, w2: 8}

# Define an operation to be restored
w3 = tf.add(w1, w2)
w4 = tf.multiply(w3, b1, name="op_to_restore")
sess = tf.Session()
sess.run(tf.global_variables_initializer())

# Create a saver object
saver = tf.train.Saver()

# Run the operation and save the graph
print(sess.run(w4, feed_dict))
saver.save(sess, 'my_test_model', global_step=1000)
登录后复制

恢复模型

要恢复以前保存的模型,您可以使用以下过程:

import tensorflow as tf

sess = tf.Session()

# Load the meta graph and restore weights
saver = tf.train.import_meta_graph('my_test_model-1000.meta')
saver.restore(sess, tf.train.latest_checkpoint('./'))

# Access saved variables directly
print(sess.run('bias:0'))  # Prints 2 (the bias value)

# Access and create feed-dict for new input data
graph = tf.get_default_graph()
w1 = graph.get_tensor_by_name("w1:0")
w2 = graph.get_tensor_by_name("w2:0")
feed_dict = {w1: 13.0, w2: 17.0}

# Access the desired operation
op_to_restore = graph.get_tensor_by_name("op_to_restore:0")

print(sess.run(op_to_restore, feed_dict))  # Prints 60 ((w1 + w2) * b1)
登录后复制

有关其他场景和用例,请参阅所提供答案中提供的资源,其中深入探讨了保存和恢复 TensorFlow模型。

以上是如何保存和恢复经过训练的 TensorFlow 模型?的详细内容。更多信息请关注PHP中文网其他相关文章!

来源:php.cn
本站声明
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系admin@php.cn
热门教程
更多>
最新下载
更多>
网站特效
网站源码
网站素材
前端模板