Eine kurze Diskussion zum Speichern und Wiederherstellen des Ladens von Tensorflow-Modellen

不言
Freigeben: 2018-04-26 16:40:54
Original
2332 Leute haben es durchsucht

In diesem Artikel wird hauptsächlich das Speichern und Wiederherstellen des Tensorflow-Modells vorgestellt. Jetzt teile ich es mit Ihnen und gebe es als Referenz. Schauen wir mal vorbei

Kürzlich haben wir einige Anti-Spam-Arbeiten durchgeführt. Zusätzlich zur Verwendung häufig verwendeter Regelabgleichs- und Filtermethoden verwenden wir auch einige Methoden des maschinellen Lernens zur Klassifizierungsvorhersage. Wir verwenden TensorFlow, um das Modell zu trainieren. In der Vorhersagephase müssen wir das Modell laden und wiederherstellen, was das Speichern und Wiederherstellen des TensorFlow-Modells umfasst.

Fassen Sie die häufig verwendeten Modellspeichermethoden von Tensorflow zusammen.

Checkpoint-Modelldatei (.ckpt) speichern

Zuallererst bietet TensorFlow eine sehr praktische API, tf.train.Saver() um ein Modell für maschinelles Lernen zu speichern und wiederherzustellen.

Modellspeicherung

Es ist sehr praktisch, tf.train.Saver() zum Speichern von Modelldateien zu verwenden:


import tensorflow as tf
import os

def save_model_ckpt(ckpt_file_path):
  x = tf.placeholder(tf.int32, name='x')
  y = tf.placeholder(tf.int32, name='y')
  b = tf.Variable(1, name='b')
  xy = tf.multiply(x, y)
  op = tf.add(xy, b, name='op_to_store')

  sess = tf.Session()
  sess.run(tf.global_variables_initializer())

  path = os.path.dirname(os.path.abspath(ckpt_file_path))
  if os.path.isdir(path) is False:
    os.makedirs(path)

  tf.train.Saver().save(sess, ckpt_file_path)
  
  # test
  feed_dict = {x: 2, y: 3}
  print(sess.run(op, feed_dict))
Nach dem Login kopieren


Das Programm generiert und speichert vier Dateien (vor Version 0.11 wurden nur drei Dateien generiert: checkpoint, model.ckpt , model.ckpt.meta)

  1. Checkpoint-Textdatei, die die Pfadinformationsliste der Modelldatei

  2. model.ckpt.data aufzeichnet -00000 -of-00001 Netzwerkgewichtsinformationen

  3. model.ckpt.index Die beiden Dateien .data und .index sind Binärdateien, die die variablen Parameterinformationen (Gewicht) im Modell speichern

  4. model.ckpt.meta-Binärdatei, die die Strukturinformationen des Rechendiagramms des Modells (die Netzwerkstruktur des Modells) protobuf

speichert Das Obige ist die grundlegende Verwendung von tf.train .Saver().save(). Die Methode save() verfügt auch über viele konfigurierbare Parameter:


tf.train.Saver().save(sess, ckpt_file_path, global_step=1000)
Nach dem Login kopieren


Das Hinzufügen des Parameters global_step bedeutet, dass das Modell alle 1000 Iterationen gespeichert wird. „-1000“ wird nach der Modelldatei model.ckpt-1000.index, model.ckpt-1000.meta, model.ckpt hinzugefügt .data- 1000-00000-of-00001

Speichern Sie das Modell alle 1000 Iterationen, aber die Strukturinformationsdatei des Modells ändert sich nicht. Sie wird nur alle 1000 Iterationen gespeichert, nicht alle 1000 Einmal speichern. Wenn wir die Metadatei nicht speichern müssen, können wir den Parameter write_meta_graph=False wie folgt hinzufügen:


Code kopieren Code wie folgt:

tf.train.Saver().save(sess, ckpt_file_path, global_step=1000, write_meta_graph=False)
Nach dem Login kopieren

