tensorflow1.0で学習したモデルの保存と復元(Saver)_python

不言
リリース: 2018-04-23 15:42:50
オリジナル
1805 人が閲覧しました

この記事では、tensorflow1.0 学習モデルの保存と回復 (Saver) を主に紹介しますので、参考として共有します。ぜひ一緒に見てみましょう

後で検証またはテストできるように、トレーニング済みのモデルのパラメーターを保存します。 tf.train.Saver() モジュールは、tf でのモデルの保存を提供します。

モデルを保存するには、まず Saver オブジェクトを作成する必要があります:

saver=tf.train.Saver()
ログイン後にコピー

この Saver オブジェクトを作成するときに、よく使用するパラメータがあります。これは、max_to_keep パラメータを設定するために使用されます。モデルを保存するためのパラメーターの数。デフォルトは 5、つまり max_to_keep=5 で、最新の 5 つのモデルを保存します。トレーニング世代 (エポック) ごとにモデルを保存したい場合は、次のように max_to_keep を None または 0 に設定できます:

saver=tf.train.Saver(max_to_keep=0)
ログイン後にコピー

ただし、これはより多くのハードディスクを占有すること以外に実用的ではないため、お勧めできません。

もちろん、最後の世代のモデルのみを保存したい場合は、max_to_keep を 1 に設定するだけです。つまり、

saver=tf.train.Saver(max_to_keep=1)
ログイン後にコピー

セーバー オブジェクトを作成した後、トレーニングされたモデルを次のように保存できます。 :

saver.save(sess,'ckpt/mnist.ckpt',global_step=step)
ログイン後にコピー

最初のパラメータsess、これは言うまでもありません。 2 番目のパラメーターは保存されたパスと名前を設定し、3 番目のパラメーターはトレーニング回数をサフィックスとしてモデル名に追加します。

saver.save(sess, 'my-model', global_step=0) ==> ファイル名: 'my-model-0'
...
saver(sess, 'my-model', global_step= 1000) ==> ファイル名: 'my-model-1000'

mnist の例を見てください:

# -*- coding: utf-8 -*-
"""
Created on Sun Jun 4 10:29:48 2017

@author: Administrator
"""
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("MNIST_data/", one_hot=False)

x = tf.placeholder(tf.float32, [None, 784])
y_=tf.placeholder(tf.int32,[None,])

dense1 = tf.layers.dense(inputs=x, 
           units=1024, 
           activation=tf.nn.relu,
           kernel_initializer=tf.truncated_normal_initializer(stddev=0.01),
           kernel_regularizer=tf.nn.l2_loss)
dense2= tf.layers.dense(inputs=dense1, 
           units=512, 
           activation=tf.nn.relu,
           kernel_initializer=tf.truncated_normal_initializer(stddev=0.01),
           kernel_regularizer=tf.nn.l2_loss)
logits= tf.layers.dense(inputs=dense2, 
            units=10, 
            activation=None,
            kernel_initializer=tf.truncated_normal_initializer(stddev=0.01),
            kernel_regularizer=tf.nn.l2_loss)

loss=tf.losses.sparse_softmax_cross_entropy(labels=y_,logits=logits)
train_op=tf.train.AdamOptimizer(learning_rate=0.001).minimize(loss)
correct_prediction = tf.equal(tf.cast(tf.argmax(logits,1),tf.int32), y_)  
acc= tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

sess=tf.InteractiveSession() 
sess.run(tf.global_variables_initializer())

saver=tf.train.Saver(max_to_keep=1)
for i in range(100):
 batch_xs, batch_ys = mnist.train.next_batch(100)
 sess.run(train_op, feed_dict={x: batch_xs, y_: batch_ys})
 val_loss,val_acc=sess.run([loss,acc], feed_dict={x: mnist.test.images, y_: mnist.test.labels})
 print('epoch:%d, val_loss:%f, val_acc:%f'%(i,val_loss,val_acc))
 saver.save(sess,'ckpt/mnist.ckpt',global_step=i+1)
sess.close()
ログイン後にコピー

コードの赤い部分はモデルを保存するコードです。後で保存されたモデルは前のモデルを上書きし、最後のモデルのみが保存されます。したがって、時間を節約し、保存コードをループの外に置くことができます (max_to_keep=1 にのみ適用されます。それ以外の場合は、ループ内に置く必要があります

実験では、最後の世代が、その世代ではない可能性があります)。デフォルトでは最後の世代を保存せず、最も検証精度の高い世代を保存したい場合は、中間変数と判定ステートメントを追加するだけです。

saver=tf.train.Saver(max_to_keep=1)
max_acc=0
for i in range(100):
 batch_xs, batch_ys = mnist.train.next_batch(100)
 sess.run(train_op, feed_dict={x: batch_xs, y_: batch_ys})
 val_loss,val_acc=sess.run([loss,acc], feed_dict={x: mnist.test.images, y_: mnist.test.labels})
 print('epoch:%d, val_loss:%f, val_acc:%f'%(i,val_loss,val_acc))
 if val_acc>max_acc:
   max_acc=val_acc
   saver.save(sess,'ckpt/mnist.ckpt',global_step=i+1)
sess.close()
ログイン後にコピー

最も検証精度の高い3世代を保存し、各回の検証精度も保存したい場合は、保存用のtxtファイルを生成できます。

saver=tf.train.Saver(max_to_keep=3)
max_acc=0
f=open('ckpt/acc.txt','w')
for i in range(100):
 batch_xs, batch_ys = mnist.train.next_batch(100)
 sess.run(train_op, feed_dict={x: batch_xs, y_: batch_ys})
 val_loss,val_acc=sess.run([loss,acc], feed_dict={x: mnist.test.images, y_: mnist.test.labels})
 print('epoch:%d, val_loss:%f, val_acc:%f'%(i,val_loss,val_acc))
 f.write(str(i+1)+', val_acc: '+str(val_acc)+'\n')
 if val_acc>max_acc:
   max_acc=val_acc
   saver.save(sess,'ckpt/mnist.ckpt',global_step=i+1)
