> 백엔드 개발 > 파이썬 튜토리얼 > TensorFlow 모델 저장 및 추출 방법 예시

TensorFlow 모델 저장 및 추출 방법 예시

不言
풀어 주다: 2018-04-26 16:34:11
원래의
2440명이 탐색했습니다.

이 글에서는 주로 TensorFlow 모델 저장 및 추출 방법의 예를 소개하고 참고할 수 있도록 하겠습니다. 함께 살펴볼까요

1. TensorFlow 모델 저장 및 추출 방법

1. TensorFlow는 tf.train.Saver 클래스를 통해 신경망 모델의 저장 및 추출을 구현합니다. tf.train.Saver 객체 보호기의 save 메소드는 TensorFlow 모델을 지정된 경로 saver.save(sess, "Model/model.ckpt")에 저장합니다. 실제로 이 파일 디렉터리에 4개의 개인 파일이 생성됩니다.

체크포인트 파일은 기록된 모델 파일 목록을 저장합니다. model.ckpt.meta는 TensorFlow 계산 그래프의 구조 정보를 저장합니다. model.ckpt는 여기에 기록된 파일 이름을 저장합니다. 다양한 매개변수의 설정은 다양하지만 복원 로드 시 파일 경로 이름은 체크포인트 파일의 "model_checkpoint_path" 값에 따라 결정됩니다.

2 저장된 TensorFlow 모델을 로드하는 방법은 saver.restore(sess, "./Model/model.ckpt")입니다. 모델을 로드하는 코드는 TensorFlow 계산 그래프에 대한 모든 작업을 정의하고 tf를 선언해야 합니다. .train.Saver 클래스의 경우 모델을 로드할 때 변수를 초기화할 필요가 없지만 저장된 모델을 통해 변수의 값이 로드된다는 점이 차이점입니다. 계산 그래프에 작업을 반복적으로 정의하고 싶지 않은 경우 지속형 그래프 saver =tf.train.import_meta_graph("Model/model.ckpt.meta")를 직접 로드할 수 있습니다.

3.tf.train.Saver 클래스는 Saver 클래스 객체를 선언할 때 변수 이름 바꾸기도 지원합니다. 사전 사전을 사용하여 변수 이름을 바꿀 수 있습니다. {"저장된 변수 이름": 변수 이름 바꾸기 name}, saver = tf.train.Saver({"v1":u1, "v2": u2}), 즉 v1이라는 원래 변수가 이제 변수 u1(other-v1이라는 이름)에 로드됩니다.

4. 이전 글의 목적 중 하나는 변수의 슬라이딩 평균을 쉽게 사용하는 것입니다. 모델을 로드할 때 그림자 변수가 변수 자체에 직접 매핑되면 훈련된 모델을 사용할 때 변수의 슬라이딩 평균을 얻기 위해 함수를 호출할 필요가 없습니다. 로드할 때 Saver 클래스 객체를 선언할 때 사전, saver = tf.train.Saver({"v/ExponentialMovingAverage": v}) 및 tf.train.ExponentialMovingAverage를 통해 슬라이딩 평균을 새 변수에 직접 로드합니다. Variable_to_restore() 함수는 변수 이름 바꾸기 사전을 얻습니다.

또한, 계산 그래프에 있는 변수와 해당 값은 Convert_variables_to_constants 함수를 통해 파일에 상수로 저장됩니다.

2. TensorFlow 프로그램 구현

# 本文件程序为配合教材及学习进度渐进进行,请按照注释分段执行 
# 执行时要注意IDE的当前工作过路径,最好每段重启控制器一次,输出结果更准确  
# Part1: 通过tf.train.Saver类实现保存和载入神经网络模型  
# 执行本段程序时注意当前的工作路径 
import tensorflow as tf  
v1 = tf.Variable(tf.constant(1.0, shape=[1]), name="v1") 
v2 = tf.Variable(tf.constant(2.0, shape=[1]), name="v2") 
result = v1 + v2  
saver = tf.train.Saver()  
with tf.Session() as sess: 
  sess.run(tf.global_variables_initializer()) 
  saver.save(sess, "Model/model.ckpt")  
 
# Part2: 加载TensorFlow模型的方法  
import tensorflow as tf  
v1 = tf.Variable(tf.constant(1.0, shape=[1]), name="v1") 
v2 = tf.Variable(tf.constant(2.0, shape=[1]), name="v2") 
result = v1 + v2  
saver = tf.train.Saver()  
with tf.Session() as sess: 
  saver.restore(sess, "./Model/model.ckpt") # 注意此处路径前添加"./" 
  print(sess.run(result)) # [ 3.] 
  
# Part3: 若不希望重复定义计算图上的运算,可直接加载已经持久化的图  
import tensorflow as tf  
saver = tf.train.import_meta_graph("Model/model.ckpt.meta")  
with tf.Session() as sess: 
  saver.restore(sess, "./Model/model.ckpt") # 注意路径写法 
  print(sess.run(tf.get_default_graph().get_tensor_by_name("add:0"))) # [ 3.] 
  