Wenn Sie das Modell alle zwei Stunden und nur die letzten 4 Modelle speichern möchten, können Sie max_to_keep hinzufügen (der Standardwert ist 5). Wenn Sie möchten Speichern Sie es in jeder Trainingsepoche, Sie können es auf „Keine“ oder „0“ setzen, aber es ist nutzlos und wird nicht empfohlen. Der Parameter „keep_checkpoint_every_n_hours“ lautet wie folgt:


Kopieren Sie den Code Der Code lautet wie folgt:

tf.train.Saver().save(sess, ckpt_file_path, max_to_keep=4, keep_checkpoint_every_n_hours=2)
Nach dem Login kopieren


Gleichzeitig in der Klasse tf.train.Saver(), wenn wir Geben Sie keine Informationen an, alle Parameterinformationen werden gespeichert. Der Inhalt kann beispielsweise nur x- und y-Parameter speichern (Parameterliste oder Diktat können übergeben werden):


tf.train.Saver([x, y]).save(sess, ckpt_file_path)
Nach dem Login kopieren


ps Während des Modelltrainingsprozesses kann der Name des Variablen- oder Parameternamens, der nach dem Speichern abgerufen werden muss, nicht gefunden werden verloren, andernfalls kann das Modell nach der Wiederherstellung nicht über get_tensor_by_name() abgerufen werden.

Laden und Wiederherstellen des Modells

Für das obige Beispiel zum Speichern des Modells ist der Prozess der Wiederherstellung des Modells wie folgt:


import tensorflow as tf

def restore_model_ckpt(ckpt_file_path):
  sess = tf.Session()
  saver = tf.train.import_meta_graph('./ckpt/model.ckpt.meta') # 加载模型结构
  saver.restore(sess, tf.train.latest_checkpoint('./ckpt')) # 只需要指定目录就可以恢复所有变量信息

  # 直接获取保存的变量
  print(sess.run('b:0'))

  # 获取placeholder变量
  input_x = sess.graph.get_tensor_by_name('x:0')
  input_y = sess.graph.get_tensor_by_name('y:0')
  # 获取需要进行计算的operator
  op = sess.graph.get_tensor_by_name('op_to_store:0')

  # 加入新的操作
  add_on_op = tf.multiply(op, 2)

  ret = sess.run(add_on_op, {input_x: 5, input_y: 5})
  print(ret)
Nach dem Login kopieren


Zuerst die Modellstruktur wiederherstellen, dann die variablen (Parameter-)Informationen wiederherstellen und schließlich können wir verschiedene Informationen im trainierten Modell erhalten ( gespeicherte Variablen, Platzhaltervariablen, Operatoren usw.) und den erhaltenen Variablen können verschiedene neue Operationen hinzugefügt werden (siehe die obigen Codekommentare).
Darüber hinaus können wir auf dieser Basis auch einige Modelle laden und andere Operationen hinzufügen. Einzelheiten finden Sie in der offiziellen Dokumentation und Demo.

Zum Speichern und Wiederherstellen von ckpt-Modelldateien gibt es eine Antwort auf Stackoverflow mit einer klaren Erklärung, auf die Sie sich beziehen können.

Gleichzeitig ist das Tutorial zum Speichern und Wiederherstellen von TensorFlow-Modellen auf cv-tricks.com auch sehr gut, Sie können darauf verweisen.

"Tensorflow 1.0 Learning: Model Saving and Restoration (Saver)" enthält einige Tipps zur Saver-Nutzung.

Eine einzelne Modelldatei (.pb) speichern

Ich habe die Demo von Tensorflows inception-v3 selbst ausgeführt und festgestellt, dass eine .pb generiert wird Nach der Ausführung wird diese Datei für die anschließende Vorhersage oder das Migrationslernen verwendet. Es ist nur eine Datei, sehr cool und sehr praktisch.

Die Hauptidee dieses Prozesses besteht darin, dass die Datei graph_def nicht den Variablenwert im Netzwerk enthält (normalerweise wird das Gewicht gespeichert), sondern den konstanten Wert, also wenn wir Sie können die Variable in eine Konstante konvertieren (mithilfe der Funktion graph_util.convert_variables_to_constants()). Sie können das Ziel erreichen, eine Datei zum Speichern sowohl der Netzwerkarchitektur als auch der Gewichte zu verwenden.

