首頁 > 後端開發 > Python教學 > 將TensorFlow的模型網路匯出為單一檔案的方法

將TensorFlow的模型網路匯出為單一檔案的方法

不言
發布: 2018-04-23 15:39:49
原創
1756 人瀏覽過

本篇文章主要介紹了將TensorFlow的網路匯出為單一檔案的方法,現在分享給大家,也為大家做個參考。一起來看看吧

有時候,我們需要將TensorFlow的模型匯出為單一檔案(同時包含模型架構定義與權重),方便在其他地方使用(如在c 中部署網路)。利用tf.train.write_graph()預設只導出了網路的定義(沒有權重),而利用tf.train.Saver().save()導出的檔案graph_def與權重是分離的,因此需要採用別的方法。

我們知道,graph_def檔案中沒有包含網路中的Variable值(通常情況儲存了權重),但是卻包含了constant值,所以如果我們能把Variable轉換為constant,即可達到使用一個文件同時儲存網路架構與權重的目標。

我們可以採用以下方式凍結權重並保存網路:

import tensorflow as tf
from tensorflow.python.framework.graph_util import convert_variables_to_constants

# 构造网络
a = tf.Variable([[3],[4]], dtype=tf.float32, name='a')
b = tf.Variable(4, dtype=tf.float32, name='b')
# 一定要给输出tensor取一个名字!!
output = tf.add(a, b, name='out')

# 转换Variable为constant,并将网络写入到文件
with tf.Session() as sess:
  sess.run(tf.global_variables_initializer())
  # 这里需要填入输出tensor的名字
  graph = convert_variables_to_constants(sess, sess.graph_def, ["out"])
  tf.train.write_graph(graph, '.', 'graph.pb', as_text=False)
登入後複製

當恢復網路時,可以使用以下方式:

import tensorflow as tf
with tf.Session() as sess:
  with open('./graph.pb', 'rb') as f:
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read()) 
    output = tf.import_graph_def(graph_def, return_elements=['out:0']) 
    print(sess.run(output))
登入後複製

輸出結果為:

[array([[ 7.],
       [ 8.]], dtype=float32) ]

可以看到先前的權重確實保存了下來!!

問題來了,我們的網路需要能有一個輸入自訂資料的介面啊!不然這玩意有什麼用。 。別急,當然有辦法。

import tensorflow as tf
from tensorflow.python.framework.graph_util import convert_variables_to_constants
a = tf.Variable([[3],[4]], dtype=tf.float32, name='a')
b = tf.Variable(4, dtype=tf.float32, name='b')
input_tensor = tf.placeholder(tf.float32, name='input')
output = tf.add((a+b), input_tensor, name='out')

with tf.Session() as sess:
  sess.run(tf.global_variables_initializer())
  graph = convert_variables_to_constants(sess, sess.graph_def, ["out"])
  tf.train.write_graph(graph, '.', 'graph.pb', as_text=False)
登入後複製

用上述程式碼重新儲存網路至graph.pb,這次我們有了一個輸入placeholder,下面來看看怎麼恢復網路並輸入自訂數據。

import tensorflow as tf

with tf.Session() as sess:
  with open('./graph.pb', 'rb') as f: 
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read()) 
    output = tf.import_graph_def(graph_def, input_map={'input:0':4.}, return_elements=['out:0'], name='a') 
    print(sess.run(output))
登入後複製

輸出結果為:

[array([[ 11.],
       [ 12.]], dtype=float32)]

可以看到結果沒有問題,當然在input_map可以替換為新的自訂的placeholder,如下所示:

#
import tensorflow as tf

new_input = tf.placeholder(tf.float32, shape=())

with tf.Session() as sess:
  with open('./graph.pb', 'rb') as f: 
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read()) 
    output = tf.import_graph_def(graph_def, input_map={'input:0':new_input}, return_elements=['out:0'], name='a') 
    print(sess.run(output, feed_dict={new_input:4}))
登入後複製

看看輸出,同樣沒有問題。

[array([[ 11.],
       [ 12.]], dtype=float32)]

#另外需要說明的一點是,在利用tf.train.write_graph寫網路架構的時候,如果令as_text=True了,則在導入網路的時候,需要做一點小修改。

import tensorflow as tf
from google.protobuf import text_format

with tf.Session() as sess:
  # 不使用'rb'模式
  with open('./graph.pb', 'r') as f:
    graph_def = tf.GraphDef()
    # 不使用graph_def.ParseFromString(f.read())
    text_format.Merge(f.read(), graph_def)
    output = tf.import_graph_def(graph_def, return_elements=['out:0']) 
    print(sess.run(output))
登入後複製

相關推薦:

#TensorFlow安裝以及對jupyter notebook配置詳解


##

以上是將TensorFlow的模型網路匯出為單一檔案的方法的詳細內容。更多資訊請關注PHP中文網其他相關文章!

相關標籤:
來源:php.cn
本網站聲明
本文內容由網友自願投稿,版權歸原作者所有。本站不承擔相應的法律責任。如發現涉嫌抄襲或侵權的內容,請聯絡admin@php.cn
作者最新文章
熱門教學
更多>
最新下載
更多>
網站特效
網站源碼
網站素材
前端模板