Maison > développement back-end > Tutoriel Python > Comment puis-je enregistrer et restaurer des modèles TensorFlow entraînés ?

Comment puis-je enregistrer et restaurer des modèles TensorFlow entraînés ?

DDD
Libérer: 2024-12-19 17:41:09
original
636 Les gens l'ont consulté

How Can I Save and Restore Trained TensorFlow Models?

Enregistrement et restauration des modèles TensorFlow entraînés

TensorFlow offre des fonctionnalités transparentes pour enregistrer et restaurer des modèles TensorFlow entraînés, vous permettant de conserver et de réutiliser vos modèles dans divers scénarios.

Sauvegarder le Modèle

Pour enregistrer un modèle entraîné dans TensorFlow, vous pouvez utiliser la classe tf.train.Saver. Voici un exemple :

import tensorflow as tf

# Prepare placeholders and variables
w1 = tf.placeholder(tf.float32, name="w1")
w2 = tf.placeholder(tf.float32, name="w2")
b1 = tf.Variable(2.0, name="bias")
feed_dict = {w1: 4, w2: 8}

# Define an operation to be restored
w3 = tf.add(w1, w2)
w4 = tf.multiply(w3, b1, name="op_to_restore")
sess = tf.Session()
sess.run(tf.global_variables_initializer())

# Create a saver object
saver = tf.train.Saver()

# Run the operation and save the graph
print(sess.run(w4, feed_dict))
saver.save(sess, 'my_test_model', global_step=1000)
Copier après la connexion

Restauration du modèle

Pour restaurer un modèle précédemment enregistré, vous pouvez utiliser le processus suivant :

import tensorflow as tf

sess = tf.Session()

# Load the meta graph and restore weights
saver = tf.train.import_meta_graph('my_test_model-1000.meta')
saver.restore(sess, tf.train.latest_checkpoint('./'))

# Access saved variables directly
print(sess.run('bias:0'))  # Prints 2 (the bias value)

# Access and create feed-dict for new input data
graph = tf.get_default_graph()
w1 = graph.get_tensor_by_name("w1:0")
w2 = graph.get_tensor_by_name("w2:0")
feed_dict = {w1: 13.0, w2: 17.0}

# Access the desired operation
op_to_restore = graph.get_tensor_by_name("op_to_restore:0")

print(sess.run(op_to_restore, feed_dict))  # Prints 60 ((w1 + w2) * b1)
Copier après la connexion

Pour des scénarios et des cas d'utilisation supplémentaires, reportez-vous aux ressources fournies dans les réponses fournies, qui approfondissent la sauvegarde et la restauration de TensorFlow. modèles.

Ce qui précède est le contenu détaillé de. pour plus d'informations, suivez d'autres articles connexes sur le site Web de PHP en chinois!

source:php.cn
Déclaration de ce site Web
Le contenu de cet article est volontairement contribué par les internautes et les droits d'auteur appartiennent à l'auteur original. Ce site n'assume aucune responsabilité légale correspondante. Si vous trouvez un contenu suspecté de plagiat ou de contrefaçon, veuillez contacter admin@php.cn
Tutoriels populaires
Plus>
Derniers téléchargements
Plus>
effets Web
Code source du site Web
Matériel du site Web
Modèle frontal