How to export TensorFlow's model network to a single file

不言
Release: 2018-04-23 15:39:49
Original
1727 people have browsed it

This article mainly introduces the method of exporting the TensorFlow network into a single file. Now I share it with you and give it as a reference. Let’s take a look together

Sometimes, we need to export the TensorFlow model to a single file (including model architecture definition and weights) for easy use in other places (such as deploying a network in c). Using tf.train.write_graph() only exports the definition of the network (without weights) by default, while the file graph_def exported using tf.train.Saver().save() is separated from the weights, so other methods need to be used. method.

We know that the graph_def file does not contain the Variable value in the network (usually the weight is stored), but it does contain the constant value, so if we can convert the Variable to constant, we can use a file The goal of simultaneously storing network architecture and weights.

We can freeze the weights and save the network in the following way:

import tensorflow as tf
from tensorflow.python.framework.graph_util import convert_variables_to_constants

# 构造网络
a = tf.Variable([[3],[4]], dtype=tf.float32, name='a')
b = tf.Variable(4, dtype=tf.float32, name='b')
# 一定要给输出tensor取一个名字!!
output = tf.add(a, b, name='out')

# 转换Variable为constant,并将网络写入到文件
with tf.Session() as sess:
  sess.run(tf.global_variables_initializer())
  # 这里需要填入输出tensor的名字
  graph = convert_variables_to_constants(sess, sess.graph_def, ["out"])
  tf.train.write_graph(graph, '.', 'graph.pb', as_text=False)
Copy after login

When restoring the network, we can use the following way:

import tensorflow as tf
with tf.Session() as sess:
  with open('./graph.pb', 'rb') as f:
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read()) 
    output = tf.import_graph_def(graph_def, return_elements=['out:0']) 
    print(sess.run(output))
Copy after login

The output result is:

[array([[ 7.],
[ 8.]], dtype=float32) ]

You can see that the previous weights are indeed saved!!

The problem is, our network needs to have an interface for inputting custom data! Otherwise, what's the use of this thing. . Don't worry, of course there is a way.

import tensorflow as tf
from tensorflow.python.framework.graph_util import convert_variables_to_constants
a = tf.Variable([[3],[4]], dtype=tf.float32, name='a')
b = tf.Variable(4, dtype=tf.float32, name='b')
input_tensor = tf.placeholder(tf.float32, name='input')
output = tf.add((a+b), input_tensor, name='out')

with tf.Session() as sess:
  sess.run(tf.global_variables_initializer())
  graph = convert_variables_to_constants(sess, sess.graph_def, ["out"])
  tf.train.write_graph(graph, '.', 'graph.pb', as_text=False)
Copy after login

Use the above code to resave the network to graph.pb. This time we have an input placeholder. Let’s see how to restore the network and enter custom data.

import tensorflow as tf

with tf.Session() as sess:
  with open('./graph.pb', 'rb') as f: 
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read()) 
    output = tf.import_graph_def(graph_def, input_map={'input:0':4.}, return_elements=['out:0'], name='a') 
    print(sess.run(output))
Copy after login

The output result is:

[array([[ 11.],
 [ 12.]], dtype=float32)]

You can see that there is no problem with the result. Of course, the input_map can be replaced with a new custom placeholder, as shown below:

import tensorflow as tf

new_input = tf.placeholder(tf.float32, shape=())

with tf.Session() as sess:
  with open('./graph.pb', 'rb') as f: 
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read()) 
    output = tf.import_graph_def(graph_def, input_map={'input:0':new_input}, return_elements=['out:0'], name='a') 
    print(sess.run(output, feed_dict={new_input:4}))
Copy after login

Look at the output, there is no problem.

[array([[ 11.],
[ 12.]], dtype=float32)]

Another point that needs to be explained is , when using tf.train.write_graph to write the network architecture, if as_text=True is set, a small modification needs to be made when importing the network.

import tensorflow as tf
from google.protobuf import text_format

with tf.Session() as sess:
  # 不使用'rb'模式
  with open('./graph.pb', 'r') as f:
    graph_def = tf.GraphDef()
    # 不使用graph_def.ParseFromString(f.read())
    text_format.Merge(f.read(), graph_def)
    output = tf.import_graph_def(graph_def, return_elements=['out:0']) 
    print(sess.run(output))
Copy after login

Related recommendations:

TensorFlow installation and detailed explanation of jupyter notebook configuration


The above is the detailed content of How to export TensorFlow's model network to a single file. For more information, please follow other related articles on the PHP Chinese website!

Related labels:
source:php.cn
Statement of this Website
The content of this article is voluntarily contributed by netizens, and the copyright belongs to the original author. This site does not assume corresponding legal responsibility. If you find any content suspected of plagiarism or infringement, please contact admin@php.cn
Popular Tutorials
More>
Latest Downloads
More>
Web Effects
Website Source Code
Website Materials
Front End Template