Maison > développement back-end > Tutoriel Python > Comment exporter le réseau de modèles TensorFlow vers un seul fichier

Comment exporter le réseau de modèles TensorFlow vers un seul fichier

不言
Libérer: 2018-04-23 15:39:49
original
1776 Les gens l'ont consulté

Cet article présente principalement la méthode d'exportation du réseau TensorFlow dans un seul fichier. Maintenant, je le partage avec vous et le donne comme référence. Jetons un coup d'œil ensemble

Parfois, nous devons exporter le modèle TensorFlow vers un seul fichier (y compris la définition de l'architecture du modèle et les pondérations) pour une utilisation facile à d'autres endroits (comme le déploiement d'un réseau en C++). L'utilisation de tf.train.write_graph() exporte uniquement la définition du réseau (sans poids) par défaut, tandis que le fichier graph_def exporté à l'aide de tf.train.Saver().save() est séparé des poids, donc d'autres méthodes doivent être utilisée.

Nous savons que le fichier graph_def ne contient pas la valeur de la variable dans le réseau (généralement le poids est stocké), mais il contient la valeur constante, donc si nous pouvons convertir la variable en constante, nous pouvons utiliser un fichier L'objectif de stocker simultanément l'architecture et les poids du réseau.

Nous pouvons geler les poids et sauvegarder le réseau de la manière suivante :

import tensorflow as tf
from tensorflow.python.framework.graph_util import convert_variables_to_constants

# 构造网络
a = tf.Variable([[3],[4]], dtype=tf.float32, name='a')
b = tf.Variable(4, dtype=tf.float32, name='b')
# 一定要给输出tensor取一个名字!!
output = tf.add(a, b, name='out')

# 转换Variable为constant,并将网络写入到文件
with tf.Session() as sess:
  sess.run(tf.global_variables_initializer())
  # 这里需要填入输出tensor的名字
  graph = convert_variables_to_constants(sess, sess.graph_def, ["out"])
  tf.train.write_graph(graph, '.', 'graph.pb', as_text=False)
Copier après la connexion

Lors de la restauration du réseau, nous pouvons utiliser le de la manière suivante :

import tensorflow as tf
with tf.Session() as sess:
  with open('./graph.pb', 'rb') as f:
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read()) 
    output = tf.import_graph_def(graph_def, return_elements=['out:0']) 
    print(sess.run(output))
Copier après la connexion

Le résultat de sortie est :

[array([[ 7.],
[ 8.]], dtype=float32)]

Vous pouvez voir que les poids précédents sont bien enregistrés !!

Le problème est que notre réseau doit avoir une interface de données d'entrée personnalisée ! Sinon, à quoi sert ce truc. . Ne vous inquiétez pas, il existe bien sûr un moyen.

import tensorflow as tf
from tensorflow.python.framework.graph_util import convert_variables_to_constants
a = tf.Variable([[3],[4]], dtype=tf.float32, name='a')
b = tf.Variable(4, dtype=tf.float32, name='b')
input_tensor = tf.placeholder(tf.float32, name='input')
output = tf.add((a+b), input_tensor, name='out')

with tf.Session() as sess:
  sess.run(tf.global_variables_initializer())
  graph = convert_variables_to_constants(sess, sess.graph_def, ["out"])
  tf.train.write_graph(graph, '.', 'graph.pb', as_text=False)
Copier après la connexion

Utilisez le code ci-dessus pour réenregistrer le réseau dans graph.pb. Cette fois, nous avons un espace réservé d'entrée. Voyons comment restaurer le réseau et. saisissez-le. Données personnalisées.

import tensorflow as tf

with tf.Session() as sess:
  with open('./graph.pb', 'rb') as f: 
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read()) 
    output = tf.import_graph_def(graph_def, input_map={'input:0':4.}, return_elements=['out:0'], name='a') 
    print(sess.run(output))
Copier après la connexion

Le résultat de sortie est :

[array([[ 11.],
[ 12. ] ], dtype=float32)]

Vous pouvez voir qu'il n'y a aucun problème avec le résultat. Bien sûr, le input_map peut être remplacé par un nouvel espace réservé personnalisé, comme indiqué ci-dessous :

import tensorflow as tf

new_input = tf.placeholder(tf.float32, shape=())

with tf.Session() as sess:
  with open('./graph.pb', 'rb') as f: 
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read()) 
    output = tf.import_graph_def(graph_def, input_map={'input:0':new_input}, return_elements=['out:0'], name='a') 
    print(sess.run(output, feed_dict={new_input:4}))
Copier après la connexion

Regardez le résultat, il n'y a pas de problème non plus.

[array([[ 11.],
[ 12.]], dtype=float32)]

Un autre point qui doit être expliqué est que, lors de l'utilisation de tf.train.write_graph pour écrire l'architecture du réseau, si as_text=True est défini, une petite modification doit être apportée lors de l'importation du réseau.

import tensorflow as tf
from google.protobuf import text_format

with tf.Session() as sess:
  # 不使用'rb'模式
  with open('./graph.pb', 'r') as f:
    graph_def = tf.GraphDef()
    # 不使用graph_def.ParseFromString(f.read())
    text_format.Merge(f.read(), graph_def)
    output = tf.import_graph_def(graph_def, return_elements=['out:0']) 
    print(sess.run(output))
Copier après la connexion

Recommandations associées :

Installation de TensorFlow et explication détaillée de la configuration du notebook Jupyter


Ce qui précède est le contenu détaillé de. pour plus d'informations, suivez d'autres articles connexes sur le site Web de PHP en chinois!

Étiquettes associées:
source:php.cn
Déclaration de ce site Web
Le contenu de cet article est volontairement contribué par les internautes et les droits d'auteur appartiennent à l'auteur original. Ce site n'assume aucune responsabilité légale correspondante. Si vous trouvez un contenu suspecté de plagiat ou de contrefaçon, veuillez contacter admin@php.cn
Tutoriels populaires
Plus>
Derniers téléchargements
Plus>
effets Web
Code source du site Web
Matériel du site Web
Modèle frontal