f.close()
sess.close()
ログイン後にコピー

restore() 関数はモデルを復元するために使用されます。これには 2 つのパラメーターが必要です。restore(sess、save_path)、save_path は保存されたモデル パスを指します。 tf.train.latest_checkpoint() を使用して、最後に保存されたモデルを自動的に取得できます。例:

model_file=tf.train.latest_checkpoint('ckpt/')
saver.restore(sess,model_file)
ログイン後にコピー

次に、プログラムの後半のコードを次のように変更できます:

sess=tf.InteractiveSession() 
sess.run(tf.global_variables_initializer())
is_train=False
saver=tf.train.Saver(max_to_keep=3)

#训练阶段
if is_train:
  max_acc=0
  f=open('ckpt/acc.txt','w')
  for i in range(100):
   batch_xs, batch_ys = mnist.train.next_batch(100)
   sess.run(train_op, feed_dict={x: batch_xs, y_: batch_ys})
   val_loss,val_acc=sess.run([loss,acc], feed_dict={x: mnist.test.images, y_: mnist.test.labels})
   print('epoch:%d, val_loss:%f, val_acc:%f'%(i,val_loss,val_acc))
   f.write(str(i+1)+', val_acc: '+str(val_acc)+'\n')
   if val_acc>max_acc:
     max_acc=val_acc
     saver.save(sess,'ckpt/mnist.ckpt',global_step=i+1)
  f.close()

#验证阶段
else:
  model_file=tf.train.latest_checkpoint('ckpt/')
  saver.restore(sess,model_file)
  val_loss,val_acc=sess.run([loss,acc], feed_dict={x: mnist.test.images, y_: mnist.test.labels})
  print('val_loss:%f, val_acc:%f'%(val_loss,val_acc))
sess.close()
ログイン後にコピー

赤でマークされた場所は、モデルの保存と復元に関連するコードです。ブール変数 is_train を使用して、トレーニング フェーズと検証フェーズを制御します。

ソースプログラム全体:

# -*- coding: utf-8 -*-
"""
Created on Sun Jun 4 10:29:48 2017

@author: Administrator
"""
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("MNIST_data/", one_hot=False)

x = tf.placeholder(tf.float32, [None, 784])
y_=tf.placeholder(tf.int32,[None,])

dense1 = tf.layers.dense(inputs=x, 
           units=1024, 
           activation=tf.nn.relu,
           kernel_initializer=tf.truncated_normal_initializer(stddev=0.01),
           kernel_regularizer=tf.nn.l2_loss)
dense2= tf.layers.dense(inputs=dense1, 
           units=512, 
           activation=tf.nn.relu,
           kernel_initializer=tf.truncated_normal_initializer(stddev=0.01),
           kernel_regularizer=tf.nn.l2_loss)
logits= tf.layers.dense(inputs=dense2, 
            units=10, 
            activation=None,
            kernel_initializer=tf.truncated_normal_initializer(stddev=0.01),
            kernel_regularizer=tf.nn.l2_loss)

loss=tf.losses.sparse_softmax_cross_entropy(labels=y_,logits=logits)
train_op=tf.train.AdamOptimizer(learning_rate=0.001).minimize(loss)
correct_prediction = tf.equal(tf.cast(tf.argmax(logits,1),tf.int32), y_)  
acc= tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

sess=tf.InteractiveSession() 
sess.run(tf.global_variables_initializer())

is_train=True
saver=tf.train.Saver(max_to_keep=3)

#训练阶段
if is_train:
  max_acc=0
  f=open('ckpt/acc.txt','w')
  for i in range(100):
   batch_xs, batch_ys = mnist.train.next_batch(100)
   sess.run(train_op, feed_dict={x: batch_xs, y_: batch_ys})
   val_loss,val_acc=sess.run([loss,acc], feed_dict={x: mnist.test.images, y_: mnist.test.labels})
   print('epoch:%d, val_loss:%f, val_acc:%f'%(i,val_loss,val_acc))
   f.write(str(i+1)+', val_acc: '+str(val_acc)+'\n')
   if val_acc>max_acc:
     max_acc=val_acc
     saver.save(sess,'ckpt/mnist.ckpt',global_step=i+1)
  f.close()

#验证阶段
else:
  model_file=tf.train.latest_checkpoint('ckpt/')
  saver.restore(sess,model_file)
  val_loss,val_acc=sess.run([loss,acc], feed_dict={x: mnist.test.images, y_: mnist.test.labels})
  print('val_loss:%f, val_acc:%f'%(val_loss,val_acc))
sess.close()
ログイン後にコピー

関連する推奨事項:

TensorFlow のモデルネットワークを単一のファイルにエクスポートする方法

以上がtensorflow1.0で学習したモデルの保存と復元(Saver)_pythonの詳細内容です。詳細については、PHP 中国語 Web サイトの他の関連記事を参照してください。

関連ラベル:
ソース:php.cn
このウェブサイトの声明
この記事の内容はネチズンが自主的に寄稿したものであり、著作権は原著者に帰属します。このサイトは、それに相当する法的責任を負いません。盗作または侵害の疑いのあるコンテンツを見つけた場合は、admin@php.cn までご連絡ください。
人気のチュートリアル
詳細>
最新のダウンロード
詳細>
ウェブエフェクト
公式サイト
サイト素材
フロントエンドテンプレート