Cet article présente principalement la méthode d'exportation du réseau TensorFlow dans un seul fichier. Maintenant, je le partage avec vous et le donne comme référence. Jetons un coup d'œil ensemble
Parfois, nous devons exporter le modèle TensorFlow vers un seul fichier (y compris la définition de l'architecture du modèle et les pondérations) pour une utilisation facile à d'autres endroits (comme le déploiement d'un réseau en C++). L'utilisation de tf.train.write_graph() exporte uniquement la définition du réseau (sans poids) par défaut, tandis que le fichier graph_def exporté à l'aide de tf.train.Saver().save() est séparé des poids, donc d'autres méthodes doivent être utilisée.
Nous savons que le fichier graph_def ne contient pas la valeur de la variable dans le réseau (généralement le poids est stocké), mais il contient la valeur constante, donc si nous pouvons convertir la variable en constante, nous pouvons utiliser un fichier L'objectif de stocker simultanément l'architecture et les poids du réseau.
Nous pouvons geler les poids et sauvegarder le réseau de la manière suivante :
import tensorflow as tf from tensorflow.python.framework.graph_util import convert_variables_to_constants # 构造网络 a = tf.Variable([[3],[4]], dtype=tf.float32, name='a') b = tf.Variable(4, dtype=tf.float32, name='b') # 一定要给输出tensor取一个名字!! output = tf.add(a, b, name='out') # 转换Variable为constant,并将网络写入到文件 with tf.Session() as sess: sess.run(tf.global_variables_initializer()) # 这里需要填入输出tensor的名字 graph = convert_variables_to_constants(sess, sess.graph_def, ["out"]) tf.train.write_graph(graph, '.', 'graph.pb', as_text=False)
Lors de la restauration du réseau, nous pouvons utiliser le de la manière suivante :
import tensorflow as tf with tf.Session() as sess: with open('./graph.pb', 'rb') as f: graph_def = tf.GraphDef() graph_def.ParseFromString(f.read()) output = tf.import_graph_def(graph_def, return_elements=['out:0']) print(sess.run(output))
Le résultat de sortie est :
[array([[ 7.],
[ 8.]], dtype=float32)]
Vous pouvez voir que les poids précédents sont bien enregistrés !!
Le problème est que notre réseau doit avoir une interface de données d'entrée personnalisée ! Sinon, à quoi sert ce truc. . Ne vous inquiétez pas, il existe bien sûr un moyen.
import tensorflow as tf from tensorflow.python.framework.graph_util import convert_variables_to_constants a = tf.Variable([[3],[4]], dtype=tf.float32, name='a') b = tf.Variable(4, dtype=tf.float32, name='b') input_tensor = tf.placeholder(tf.float32, name='input') output = tf.add((a+b), input_tensor, name='out') with tf.Session() as sess: sess.run(tf.global_variables_initializer()) graph = convert_variables_to_constants(sess, sess.graph_def, ["out"]) tf.train.write_graph(graph, '.', 'graph.pb', as_text=False)
Utilisez le code ci-dessus pour réenregistrer le réseau dans graph.pb. Cette fois, nous avons un espace réservé d'entrée. Voyons comment restaurer le réseau et. saisissez-le. Données personnalisées.
import tensorflow as tf with tf.Session() as sess: with open('./graph.pb', 'rb') as f: graph_def = tf.GraphDef() graph_def.ParseFromString(f.read()) output = tf.import_graph_def(graph_def, input_map={'input:0':4.}, return_elements=['out:0'], name='a') print(sess.run(output))
Le résultat de sortie est :
[array([[ 11.],
[ 12. ] ], dtype=float32)]
Vous pouvez voir qu'il n'y a aucun problème avec le résultat. Bien sûr, le input_map peut être remplacé par un nouvel espace réservé personnalisé, comme indiqué ci-dessous :
import tensorflow as tf new_input = tf.placeholder(tf.float32, shape=()) with tf.Session() as sess: with open('./graph.pb', 'rb') as f: graph_def = tf.GraphDef() graph_def.ParseFromString(f.read()) output = tf.import_graph_def(graph_def, input_map={'input:0':new_input}, return_elements=['out:0'], name='a') print(sess.run(output, feed_dict={new_input:4}))
Regardez le résultat, il n'y a pas de problème non plus.
[array([[ 11.],
[ 12.]], dtype=float32)]
Un autre point qui doit être expliqué est que, lors de l'utilisation de tf.train.write_graph pour écrire l'architecture du réseau, si as_text=True est défini, une petite modification doit être apportée lors de l'importation du réseau.
import tensorflow as tf from google.protobuf import text_format with tf.Session() as sess: # 不使用'rb'模式 with open('./graph.pb', 'r') as f: graph_def = tf.GraphDef() # 不使用graph_def.ParseFromString(f.read()) text_format.Merge(f.read(), graph_def) output = tf.import_graph_def(graph_def, return_elements=['out:0']) print(sess.run(output))
Recommandations associées :
Installation de TensorFlow et explication détaillée de la configuration du notebook Jupyter
Ce qui précède est le contenu détaillé de. pour plus d'informations, suivez d'autres articles connexes sur le site Web de PHP en chinois!