# Part4: tf.train.Saver类也支持在保存和加载时给变量重命名  
import tensorflow as tf  
# 声明的变量名称name与已保存的模型中的变量名称name不一致 
u1 = tf.Variable(tf.constant(1.0, shape=[1]), name="other-v1") 
u2 = tf.Variable(tf.constant(2.0, shape=[1]), name="other-v2") 
result = u1 + u2  
# 若直接生命Saver类对象,会报错变量找不到 
# 使用一个字典dict重命名变量即可,{"已保存的变量的名称name": 重命名变量名} 
# 原来名称name为v1的变量现在加载到变量u1(名称name为other-v1)中 
saver = tf.train.Saver({"v1": u1, "v2": u2})  
with tf.Session() as sess: 
  saver.restore(sess, "./Model/model.ckpt") 
  print(sess.run(result)) # [ 3.] 
  
# Part5: 保存滑动平均模型  
import tensorflow as tf  
v = tf.Variable(0, dtype=tf.float32, name="v") 
for variables in tf.global_variables(): 
  print(variables.name) # v:0  
ema = tf.train.ExponentialMovingAverage(0.99) 
maintain_averages_op = ema.apply(tf.global_variables()) 
for variables in tf.global_variables(): 
  print(variables.name) # v:0 
             # v/ExponentialMovingAverage:0  
saver = tf.train.Saver()  
with tf.Session() as sess: 
  sess.run(tf.global_variables_initializer()) 
  sess.run(tf.assign(v, 10)) 
  sess.run(maintain_averages_op) 
  saver.save(sess, "Model/model_ema.ckpt") 
  print(sess.run([v, ema.average(v)])) # [10.0, 0.099999905]  
 
# Part6: 通过变量重命名直接读取变量的滑动平均值  
import tensorflow as tf  
v = tf.Variable(0, dtype=tf.float32, name="v") 
saver = tf.train.Saver({"v/ExponentialMovingAverage": v}) 
 with tf.Session() as sess: 
  saver.restore(sess, "./Model/model_ema.ckpt") 
  print(sess.run(v)) # 0.0999999 
  
# Part7: 通过tf.train.ExponentialMovingAverage的variables_to_restore()函数获取变量重命名字典  
import tensorflow as tf  
v = tf.Variable(0, dtype=tf.float32, name="v") 
# 注意此处的变量名称name一定要与已保存的变量名称一致 
ema = tf.train.ExponentialMovingAverage(0.99) 
print(ema.variables_to_restore()) 
# {&#39;v/ExponentialMovingAverage&#39;: <tf.Variable &#39;v:0&#39; shape=() dtype=float32_ref>} 
# 此处的v取自上面变量v的名称name="v"  
saver = tf.train.Saver(ema.variables_to_restore()) 
 with tf.Session() as sess: 
  saver.restore(sess, "./Model/model_ema.ckpt") 
  print(sess.run(v)) # 0.0999999 
 
# Part8: 通过convert_variables_to_constants函数将计算图中的变量及其取值通过常量的方式保存于一个文件中  
import tensorflow as tf 
from tensorflow.python.framework import graph_util  
v1 = tf.Variable(tf.constant(1.0, shape=[1]), name="v1") 
v2 = tf.Variable(tf.constant(2.0, shape=[1]), name="v2") 
result = v1 + v2  
with tf.Session() as sess: 
  sess.run(tf.global_variables_initializer()) 
  # 导出当前计算图的GraphDef部分,即从输入层到输出层的计算过程部分 
  graph_def = tf.get_default_graph().as_graph_def() 
  output_graph_def = graph_util.convert_variables_to_constants(sess, 
                            graph_def, [&#39;add&#39;])  
  with tf.gfile.GFile("Model/combined_model.pb", &#39;wb&#39;) as f: 
    f.write(output_graph_def.SerializeToString()) 
  
# Part9: 载入包含变量及其取值的模型  
import tensorflow as tf 
from tensorflow.python.platform import gfile  
with tf.Session() as sess: 
  model_filename = "Model/combined_model.pb" 
  with gfile.FastGFile(model_filename, &#39;rb&#39;) as f: 
    graph_def = tf.GraphDef() 
    graph_def.ParseFromString(f.read())  
  result = tf.import_graph_def(graph_def, return_elements=["add:0"]) 
  print(sess.run(result)) # [array([ 3.], dtype=float32)]
로그인 후 복사

관련 권장사항:

tensorflow에서 데이터를 로드하는 세 가지 방법을 설명하세요

tensorflow는 플래그를 사용하여 정의합니다. 명령줄 매개변수

위 내용은 TensorFlow 모델 저장 및 추출 방법 예시의 상세 내용입니다. 자세한 내용은 PHP 중국어 웹사이트의 기타 관련 기사를 참조하세요!

관련 라벨:
원천:php.cn
본 웹사이트의 성명
본 글의 내용은 네티즌들의 자발적인 기여로 작성되었으며, 저작권은 원저작자에게 있습니다. 본 사이트는 이에 상응하는 법적 책임을 지지 않습니다. 표절이나 침해가 의심되는 콘텐츠를 발견한 경우 admin@php.cn으로 문의하세요.
인기 튜토리얼
더>
최신 다운로드
더>
웹 효과
웹사이트 소스 코드
웹사이트 자료
프론트엔드 템플릿