Enregistrement et restauration des modèles TensorFlow entraînés
TensorFlow offre des fonctionnalités transparentes pour enregistrer et restaurer des modèles TensorFlow entraînés, vous permettant de conserver et de réutiliser vos modèles dans divers scénarios.
Sauvegarder le Modèle
Pour enregistrer un modèle entraîné dans TensorFlow, vous pouvez utiliser la classe tf.train.Saver. Voici un exemple :
import tensorflow as tf # Prepare placeholders and variables w1 = tf.placeholder(tf.float32, name="w1") w2 = tf.placeholder(tf.float32, name="w2") b1 = tf.Variable(2.0, name="bias") feed_dict = {w1: 4, w2: 8} # Define an operation to be restored w3 = tf.add(w1, w2) w4 = tf.multiply(w3, b1, name="op_to_restore") sess = tf.Session() sess.run(tf.global_variables_initializer()) # Create a saver object saver = tf.train.Saver() # Run the operation and save the graph print(sess.run(w4, feed_dict)) saver.save(sess, 'my_test_model', global_step=1000)
Restauration du modèle
Pour restaurer un modèle précédemment enregistré, vous pouvez utiliser le processus suivant :
import tensorflow as tf sess = tf.Session() # Load the meta graph and restore weights saver = tf.train.import_meta_graph('my_test_model-1000.meta') saver.restore(sess, tf.train.latest_checkpoint('./')) # Access saved variables directly print(sess.run('bias:0')) # Prints 2 (the bias value) # Access and create feed-dict for new input data graph = tf.get_default_graph() w1 = graph.get_tensor_by_name("w1:0") w2 = graph.get_tensor_by_name("w2:0") feed_dict = {w1: 13.0, w2: 17.0} # Access the desired operation op_to_restore = graph.get_tensor_by_name("op_to_restore:0") print(sess.run(op_to_restore, feed_dict)) # Prints 60 ((w1 + w2) * b1)
Pour des scénarios et des cas d'utilisation supplémentaires, reportez-vous aux ressources fournies dans les réponses fournies, qui approfondissent la sauvegarde et la restauration de TensorFlow. modèles.
Ce qui précède est le contenu détaillé de. pour plus d'informations, suivez d'autres articles connexes sur le site Web de PHP en chinois!