首頁 後端開發 Python教學 TensorFlow模型保存和提取方法範例

TensorFlow模型保存和提取方法範例

Apr 26, 2018 pm 04:34 PM
tensorflow 提取 方法

本篇文章主要介紹了TensorFlow模型保存和提取方法範例,現在分享給大家,也給大家做個參考。一起來看看吧

一、TensorFlow模型保存和提取方法

1. TensorFlow透過tf.train.Saver類別實作神經網絡模型的保存與提取。 tf.train.Saver物件saver的save方法將TensorFlow模型儲存到指定路徑中,saver.save(sess,"Model/model.ckpt"),實際上在這個檔案目錄下會產生4個人檔案:

checkpoint檔案保存了一個錄下多有的模型檔案列表,model.ckpt.meta保存了TensorFlow計算圖的結構信息,model.ckpt保存每個變數的取值,此處檔案名稱的寫入方式會因不同參數的設定而不同,但載入restore時的檔案路徑名稱是以checkpoint檔案中的「model_checkpoint_path」值決定的。

2. 載入這個已儲存的TensorFlow模型的方法是saver.restore(sess,"./Model/model.ckpt"),載入模型的程式碼也要定義TensorFlow計算圖上的所有運算並且宣告一個tf.train.Saver類,不同的是載入模型時不需要進行變數的初始化,而是將變數的取值透過保存的模型載入進來,注意載入路徑的寫法。若不希望重複定義計算圖上的運算,可直接載入已經持久化的圖,saver =tf.train.import_meta_graph("Model/model.ckpt.meta")。

3.tf.train.Saver類別也支援在儲存和載入時給變數重新命名,宣告Saver類別物件的時候使用一個字典dict重命名變數即可,{"已儲存的變數的名稱name": 重新命名變數名稱},saver = tf.train.Saver({"v1":u1, "v2": u2})即原來名稱name為v1的變數現在載入到變數u1(名稱name為other- v1)中。

4. 上一條做的目的之一就是方便使用變數的滑動平均值。如果在載入模型時直接將影子變數對應到變數自身,則在使用訓練好的模型時就不需要再呼叫函數來取得變數的滑動平均值了。載入時,宣告Saver類別物件時透過一個字典將滑動平均值直接載入到新的變數中,saver = tf.train.Saver({"v/ExponentialMovingAverage": v}),另透過tf.train.ExponentialMovingAverage的variables_to_restore()函數取得變數重新命名字典。

此外,透過convert_variables_to_constants函數將計算圖中的變數及其取值透過常數的方式保存於一個檔案中。

二、TensorFlow程式實作

#
# 本文件程序为配合教材及学习进度渐进进行,请按照注释分段执行 
# 执行时要注意IDE的当前工作过路径,最好每段重启控制器一次,输出结果更准确  
# Part1: 通过tf.train.Saver类实现保存和载入神经网络模型  
# 执行本段程序时注意当前的工作路径 
import tensorflow as tf  
v1 = tf.Variable(tf.constant(1.0, shape=[1]), name="v1") 
v2 = tf.Variable(tf.constant(2.0, shape=[1]), name="v2") 
result = v1 + v2  
saver = tf.train.Saver()  
with tf.Session() as sess: 
  sess.run(tf.global_variables_initializer()) 
  saver.save(sess, "Model/model.ckpt")  
 
# Part2: 加载TensorFlow模型的方法  
import tensorflow as tf  
v1 = tf.Variable(tf.constant(1.0, shape=[1]), name="v1") 
v2 = tf.Variable(tf.constant(2.0, shape=[1]), name="v2") 
result = v1 + v2  
saver = tf.train.Saver()  
with tf.Session() as sess: 
  saver.restore(sess, "./Model/model.ckpt") # 注意此处路径前添加"./" 
  print(sess.run(result)) # [ 3.] 
  
# Part3: 若不希望重复定义计算图上的运算,可直接加载已经持久化的图  
import tensorflow as tf  
saver = tf.train.import_meta_graph("Model/model.ckpt.meta")  
with tf.Session() as sess: 
  saver.restore(sess, "./Model/model.ckpt") # 注意路径写法 
  print(sess.run(tf.get_default_graph().get_tensor_by_name("add:0"))) # [ 3.] 
  
