ホームページ > バックエンド開発 > Python チュートリアル > TensorFlow モデルを保存および復元するにはどうすればよいですか?

TensorFlow モデルを保存および復元するにはどうすればよいですか?

Barbara Streisand
リリース: 2024-12-26 16:08:10
オリジナル
268 人が閲覧しました

How Can I Save and Restore TensorFlow Models?

Tensorflow モデルの保存と復元

Tensorflow では、モデルの保存と復元により、トレーニング済みモデルを保存し、将来の使用に活用することができます。関係する手順は次のとおりです:

モデルの保存 (Tensorflow 0.11 以降):

  1. プレースホルダーを作成し、モデルの TensorFlow 操作を定義します。
  2. TensorFlow を初期化する変数。
  3. tf.train.Saver オブジェクトを作成します。
  4. セッションとモデルのパスを指定して saver.save メソッドを呼び出します。

例:

# Define placeholders
w1 = tf.placeholder("float", name="w1")
w2 = tf.placeholder("float", name="w2")

# Define operations
w3 = tf.add(w1, w2)
w4 = tf.multiply(w3, 2.0, name="op_to_restore")

# Initialize variables
sess = tf.Session()
sess.run(tf.global_variables_initializer())

# Create a saver
saver = tf.train.Saver()

# Save the model
saver.save(sess, 'my_model', global_step=1000)
ログイン後にコピー

を復元していますモデル:

  1. tf.train.import_meta_graph 関数を使用してメタ グラフをロードし、重みを復元します。
  2. 保存された変数に直接アクセスします。
  3. プレースホルダーを作成します
  4. 必要なデータにアクセスして実行します。操作。

例:

# Load the meta graph
sess = tf.Session()
saver = tf.train.import_meta_graph('my_model-1000.meta')
saver.restore(sess, tf.train.latest_checkpoint('./'))

# Access saved variables
print(sess.run('bias:0'))  # Prints the saved bias value

# Create placeholders and feed new data
w1 = tf.get_default_graph().get_tensor_by_name("w1:0")
w2 = tf.get_default_graph().get_tensor_by_name("w2:0")
feed_dict = {w1: 13.0, w2: 17.0}

# Access and run the operation
op_to_restore = tf.get_default_graph().get_tensor_by_name("op_to_restore:0")
print(sess.run(op_to_restore, feed_dict))  # Prints the result of the restored operation
ログイン後にコピー

以上がTensorFlow モデルを保存および復元するにはどうすればよいですか?の詳細内容です。詳細については、PHP 中国語 Web サイトの他の関連記事を参照してください。

ソース:php.cn
このウェブサイトの声明
この記事の内容はネチズンが自主的に寄稿したものであり、著作権は原著者に帰属します。このサイトは、それに相当する法的責任を負いません。盗作または侵害の疑いのあるコンテンツを見つけた場合は、admin@php.cn までご連絡ください。
著者別の最新記事
人気のチュートリアル
詳細>
最新のダウンロード
詳細>
ウェブエフェクト
公式サイト
サイト素材
フロントエンドテンプレート