2017-04-05 57 views
4

我想创建一个从迭代器填充的队列。在下面的MWE然而,总是相同的值入队:从Python迭代器填充队列

import tensorflow as tf 
import numpy as np 

# data 
imgs = [np.random.randn(i,i) for i in [2,3,4,5]] 

# iterate through data infinitly 
def data_iterator(): 
    while True: 
     for img in imgs: 
      yield img 

it = data_iterator() 

# create queue for data 
q = tf.FIFOQueue(capacity=5, dtypes=[tf.float64]) 

# feed next element from iterator 
enqueue_op = q.enqueue(list(next(it))) 

# setup queue runner 
numberOfThreads = 1 
qr = tf.train.QueueRunner(q, [enqueue_op] * numberOfThreads) 
tf.train.add_queue_runner(qr) 

# dequeue 
dequeue_op = q.dequeue() 
dequeue_op = tf.Print(dequeue_op, data=[dequeue_op], message="dequeue()") 

# We start the session as usual ... 
with tf.Session() as sess: 
    coord = tf.train.Coordinator() 
    threads = tf.train.start_queue_runners(coord=coord) 

    for i in range(10): 
     data = sess.run(dequeue_op) 
     print(data) 
. 
    coord.request_stop() 
    coord.join(threads) 

难道我一定要使用feed_dict?如果是的话,我该如何结合QueueRunner使用它?

回答

3

当运行

enqueue_op = q.enqueue(list(next(it))) 

tensorflow将执行清单(下一个(它))正好一次。此后,它会保存第一个列表,并在每次运行enqueue_op时将其添加到q中。为了避免这种情况,你必须使用占位符。提供的占位符与tf.train.QueueRunner不兼容。请使用此:

import tensorflow as tf 
import numpy as np 
import threading 

# data 
imgs = [np.random.randn(i,i) for i in [2,3,4,5]] 

# iterate through data infinitly 
def data_iterator(): 
    while True: 
     for img in imgs: 
      yield img 

it = data_iterator() 

# create queue for data 
q = tf.FIFOQueue(capacity=5, dtypes=[tf.float64]) 

# feed next element from iterator 

img_p = tf.placeholder(tf.float64, [None, None]) 
enqueue_op = q.enqueue(img_p) 

dequeue_op = q.dequeue() 


with tf.Session() as sess: 
    coord = tf.train.Coordinator() 

    def enqueue_thread(): 
     with coord.stop_on_exception(): 
      while not coord.should_stop(): 
       sess.run(enqueue_op, feed_dict={img_p: list(next(it))}) 

    numberOfThreads = 1 
    for i in range(numberOfThreads): 
     threading.Thread(target=enqueue_thread).start() 



    for i in range(3): 
     data = sess.run(dequeue_op) 
     print(data)