Heim > Backend-Entwicklung > Python-Tutorial > So exportieren Sie das Modellnetzwerk von TensorFlow in eine einzelne Datei

So exportieren Sie das Modellnetzwerk von TensorFlow in eine einzelne Datei

不言
Freigeben: 2018-04-23 15:39:49
Original
1756 Leute haben es durchsucht

In diesem Artikel wird hauptsächlich die Methode zum Exportieren des TensorFlow-Netzwerks in eine einzelne Datei vorgestellt. Jetzt teile ich sie mit Ihnen und gebe sie als Referenz. Werfen wir gemeinsam einen Blick darauf

Manchmal müssen wir das TensorFlow-Modell in eine einzelne Datei exportieren (einschließlich Modellarchitekturdefinition und Gewichtungen), um es an anderen Orten einfach verwenden zu können (z. B. bei der Bereitstellung eines Netzwerks in C++). Die Verwendung von tf.train.write_graph() exportiert standardmäßig nur die Definition des Netzwerks (ohne Gewichtungen), während die mit tf.train.Saver().save() exportierte Datei graph_def von den Gewichten getrennt ist, sodass andere Methoden dies tun müssen Methode verwendet werden.

Wir wissen, dass die graph_def-Datei nicht den Variablenwert im Netzwerk enthält (normalerweise wird das Gewicht gespeichert), aber sie enthält den konstanten Wert. Wenn wir also die Variable in eine Konstante umwandeln können, können wir sie verwenden Eine Datei Das Ziel ist die gleichzeitige Speicherung von Netzwerkarchitektur und -gewichten.

Wir können die Gewichte einfrieren und das Netzwerk auf folgende Weise speichern:

import tensorflow as tf
from tensorflow.python.framework.graph_util import convert_variables_to_constants

# 构造网络
a = tf.Variable([[3],[4]], dtype=tf.float32, name='a')
b = tf.Variable(4, dtype=tf.float32, name='b')
# 一定要给输出tensor取一个名字!!
output = tf.add(a, b, name='out')

# 转换Variable为constant,并将网络写入到文件
with tf.Session() as sess:
  sess.run(tf.global_variables_initializer())
  # 这里需要填入输出tensor的名字
  graph = convert_variables_to_constants(sess, sess.graph_def, ["out"])
  tf.train.write_graph(graph, '.', 'graph.pb', as_text=False)
Nach dem Login kopieren

Beim Wiederherstellen des Netzwerks können wir das verwenden Folgendermaßen:

import tensorflow as tf
with tf.Session() as sess:
  with open('./graph.pb', 'rb') as f:
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read()) 
    output = tf.import_graph_def(graph_def, return_elements=['out:0']) 
    print(sess.run(output))
Nach dem Login kopieren

Das Ausgabeergebnis ist:

[array([[ 7.],
[ 8.]], dtype=float32)]

Sie können sehen, dass die vorherigen Gewichte tatsächlich gespeichert werden!!

Das Problem ist, dass unser Netzwerk dies tun muss eine benutzerdefinierte Eingabedatenschnittstelle! Was nützt dieses Ding sonst? . Keine Sorge, natürlich gibt es einen Weg.

import tensorflow as tf
from tensorflow.python.framework.graph_util import convert_variables_to_constants
a = tf.Variable([[3],[4]], dtype=tf.float32, name='a')
b = tf.Variable(4, dtype=tf.float32, name='b')
input_tensor = tf.placeholder(tf.float32, name='input')
output = tf.add((a+b), input_tensor, name='out')

with tf.Session() as sess:
  sess.run(tf.global_variables_initializer())
  graph = convert_variables_to_constants(sess, sess.graph_def, ["out"])
  tf.train.write_graph(graph, '.', 'graph.pb', as_text=False)
Nach dem Login kopieren

Verwenden Sie den obigen Code, um das Netzwerk in graph.pb erneut zu speichern. Sehen wir uns an, wie das Netzwerk wiederhergestellt wird Geben Sie es ein.

import tensorflow as tf

with tf.Session() as sess:
  with open('./graph.pb', 'rb') as f: 
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read()) 
    output = tf.import_graph_def(graph_def, input_map={'input:0':4.}, return_elements=['out:0'], name='a') 
    print(sess.run(output))
Nach dem Login kopieren

Das Ausgabeergebnis ist:

[array([[ 11.],
[ 12. ] ], dtype=float32)]

Sie können sehen, dass das Ergebnis kein Problem darstellt. Natürlich kann die input_map durch einen neuen benutzerdefinierten Platzhalter ersetzt werden, wie unten gezeigt:

import tensorflow as tf

new_input = tf.placeholder(tf.float32, shape=())

with tf.Session() as sess:
  with open('./graph.pb', 'rb') as f: 
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read()) 
    output = tf.import_graph_def(graph_def, input_map={'input:0':new_input}, return_elements=['out:0'], name='a') 
    print(sess.run(output, feed_dict={new_input:4}))
Nach dem Login kopieren

Sehen Sie sich die Ausgabe an, es gibt auch kein Problem.

[array([[ 11.],
[ 12.]], dtype=float32)]

Ein weiterer Punkt, der sein muss Die Erklärung lautet: Wenn Sie tf.train.write_graph zum Schreiben der Netzwerkarchitektur verwenden und as_text=True festgelegt ist, müssen Sie beim Importieren des Netzwerks eine kleine Änderung vornehmen.

import tensorflow as tf
from google.protobuf import text_format

with tf.Session() as sess:
  # 不使用'rb'模式
  with open('./graph.pb', 'r') as f:
    graph_def = tf.GraphDef()
    # 不使用graph_def.ParseFromString(f.read())
    text_format.Merge(f.read(), graph_def)
    output = tf.import_graph_def(graph_def, return_elements=['out:0']) 
    print(sess.run(output))
Nach dem Login kopieren

Verwandte Empfehlungen:

TensorFlow-Installation und detaillierte Erklärung der Jupyter-Notebook-Konfiguration


Das obige ist der detaillierte Inhalt vonSo exportieren Sie das Modellnetzwerk von TensorFlow in eine einzelne Datei. Für weitere Informationen folgen Sie bitte anderen verwandten Artikeln auf der PHP chinesischen Website!

Verwandte Etiketten:
Quelle:php.cn
Erklärung dieser Website
Der Inhalt dieses Artikels wird freiwillig von Internetnutzern beigesteuert und das Urheberrecht liegt beim ursprünglichen Autor. Diese Website übernimmt keine entsprechende rechtliche Verantwortung. Wenn Sie Inhalte finden, bei denen der Verdacht eines Plagiats oder einer Rechtsverletzung besteht, wenden Sie sich bitte an admin@php.cn
Beliebte Tutorials
Mehr>
Neueste Downloads
Mehr>
Web-Effekte
Quellcode der Website
Website-Materialien
Frontend-Vorlage