首頁 > 後端開發 > php教程 > 詳解tensorflow載入資料的三種方式

詳解tensorflow載入資料的三種方式

不言
發布: 2023-03-24 19:12:01
原創
3463 人瀏覽過

這篇文章主要介紹了詳解tensorflow載入資料的三種方式,現在分享給大家,也給大家做個參考。一起來看看吧

Tensorflow資料讀取有三種方式:

  1. Preloaded data: 預先載入資料

  2. Feeding : Python產生數據,再把數據餵給後端。

  3. Reading from file: 從檔案中直接讀取

#這三種有讀取方式有什麼差別呢?我們首先要知道TensorFlow(TF)是怎麼樣的運作。

TF的核心是用C 寫的,這樣的好處是運作快,缺點是呼叫不靈活。而Python則剛好相反,所以結合兩種語言的優勢。涉及計算的核心算符和運行框架是用C 寫的,並提供API給Python。 Python呼叫這些API,設計訓練模型(Graph),再將設計好的Graph給後端執行。簡而言之,Python的角色是Design,C 是Run。

一、預先載入資料:

#
import tensorflow as tf 
# 设计Graph 
x1 = tf.constant([2, 3, 4]) 
x2 = tf.constant([4, 0, 1]) 
y = tf.add(x1, x2) 
# 打开一个session --> 计算y 
with tf.Session() as sess: 
  print sess.run(y)
登入後複製

二、python產生數據,再將數據餵給後端

#
import tensorflow as tf 
# 设计Graph 
x1 = tf.placeholder(tf.int16) 
x2 = tf.placeholder(tf.int16) 
y = tf.add(x1, x2) 
# 用Python产生数据 
li1 = [2, 3, 4] 
li2 = [4, 0, 1] 
# 打开一个session --> 喂数据 --> 计算y 
with tf.Session() as sess: 
  print sess.run(y, feed_dict={x1: li1, x2: li2})
登入後複製

說明:在這裡x1, x2只是佔位符,沒有具體的值,那麼運行的時候去哪取值呢?這時候就要用到sess.run()中的feed_dict參數,將Python產生的資料餵給後端,併計算y。

這兩種方案的缺點:

1、預先載入:將資料直接內嵌到Graph中,再把Graph傳入Session中運作。當資料量比較大時,Graph的傳輸會遇到效率問題。

2、用佔位符取代數據,待運行的時候填入資料。

前兩種方法很方便,但遇到大型資料的時候就會很吃力,即使是Feeding,中間環節的增加也是不小的開銷,例如資料型別轉換等等。最優的方案就是在Graph定義好檔案讀取的方法,讓TF自己去從檔案中讀取數據,解碼成可使用的樣本集。

三、從檔案讀取,簡單來說就是將資料讀取模組的圖搭好

1、準備數據,建構三個檔案,A.csv,B.csv,C.csv

#
$ echo -e "Alpha1,A1\nAlpha2,A2\nAlpha3,A3" > A.csv 
$ echo -e "Bee1,B1\nBee2,B2\nBee3,B3" > B.csv 
$ echo -e "Sea1,C1\nSea2,C2\nSea3,C3" > C.csv
登入後複製

2、單一Reader,單個樣本

#-*- coding:utf-8 -*- 
import tensorflow as tf 
# 生成一个先入先出队列和一个QueueRunner,生成文件名队列 
filenames = ['A.csv', 'B.csv', 'C.csv'] 
filename_queue = tf.train.string_input_producer(filenames, shuffle=False) 
# 定义Reader 
reader = tf.TextLineReader() 
key, value = reader.read(filename_queue) 
# 定义Decoder 
example, label = tf.decode_csv(value, record_defaults=[['null'], ['null']]) 
#example_batch, label_batch = tf.train.shuffle_batch([example,label], batch_size=1, capacity=200, min_after_dequeue=100, num_threads=2) 
# 运行Graph 
with tf.Session() as sess: 
  coord = tf.train.Coordinator() #创建一个协调器,管理线程 
  threads = tf.train.start_queue_runners(coord=coord) #启动QueueRunner, 此时文件名队列已经进队。 
  for i in range(10): 
    print example.eval(),label.eval() 
  coord.request_stop() 
  coord.join(threads)
登入後複製

說明:這裡沒有使用tf.train.shuffle_batch,會導致產生的樣本和label之間對應不上,亂序了。產生結果如下:

Alpha1 A2
Alpha3 B1
Bee2 B3
Sea1 C2
Sea3 A1
Alpha2 A3
Bee1 B2
Bee3 C1
Sea2 C3
Alpha1 A2

解:用tf.train.shuffle_batch,那麼產生的結果就能夠對應上。

