
TensorFlow 张量的打印值:综合指南
在 TensorFlow 中,Tensor 对象表示多维数据数组。要访问张量中存储的实际值,您需要在会话中对其进行评估。
Session.run() 方法
最直接的方法是使用 Session.run() 方法来评估 Tensor 并检索其值:
1 2 3 4 5 6 7 | import tensorflow as tf
sess = tf.Session()
matrix1 = tf.constant([[3., 3.]])
matrix2 = tf.constant([[2.],[2.]])
product = tf.matmul(matrix1, matrix2)
print (sess.run(product))
|
登录后复制
这会将 Tensor 的值打印为 NumPy 数组。
Tensor .eval() 方法
您还可以使用 Tensor.eval() 方法在默认 Session 内评估 Tensor:
1 2 | with tf.Session():
print (product. eval ())
|
登录后复制
交互式会话
为了更方便,您可以使用 tf.InteractiveSession 为整个程序打开默认会话:
1 2 3 4 5 6 7 8 | import tensorflow as tf
tf.InteractiveSession()
matrix1 = tf.constant([[3., 3.]])
matrix2 = tf.constant([[2.],[2.]])
product = tf.matmul(matrix1, matrix2)
print (product. eval ())
|
登录后复制
注释
- 为了提高效率,TensorFlow 将计算的定义(构建数据流图)与执行(评估图并生成值)分开。
- tf.print() 运算符也可用于打印 Tensor 的值,但这需要使用 Session.run() 手动执行。
- 如果可以有效计算的话,tf.get_static_value() 函数可用于获取 Tensor 的常量值。
以上是如何打印 TensorFlow 张量的值?的详细内容。更多信息请关注PHP中文网其他相关文章!