ホームページ > バックエンド開発 > Python チュートリアル > TensorFlow の使用を開始し、tf.train.Saver() を使用してモデルを保存する

TensorFlow の使用を開始し、tf.train.Saver() を使用してモデルを保存する

不言
リリース: 2018-04-24 14:15:07
オリジナル
4142 人が閲覧しました

この記事では、TensorFlow を始めるときに tf.train.Saver() を使用してモデルを保存する方法を主に紹介します。一緒に見てみましょう

モデルの保存に関するいくつかの考え

saver = tf.train.Saver(max_to_keep=3)
ログイン後にコピー

セーバーを定義するとき、通常、保存されるモデルの最大数を定義します。一般に、モデル自体が大きい場合、ハードウェアを考慮する必要があります。ディスクサイズ。現在トレーニングされているモデルに基づいて微調整を実行する必要がある場合は、一度に過剰適合される可能性があるため、後続の微調整は必ずしも最適な ckpt から実行されるとは限りません。ただし、保存するファイルが多すぎると、ハードディスクが圧迫されてしまいます。最良のモデルのみを保持したい場合は、特定のステップ数まで繰り返すたびに検証セットの精度または f1 値を計算し、今回の結果が前回よりも優れている場合は、新しいモデルを保存します。それ以外の場合は、モデルを保存する必要はありません。

異なるエポックに保存されたモデルを融合に使用したい場合は、3 ~ 5 つのモデルで十分です。融合されたモデルが M になると仮定します。このように、M の場合は、より優れたものになる可能性があります。 m_bestよりも。しかし、このモデルを他の構造のモデルと融合すると、M は平均演算に相当し、モデルの「特性」が低下するため、M の効果は m_best ほど良くありません。

しかし、学習率を調整して複数の局所最適点を取得する新しい融合方法があります。つまり、損失を減らすことができない場合は、ckpt を保存し、学習率を上げて次の点を見つけ続けます。これらの ckpt は融合に使用されます。まだ試していませんが、単一のモデルでは確実に改善されますが、他のモデルとの融合では改善しない状況が発生するかどうかはわかりません。改善する。

tf.train.Saver() を使用してモデルを保存する方法

これまでにも、主に不正なコーディングの問題が原因でエラーが発生しました。したがって、ファイル パスに中国語の文字が含まれないように注意してください。

import tensorflow as tf
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
sess = tf.Session(config=config)

# Create some variables.
v1 = tf.Variable([1.0, 2.3], name="v1")
v2 = tf.Variable(55.5, name="v2")

# Add an op to initialize the variables.
init_op = tf.global_variables_initializer()

# Add ops to save and restore all the variables.
saver = tf.train.Saver()

ckpt_path = './ckpt/test-model.ckpt'
# Later, launch the model, initialize the variables, do some work, save the
# variables to disk.
sess.run(init_op)
save_path = saver.save(sess, ckpt_path, global_step=1)
print("Model saved in file: %s" % save_path)
ログイン後にコピー

モデルはファイルに保存されました: ./ckpt/test-model.ckpt-1

上記のモデルを保存した後で注意してください。次のモデルを使用してインポートする前に、カーネルを再起動する必要があります。そうしないと、「v1」という名前を 2 回付けることになり、名前が間違ってしまいます。

import tensorflow as tf
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
sess = tf.Session(config=config)

# Create some variables.
v1 = tf.Variable([11.0, 16.3], name="v1")
v2 = tf.Variable(33.5, name="v2")

# Add ops to save and restore all the variables.
saver = tf.train.Saver()

# Later, launch the model, use the saver to restore variables from disk, and
# do some work with the model.
# Restore variables from disk.
ckpt_path = './ckpt/test-model.ckpt'
saver.restore(sess, ckpt_path + '-'+ str(1))
print("Model restored.")

print sess.run(v1)
print sess.run(v2)
ログイン後にコピー

INFO:tensorflow:Restoringparametersfrom./ckpt/test-model.ckpt-1
モデルが復元されました。
[ 1. 2.29999995]
55.5

モデルをインポートする前、再定義する必要があります再び変数。

ただし、すべての変数を再定義する必要はなく、必要な変数を定義するだけです。

言い換えると、定義する変数はチェックポイントに存在する必要がありますが、チェックポイント内のすべての変数を再定義する必要はありません。

import tensorflow as tf
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
sess = tf.Session(config=config)

# Create some variables.
v1 = tf.Variable([11.0, 16.3], name="v1")

# Add ops to save and restore all the variables.
saver = tf.train.Saver()

# Later, launch the model, use the saver to restore variables from disk, and
# do some work with the model.
# Restore variables from disk.
ckpt_path = './ckpt/test-model.ckpt'
saver.restore(sess, ckpt_path + '-'+ str(1))
print("Model restored.")

print sess.run(v1)
ログイン後にコピー

INFO:tensorflow:./ckpt/test-model.ckpt-1からパラメータを復元しています
モデルが復元されました。
[ 1. 2.29999995]

tf.Saver([tensors_to _be] _保存済み]) OKリストと保存するテンソルを渡します。リストが指定されていない場合は、デフォルトで現在のテンソルがすべて保存されます。一般的に、 tf.Saver は tf.variable_scope() とうまく組み合わせることができます: [転移学習] すでに保存されているモデルに新しい変数を追加して微調整する

関連する推奨事項:

Tensorflow についてtf.train.batch 関数

以上がTensorFlow の使用を開始し、tf.train.Saver() を使用してモデルを保存するの詳細内容です。詳細については、PHP 中国語 Web サイトの他の関連記事を参照してください。

関連ラベル:
ソース:php.cn
このウェブサイトの声明
この記事の内容はネチズンが自主的に寄稿したものであり、著作権は原著者に帰属します。このサイトは、それに相当する法的責任を負いません。盗作または侵害の疑いのあるコンテンツを見つけた場合は、admin@php.cn までご連絡ください。
人気のチュートリアル
詳細>
最新のダウンロード
詳細>
ウェブエフェクト
公式サイト
サイト素材
フロントエンドテンプレート