#-*- coding:utf-8 -*- 
import tensorflow as tf 
# 生成一个先入先出队列和一个QueueRunner,生成文件名队列 
filenames = ['A.csv', 'B.csv', 'C.csv'] 
filename_queue = tf.train.string_input_producer(filenames, shuffle=False) 
# 定义Reader 
reader = tf.TextLineReader() 
key, value = reader.read(filename_queue) 
# 定义Decoder 
example, label = tf.decode_csv(value, record_defaults=[['null'], ['null']]) 
example_batch, label_batch = tf.train.shuffle_batch([example,label], batch_size=1, capacity=200, min_after_dequeue=100, num_threads=2) 
# 运行Graph 
with tf.Session() as sess: 
  coord = tf.train.Coordinator() #创建一个协调器,管理线程 
  threads = tf.train.start_queue_runners(coord=coord) #启动QueueRunner, 此时文件名队列已经进队。 
  for i in range(10): 
    e_val,l_val = sess.run([example_batch, label_batch]) 
    print e_val,l_val 
  coord.request_stop() 
  coord.join(threads)
登入後複製

3、單一Reader,多個樣本,主要也是透過tf.train.shuffle_batch來實作 

#
#-*- coding:utf-8 -*- 
import tensorflow as tf 
filenames = ['A.csv', 'B.csv', 'C.csv'] 
filename_queue = tf.train.string_input_producer(filenames, shuffle=False) 
reader = tf.TextLineReader() 
key, value = reader.read(filename_queue) 
example, label = tf.decode_csv(value, record_defaults=[['null'], ['null']]) 
# 使用tf.train.batch()会多加了一个样本队列和一个QueueRunner。 
#Decoder解后数据会进入这个队列,再批量出队。 
# 虽然这里只有一个Reader,但可以设置多线程,相应增加线程数会提高读取速度,但并不是线程越多越好。 
example_batch, label_batch = tf.train.batch( 
   [example, label], batch_size=5) 
with tf.Session() as sess: 
  coord = tf.train.Coordinator() 
  threads = tf.train.start_queue_runners(coord=coord) 
  for i in range(10): 
    e_val,l_val = sess.run([example_batch,label_batch]) 
    print e_val,l_val 
  coord.request_stop() 
  coord.join(threads)
登入後複製

說明:下面這種寫法,提取出來的batch_size個樣本,特徵和label之間也是不同步的

#-*- coding:utf-8 -*- 
import tensorflow as tf 
filenames = ['A.csv', 'B.csv', 'C.csv'] 
filename_queue = tf.train.string_input_producer(filenames, shuffle=False) 
reader = tf.TextLineReader() 
key, value = reader.read(filename_queue) 
example, label = tf.decode_csv(value, record_defaults=[['null'], ['null']]) 
# 使用tf.train.batch()会多加了一个样本队列和一个QueueRunner。 
#Decoder解后数据会进入这个队列,再批量出队。 
# 虽然这里只有一个Reader,但可以设置多线程,相应增加线程数会提高读取速度,但并不是线程越多越好。 
example_batch, label_batch = tf.train.batch( 
   [example, label], batch_size=5) 
with tf.Session() as sess: 
  coord = tf.train.Coordinator() 
  threads = tf.train.start_queue_runners(coord=coord) 
  for i in range(10): 
    print example_batch.eval(), label_batch.eval() 
  coord.request_stop() 
  coord.join(threads)
登入後複製

說明:輸出結果如下:可以看出feature和label之間是不對應的

['Alpha1' 'Alpha2' 'Alpha3' 'Bee1 ' 'Bee2'] ['B3' 'C1' 'C2' 'C3' 'A1']
['Alpha2' 'Alpha3' 'Bee1' 'Bee2' 'Bee3'] ['C1' 'C2' ' C3' 'A1' 'A2']
['Alpha3' 'Bee1' 'Bee2' 'Bee3' 'Sea1'] ['C2' 'C3' 'A1' 'A2' 'A3']

4、多個reader,多個樣本

#-*- coding:utf-8 -*- 
import tensorflow as tf 
filenames = ['A.csv', 'B.csv', 'C.csv'] 
filename_queue = tf.train.string_input_producer(filenames, shuffle=False) 
reader = tf.TextLineReader() 
key, value = reader.read(filename_queue) 
record_defaults = [['null'], ['null']] 
#定义了多种解码器,每个解码器跟一个reader相连 
example_list = [tf.decode_csv(value, record_defaults=record_defaults) 
         for _ in range(2)] # Reader设置为2 
# 使用tf.train.batch_join(),可以使用多个reader,并行读取数据。每个Reader使用一个线程。 
example_batch, label_batch = tf.train.batch_join( 
   example_list, batch_size=5) 
with tf.Session() as sess: 
  coord = tf.train.Coordinator() 
  threads = tf.train.start_queue_runners(coord=coord) 
  for i in range(10): 
    e_val,l_val = sess.run([example_batch,label_batch]) 
    print e_val,l_val 
  coord.request_stop() 
  coord.join(threads)
登入後複製