# Part4: tf.train.Saver类也支持在保存和加载时给变量重命名  
import tensorflow as tf  
# 声明的变量名称name与已保存的模型中的变量名称name不一致 
u1 = tf.Variable(tf.constant(1.0, shape=[1]), name="other-v1") 
u2 = tf.Variable(tf.constant(2.0, shape=[1]), name="other-v2") 
result = u1 + u2  
# 若直接生命Saver类对象,会报错变量找不到 
# 使用一个字典dict重命名变量即可,{"已保存的变量的名称name": 重命名变量名} 
# 原来名称name为v1的变量现在加载到变量u1(名称name为other-v1)中 
saver = tf.train.Saver({"v1": u1, "v2": u2})  
with tf.Session() as sess: 
  saver.restore(sess, "./Model/model.ckpt") 
  print(sess.run(result)) # [ 3.] 
  
# Part5: 保存滑动平均模型  
import tensorflow as tf  
v = tf.Variable(0, dtype=tf.float32, name="v") 
for variables in tf.global_variables(): 
  print(variables.name) # v:0  
ema = tf.train.ExponentialMovingAverage(0.99) 
maintain_averages_op = ema.apply(tf.global_variables()) 
for variables in tf.global_variables(): 
  print(variables.name) # v:0 
             # v/ExponentialMovingAverage:0  
saver = tf.train.Saver()  
with tf.Session() as sess: 
  sess.run(tf.global_variables_initializer()) 
  sess.run(tf.assign(v, 10)) 
  sess.run(maintain_averages_op) 
  saver.save(sess, "Model/model_ema.ckpt") 
  print(sess.run([v, ema.average(v)])) # [10.0, 0.099999905]  
 
# Part6: 通过变量重命名直接读取变量的滑动平均值  
import tensorflow as tf  
v = tf.Variable(0, dtype=tf.float32, name="v") 
saver = tf.train.Saver({"v/ExponentialMovingAverage": v}) 
 with tf.Session() as sess: 
  saver.restore(sess, "./Model/model_ema.ckpt") 
  print(sess.run(v)) # 0.0999999 
  
# Part7: 通过tf.train.ExponentialMovingAverage的variables_to_restore()函数获取变量重命名字典  
import tensorflow as tf  
v = tf.Variable(0, dtype=tf.float32, name="v") 
# 注意此处的变量名称name一定要与已保存的变量名称一致 
ema = tf.train.ExponentialMovingAverage(0.99) 
print(ema.variables_to_restore()) 
# {&#39;v/ExponentialMovingAverage&#39;: <tf.Variable &#39;v:0&#39; shape=() dtype=float32_ref>} 
# 此处的v取自上面变量v的名称name="v"  
saver = tf.train.Saver(ema.variables_to_restore()) 
 with tf.Session() as sess: 
  saver.restore(sess, "./Model/model_ema.ckpt") 
  print(sess.run(v)) # 0.0999999 
 
