ホームページ > バックエンド開発 > Python チュートリアル > TensorFlow のモデル ネットワークを単一のファイルにエクスポートする方法

TensorFlow のモデル ネットワークを単一のファイルにエクスポートする方法

不言
リリース: 2018-04-23 15:39:49
オリジナル
1756 人が閲覧しました

この記事では主に TensorFlow ネットワークを単一のファイルにエクスポートする方法を紹介し、参考として提供します。一緒に見てみましょう

場合によっては、他の場所 (C++ でのネットワークのデプロイなど) で簡単に使用できるように、TensorFlow モデルを単一のファイル (モデル アーキテクチャの定義と重みを含む) にエクスポートする必要があります。 tf.train.write_graph() を使用すると、デフォルトではネットワークの定義 (重みなし) のみがエクスポートされますが、tf.train.Saver().save() を使用してエクスポートされたファイルgraph_def は重みから分離されるため、他のメソッドは次のことを行う必要があります。方法が使用されます。

graph_def ファイルにはネットワーク内の変数値が含まれていないことがわかっています (通常は重みが保存されています)。ただし、定数値は含まれているため、変数を定数に変換できれば、1 つのファイルを使用して保存できます。ネットワーク アーキテクチャと同時に重みを付けてターゲットを設定します。

次の方法で重みをフリーズしてネットワークを保存できます:

import tensorflow as tf
from tensorflow.python.framework.graph_util import convert_variables_to_constants

# 构造网络
a = tf.Variable([[3],[4]], dtype=tf.float32, name='a')
b = tf.Variable(4, dtype=tf.float32, name='b')
# 一定要给输出tensor取一个名字!!
output = tf.add(a, b, name='out')

# 转换Variable为constant,并将网络写入到文件
with tf.Session() as sess:
  sess.run(tf.global_variables_initializer())
  # 这里需要填入输出tensor的名字
  graph = convert_variables_to_constants(sess, sess.graph_def, ["out"])
  tf.train.write_graph(graph, '.', 'graph.pb', as_text=False)
ログイン後にコピー

ネットワークを復元するときは、次の方法を使用できます:

import tensorflow as tf
with tf.Session() as sess:
  with open('./graph.pb', 'rb') as f:
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read()) 
    output = tf.import_graph_def(graph_def, return_elements=['out:0']) 
    print(sess.run(output))
ログイン後にコピー

出力結果は次のとおりです:

[array([[ 7 .],
[ 8.]], dtype=float32)]

以前の重みが実際に保存されていることがわかります!!

問題は、ネットワークにインターフェイスが必要であることです。カスタムデータを入力するためです!そうでなければ、これは何の役に立つのでしょう。 。心配しないでください、もちろん方法はあります。

import tensorflow as tf
from tensorflow.python.framework.graph_util import convert_variables_to_constants
a = tf.Variable([[3],[4]], dtype=tf.float32, name='a')
b = tf.Variable(4, dtype=tf.float32, name='b')
input_tensor = tf.placeholder(tf.float32, name='input')
output = tf.add((a+b), input_tensor, name='out')

with tf.Session() as sess:
  sess.run(tf.global_variables_initializer())
  graph = convert_variables_to_constants(sess, sess.graph_def, ["out"])
  tf.train.write_graph(graph, '.', 'graph.pb', as_text=False)
ログイン後にコピー

上記のコードを使用して、ネットワークをgraph.pbに再保存します。今回は、ネットワークを復元してカスタム データを入力する方法を見てみましょう。

import tensorflow as tf

with tf.Session() as sess:
  with open('./graph.pb', 'rb') as f: 
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read()) 
    output = tf.import_graph_def(graph_def, input_map={'input:0':4.}, return_elements=['out:0'], name='a') 
    print(sess.run(output))
ログイン後にコピー

出力結果は、

[array([[ 11.],
[ 12.]], dtype=float32)]

で問題ないことがわかります。もちろん、input_map 内の結果です。以下に示すように、新しいカスタム プレースホルダーに置き換えることができます。

import tensorflow as tf

new_input = tf.placeholder(tf.float32, shape=())

with tf.Session() as sess:
  with open('./graph.pb', 'rb') as f: 
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read()) 
    output = tf.import_graph_def(graph_def, input_map={'input:0':new_input}, return_elements=['out:0'], name='a') 
    print(sess.run(output, feed_dict={new_input:4}))
ログイン後にコピー

出力を見てください。問題はありません。

[array([[ 11.],
[ 12.]], dtype=float32)]

もう 1 つ説明する必要がある点は、 tf.train.write_graph を使用してネットワーク アーキテクチャを記述する場合、 as_text=True の場合、ネットワークにインポートするときに若干の変更を加える必要があります。

import tensorflow as tf
from google.protobuf import text_format

with tf.Session() as sess:
  # 不使用'rb'模式
  with open('./graph.pb', 'r') as f:
    graph_def = tf.GraphDef()
    # 不使用graph_def.ParseFromString(f.read())
    text_format.Merge(f.read(), graph_def)
    output = tf.import_graph_def(graph_def, return_elements=['out:0']) 
    print(sess.run(output))
ログイン後にコピー

関連する推奨事項:

TensorFlow のインストールと jupyter Notebook 構成の詳細な説明


以上がTensorFlow のモデル ネットワークを単一のファイルにエクスポートする方法の詳細内容です。詳細については、PHP 中国語 Web サイトの他の関連記事を参照してください。

関連ラベル:
ソース:php.cn
このウェブサイトの声明
この記事の内容はネチズンが自主的に寄稿したものであり、著作権は原著者に帰属します。このサイトは、それに相当する法的責任を負いません。盗作または侵害の疑いのあるコンテンツを見つけた場合は、admin@php.cn までご連絡ください。
人気のチュートリアル
詳細>
最新のダウンロード
詳細>
ウェブエフェクト
公式サイト
サイト素材
フロントエンドテンプレート