這篇文章主要介紹了關於Tensorflow中的tf.train.batch函數的使用,現在分享給大家,也給大家做個參考。一起來看看吧
這兩天一直在看tensorflow中的讀取資料的佇列,說實話,真的是很難懂。也可能我之前沒這方面的經驗吧,最早我都使用的theano,什麼都是自己寫。經過這兩天的文檔以及相關資料,並且請教了國內的師弟。今天算是有點小感受了。簡單的說,就是計算圖是從一個管道中讀取資料的,錄入管道是用的現成的方法,讀取也是。為了確保多執行緒的時候從一個管道讀取資料不會亂吧,所以這種時候 讀取的時候需要執行緒管理的相關操作。今天我實驗室了一個簡單的操作,就是給一個有序的數據,看看讀出來是不是有序的,結果發現是有序的,所以直接給代碼:
import tensorflow as tf import numpy as np def generate_data(): num = 25 label = np.asarray(range(0, num)) images = np.random.random([num, 5, 5, 3]) print('label size :{}, image size {}'.format(label.shape, images.shape)) return label, images def get_batch_data(): label, images = generate_data() images = tf.cast(images, tf.float32) label = tf.cast(label, tf.int32) input_queue = tf.train.slice_input_producer([images, label], shuffle=False) image_batch, label_batch = tf.train.batch(input_queue, batch_size=10, num_threads=1, capacity=64) return image_batch, label_batch image_batch, label_batch = get_batch_data() with tf.Session() as sess: coord = tf.train.Coordinator() threads = tf.train.start_queue_runners(sess, coord) i = 0 try: while not coord.should_stop(): image_batch_v, label_batch_v = sess.run([image_batch, label_batch]) i += 1 for j in range(10): print(image_batch_v.shape, label_batch_v[j]) except tf.errors.OutOfRangeError: print("done") finally: coord.request_stop() coord.join(threads)
記得那個slice_input_producer方法,預設是要shuffle的哈。
Besides, I would like to comment this code.
1: there is a parameter 'num_epochs' in slice_input_producer, which controls how many epochs the slice_input_producer, which controls how many epochs the slice_input_producer method wod method runs the specified epochs, it would report the OutOfRangeRrror. I think it would be useful for our control the training epochs.
2: the output of this method is one s
tf.train.batch([example, label], batch_size=batch_size, capacity=capacity):[example, label ]表示樣本和樣本標籤,這個可以是一個樣本和一個樣本標籤,batch_size是一個傳回的一個batch樣本集的樣本個數。 capacity是隊列中的容量。這主要是依序組合成一個batch
tf.train.shuffle_batch([example, label], batch_size=batch_size, capacity=capacity, min_after_dequeue)。這裡面的參數和上面的一樣的意思。不一樣的是這個參數min_after_dequeue,一定要確保這參數小於capacity參數的值,否則會出錯。這個代表佇列中的元素大於它的時候就輸出亂的順序的batch。也就是說這個函數的輸出結果是一個亂序的樣本排列的batch,不是按照順序排列的。
上面的函數回傳值都是一個batch的樣本和樣本標籤,只是一個是按照順序,另外一個是隨機的
相關推薦:
############################################################# ######
以上是關於Tensorflow中的tf.train.batch函數的詳細內容。更多資訊請關注PHP中文網其他相關文章!