# Part8: 通过convert_variables_to_constants函数将计算图中的变量及其取值通过常量的方式保存于一个文件中  
import tensorflow as tf 
from tensorflow.python.framework import graph_util  
v1 = tf.Variable(tf.constant(1.0, shape=[1]), name="v1") 
v2 = tf.Variable(tf.constant(2.0, shape=[1]), name="v2") 
result = v1 + v2  
with tf.Session() as sess: 
  sess.run(tf.global_variables_initializer()) 
  # 导出当前计算图的GraphDef部分,即从输入层到输出层的计算过程部分 
  graph_def = tf.get_default_graph().as_graph_def() 
  output_graph_def = graph_util.convert_variables_to_constants(sess, 
                            graph_def, [&#39;add&#39;])  
  with tf.gfile.GFile("Model/combined_model.pb", &#39;wb&#39;) as f: 
    f.write(output_graph_def.SerializeToString()) 
  
# Part9: 载入包含变量及其取值的模型  
import tensorflow as tf 
from tensorflow.python.platform import gfile  
with tf.Session() as sess: 
  model_filename = "Model/combined_model.pb" 
  with gfile.FastGFile(model_filename, &#39;rb&#39;) as f: 
    graph_def = tf.GraphDef() 
    graph_def.ParseFromString(f.read())  
  result = tf.import_graph_def(graph_def, return_elements=["add:0"]) 
  print(sess.run(result)) # [array([ 3.], dtype=float32)]
登入後複製

相關推薦:

詳解tensorflow載入資料的三種方式

tensorflow 使用flags定義指令列參數的方法

以上是TensorFlow模型保存和提取方法範例的詳細內容。更多資訊請關注PHP中文網其他相關文章!

本網站聲明
本文內容由網友自願投稿,版權歸原作者所有。本站不承擔相應的法律責任。如發現涉嫌抄襲或侵權的內容,請聯絡admin@php.cn

熱AI工具

Undresser.AI Undress

Undresser.AI Undress

人工智慧驅動的應用程序,用於創建逼真的裸體照片

AI Clothes Remover

AI Clothes Remover

用於從照片中去除衣服的線上人工智慧工具。

Undress AI Tool

Undress AI Tool

免費脫衣圖片

Clothoff.io

Clothoff.io

AI脫衣器

Video Face Swap

Video Face Swap

使用我們完全免費的人工智慧換臉工具,輕鬆在任何影片中換臉!

熱門文章

<🎜>:泡泡膠模擬器無窮大 - 如何獲取和使用皇家鑰匙
3 週前 By 尊渡假赌尊渡假赌尊渡假赌
北端:融合系統,解釋
3 週前 By 尊渡假赌尊渡假赌尊渡假赌
Mandragora:巫婆樹的耳語 - 如何解鎖抓鉤
3 週前 By 尊渡假赌尊渡假赌尊渡假赌

熱工具

記事本++7.3.1

記事本++7.3.1

好用且免費的程式碼編輯器

SublimeText3漢化版

SublimeText3漢化版

中文版,非常好用

禪工作室 13.0.1

禪工作室 13.0.1

強大的PHP整合開發環境

Dreamweaver CS6

Dreamweaver CS6

視覺化網頁開發工具

SublimeText3 Mac版

SublimeText3 Mac版

神級程式碼編輯軟體(SublimeText3)

熱門話題

Java教學
1665
14
CakePHP 教程
1424
52
Laravel 教程
1322
25
PHP教程
1270
29
C# 教程
1250
24
怎麼刪除微信好友?刪除微信好友的方法 怎麼刪除微信好友?刪除微信好友的方法 Mar 04, 2024 am 11:10 AM

微信是主流的聊天工具之一,我們可以透過微信認識新的朋友,聯絡老的朋友,維繫朋友之間的友誼。正如天下沒有不散的宴席,人與人之間的相處難免會發生意見不合的時候。當一個人極度影響你的情緒,或是在相處的時候發現三觀不合,沒辦法再繼續溝通,那麼我們可能需要刪除微信好友的方法。怎麼刪除微信好友?刪除微信好友的方法第一步:在微信主介面輕觸【通訊錄】;第二步:點選對應要刪除的好友,進入【詳細資料】;第三步:點選右上角【...】;第四步:點選下方【刪除】即可;第五步:了解後頁面提示後,點選【刪除聯絡人】即可;溫馨

微信刪除的人如何找回(簡單教學告訴你如何恢復被刪除的聯絡人) 微信刪除的人如何找回(簡單教學告訴你如何恢復被刪除的聯絡人) May 01, 2024 pm 12:01 PM

而後悔莫及、人們常常會因為一些原因不小心刪除某些聯絡人、微信作為一款廣泛使用的社群軟體。幫助用戶解決這個問題,本文將介紹如何透過簡單的方法找回被刪除的聯絡人。 1.了解微信聯絡人刪除機制這為我們找回被刪除的聯絡人提供了可能性、微信中的聯絡人刪除機制是將其從通訊錄中移除,但並未完全刪除。 2.使用微信內建「通訊錄恢復」功能微信提供了「通訊錄恢復」節省時間和精力,使用者可以透過此功能快速找回先前刪除的聯絡人,功能。 3.進入微信設定頁面點選右下角,開啟微信應用程式「我」再點選右上角設定圖示、進入設定頁面,,

七彩虹主機板怎麼進入bios?教你兩種方法 七彩虹主機板怎麼進入bios?教你兩種方法 Mar 13, 2024 pm 06:01 PM

  七彩虹主機板在中國國內市場享有較高的知名度和市場佔有率,但是有些七彩虹主機板的用戶還不清楚怎麼進入bios進行設定呢?針對這一情況,小編專門為大家帶來了兩種進入七彩虹主機板bios的方法,快來試試吧!方法一:使用u盤啟動快捷鍵直接進入u盤裝系統七彩虹主機板一鍵啟動u盤的快捷鍵是ESC或F11,首先使用黑鯊裝機大師製作一個黑鯊U盤啟動盤,然後開啟電腦,當看到開機畫面的時候,連續按下鍵盤上的ESC或F11鍵以後將會進入到一個啟動項順序選擇的窗口,將遊標移到顯示“USB”的地方,然

怎麼在番茄免費小說app中寫小說 分享番茄小說寫小說方法教程 怎麼在番茄免費小說app中寫小說 分享番茄小說寫小說方法教程 Mar 28, 2024 pm 12:50 PM

番茄小說是一款非常熱門的小說閱讀軟體,我們在番茄小說中經常會有新的小說和漫畫可以去閱讀,每一本小說和漫畫都很有意思,很多小伙伴也想著要去寫小說來賺取賺取零用錢,在把自己想要寫的小說內容編輯成文字,那麼我們要怎麼樣在這裡面去寫小說呢?小伙伴們都不知道,那就讓我們一起到本站本站中花點時間來看寫小說的方法介紹。分享番茄小說寫小說方法教學  1、先在手機上打開番茄免費小說app,點擊個人中心——作家中心  2、跳到番茄作家助手頁面——點擊創建新書在小說的結

手機版龍蛋孵化方法大揭密(一步一步教你如何成功孵化手機版龍蛋) 手機版龍蛋孵化方法大揭密(一步一步教你如何成功孵化手機版龍蛋) May 04, 2024 pm 06:01 PM

手機遊戲成為了人們生活中不可或缺的一部分,隨著科技的發展。它以其可愛的龍蛋形象和有趣的孵化過程吸引了眾多玩家的關注,而其中一款備受矚目的遊戲就是手機版龍蛋。幫助玩家們在遊戲中更好地培養和成長自己的小龍,本文將向大家介紹手機版龍蛋的孵化方法。 1.選擇合適的龍蛋種類玩家需要仔細選擇自己喜歡並且適合自己的龍蛋種類,根據遊戲中提供的不同種類的龍蛋屬性和能力。 2.提升孵化機的等級玩家需要透過完成任務和收集道具來提升孵化機的等級,孵化機的等級決定了孵化速度和孵化成功率。 3.收集孵化所需的資源玩家需要在遊戲中

Win11管理員權限取得方法總計 Win11管理員權限取得方法總計 Mar 09, 2024 am 08:45 AM

Win11管理員權限取得方法匯總在Windows11作業系統中,管理員權限是非常重要的權限之一,可以讓使用者對系統進行各種操作。有時候,我們可能需要取得管理員權限來完成一些操作,例如安裝軟體、修改系統設定等。下面就為大家總結了一些取得Win11管理員權限的方法,希望能幫助大家。 1.使用快捷鍵在Windows11系統中,可以透過快捷鍵的方式快速開啟命令提

Oracle版本查詢方法詳解 Oracle版本查詢方法詳解 Mar 07, 2024 pm 09:21 PM

Oracle版本查詢方法詳解Oracle是目前世界上最受歡迎的關聯式資料庫管理系統之一,它提供了豐富的功能和強大的效能,廣泛應用於企業。在進行資料庫管理和開發過程中,了解Oracle資料庫的版本是非常重要的。本文將詳細介紹如何查詢Oracle資料庫的版本信息,並給出具體的程式碼範例。查詢資料庫版本的SQL語句在Oracle資料庫中,可以透過執行簡單的SQL語句

手機字體大小設定方法(輕鬆調整手機字體大小) 手機字體大小設定方法(輕鬆調整手機字體大小) May 07, 2024 pm 03:34 PM

字體大小的設定成為了重要的個人化需求,隨著手機成為人們日常生活的重要工具。以滿足不同使用者的需求、本文將介紹如何透過簡單的操作,提升手機使用體驗,調整手機字體大小。為什麼需要調整手機字體大小-調整字體大小可以使文字更清晰易讀-適合不同年齡段用戶的閱讀需求-方便視力不佳的用戶使用手機系統自帶字體大小設置功能-如何進入系統設置界面-在在設定介面中找到並進入"顯示"選項-找到"字體大小"選項並進行調整第三方應用調整字體大小-下載並安裝支援字體大小調整的應用程式-開啟應用程式並進入相關設定介面-根據個人

See all articles