ps: Hier ist .pb der Suffixname der Modelldatei. Natürlich können wir auch andere Suffixe verwenden (verwenden Sie .pb, um mit Google konsistent zu sein╮(╯▽╰)╭). 🎜>

Modellspeicherung

Ähnlich basierend auf dem obigen Beispiel, eine einfache Demo:



import tensorflow as tf
import os
from tensorflow.python.framework import graph_util

def save_mode_pb(pb_file_path):
  x = tf.placeholder(tf.int32, name='x')
  y = tf.placeholder(tf.int32, name='y')
  b = tf.Variable(1, name='b')
  xy = tf.multiply(x, y)
  # 这里的输出需要加上name属性
  op = tf.add(xy, b, name='op_to_store')

  sess = tf.Session()
  sess.run(tf.global_variables_initializer())

  path = os.path.dirname(os.path.abspath(pb_file_path))
  if os.path.isdir(path) is False:
    os.makedirs(path)

  # convert_variables_to_constants 需要指定output_node_names,list(),可以多个
  constant_graph = graph_util.convert_variables_to_constants(sess, sess.graph_def, ['op_to_store'])
  with tf.gfile.FastGFile(pb_file_path, mode='wb') as f:
    f.write(constant_graph.SerializeToString())

  # test
  feed_dict = {x: 2, y: 3}
  print(sess.run(op, feed_dict))
Nach dem Login kopieren


程序生成并保存一个文件

model.pb 二进制文件,同时保存了模型网络结构和参数(权重)信息

模型加载还原

针对上面的模型保存例子,还原模型的过程如下:


import tensorflow as tf
from tensorflow.python.platform import gfile

def restore_mode_pb(pb_file_path):
  sess = tf.Session()
  with gfile.FastGFile(pb_file_path, 'rb') as f:
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read())
    sess.graph.as_default()
    tf.import_graph_def(graph_def, name='')

  print(sess.run('b:0'))

  input_x = sess.graph.get_tensor_by_name('x:0')
  input_y = sess.graph.get_tensor_by_name('y:0')

  op = sess.graph.get_tensor_by_name('op_to_store:0')

  ret = sess.run(op, {input_x: 5, input_y: 5})
  print(ret)
Nach dem Login kopieren


模型的还原过程与checkpoint差不多一样。

《将TensorFlow的网络导出为单个文件》上介绍了TensorFlow保存单个模型文件的方式,大同小异,可以看看。

思考

模型的保存与加载只是TensorFlow中最基础的部分之一,虽然简单但是也必不可少,在实际运用中还需要注意模型何时保存,哪些变量需要保存,如何设计加载实现迁移学习等等问题。

同时TensorFlow的函数和类都在一直变化更新,以后也有可能出现更丰富的模型保存和还原的方法。

选择保存为checkpoint或单个pb文件视业务情况而定,没有特别大的差别。checkpoint保存感觉会更加灵活一些,pb文件更适合线上部署吧(个人看法)。

以上完整代码:github https://github.com/liuyan731/tf_demo

相关推荐:

TensorFlow模型保存和提取方法示例


Das obige ist der detaillierte Inhalt vonEine kurze Diskussion zum Speichern und Wiederherstellen des Ladens von Tensorflow-Modellen. Für weitere Informationen folgen Sie bitte anderen verwandten Artikeln auf der PHP chinesischen Website!

Verwandte Etiketten:
Quelle:php.cn
Erklärung dieser Website
Der Inhalt dieses Artikels wird freiwillig von Internetnutzern beigesteuert und das Urheberrecht liegt beim ursprünglichen Autor. Diese Website übernimmt keine entsprechende rechtliche Verantwortung. Wenn Sie Inhalte finden, bei denen der Verdacht eines Plagiats oder einer Rechtsverletzung besteht, wenden Sie sich bitte an admin@php.cn
Beliebte Tutorials
Mehr>
Neueste Downloads
Mehr>
Web-Effekte
Quellcode der Website
Website-Materialien
Frontend-Vorlage