tf.train.batch與tf.train.shuffle_batch函數是單個Reader讀取,但是可以多執行緒。 tf.train.batch_join與tf.train.shuffle_batch_join可設定多Reader讀取,每個Reader使用一個執行緒。至於兩種方法的效率,單Reader時,2個執行緒就達到了速度的極限。多Reader時,2個Reader達到了極限。所以並不是線程越多越快,甚至更多的線程反而會使效率下降。

5、迭代控制,設定epoch參數,並指定我們的樣本在訓練的時候只能用多少回合

#-*- coding:utf-8 -*- 
import tensorflow as tf 
filenames = ['A.csv', 'B.csv', 'C.csv'] 
#num_epoch: 设置迭代数 
filename_queue = tf.train.string_input_producer(filenames, shuffle=False,num_epochs=3) 
reader = tf.TextLineReader() 
key, value = reader.read(filename_queue) 
record_defaults = [['null'], ['null']] 
#定义了多种解码器,每个解码器跟一个reader相连 
example_list = [tf.decode_csv(value, record_defaults=record_defaults) 
         for _ in range(2)] # Reader设置为2 
# 使用tf.train.batch_join(),可以使用多个reader,并行读取数据。每个Reader使用一个线程。 
example_batch, label_batch = tf.train.batch_join( 
   example_list, batch_size=1) 
#初始化本地变量 
init_local_op = tf.initialize_local_variables() 
with tf.Session() as sess: 
  sess.run(init_local_op) 
  coord = tf.train.Coordinator() 
  threads = tf.train.start_queue_runners(coord=coord) 
  try: 
    while not coord.should_stop(): 
      e_val,l_val = sess.run([example_batch,label_batch]) 
      print e_val,l_val 
  except tf.errors.OutOfRangeError: 
      print('Epochs Complete!') 
  finally: 
      coord.request_stop() 
  coord.join(threads) 
  coord.request_stop() 
  coord.join(threads)
登入後複製

在迭代控制中,記得加入tf.initialize_local_variables(),官網教學沒有說明,但是如果不初始化,運行就會報錯。

對於傳統的機器學習而言,比方說分類問題,[x1 x2 x3]是feature。對於二分類問題,label經過one-hot編碼之後就會是[0,1]或[1,0]。一般情況下,我們會考慮將資料組織在csv檔案中,一行代表一個sample。然後使用佇列的方式去讀取資料

说明:对于该数据,前三列代表的是feature,因为是分类问题,后两列就是经过one-hot编码之后得到的label

使用队列读取该csv文件的代码如下:

#-*- coding:utf-8 -*- 
import tensorflow as tf 
# 生成一个先入先出队列和一个QueueRunner,生成文件名队列 
filenames = ['A.csv'] 
filename_queue = tf.train.string_input_producer(filenames, shuffle=False) 
# 定义Reader 
reader = tf.TextLineReader() 
key, value = reader.read(filename_queue) 
# 定义Decoder 
record_defaults = [[1], [1], [1], [1], [1]] 
col1, col2, col3, col4, col5 = tf.decode_csv(value,record_defaults=record_defaults) 
features = tf.pack([col1, col2, col3]) 
label = tf.pack([col4,col5]) 
example_batch, label_batch = tf.train.shuffle_batch([features,label], batch_size=2, capacity=200, min_after_dequeue=100, num_threads=2) 
# 运行Graph 
with tf.Session() as sess: 
  coord = tf.train.Coordinator() #创建一个协调器,管理线程 
  threads = tf.train.start_queue_runners(coord=coord) #启动QueueRunner, 此时文件名队列已经进队。 
  for i in range(10): 
    e_val,l_val = sess.run([example_batch, label_batch]) 
    print e_val,l_val 
  coord.request_stop() 
  coord.join(threads)
登入後複製

输出结果如下:

说明:

record_defaults = [[1], [1], [1], [1], [1]]
登入後複製

代表解析的模板,每个样本有5列,在数据中是默认用‘,'隔开的,然后解析的标准是[1],也即每一列的数值都解析为整型。[1.0]就是解析为浮点,['null']解析为string类型

相关推荐:

TensorFlow入门使用 tf.train.Saver()保存模型

关于Tensorflow中的tf.train.batch函数

以上是詳解tensorflow載入資料的三種方式的詳細內容。更多資訊請關注PHP中文網其他相關文章!

相關標籤:
來源:php.cn
本網站聲明
本文內容由網友自願投稿,版權歸原作者所有。本站不承擔相應的法律責任。如發現涉嫌抄襲或侵權的內容,請聯絡admin@php.cn
作者最新文章
最新問題
python - tensorflow中TFRecord是怎麼用的?
來自於 1970-01-01 08:00:00
0
0
0
java - springboot新手學習
來自於 1970-01-01 08:00:00
0
0
0
spring - JavaWeb中 Service 層的事務問題
來自於 1970-01-01 08:00:00
0
0
0
熱門教學
更多>
最新下載
更多>
網站特效
網站源碼
網站素材
前端模板