2016-03-01 29 views
3
import tensorflow as tf 
sess = tf.Session() 

def add_to_batch(image): 

    print('Adding to batch') 
    image_batch = tf.train.shuffle_batch([image],batch_size=5,capacity=11,min_after_dequeue=1,num_threads=1) 

    # Add to summary 
    tf.image_summary('images',image_batch) 

    return image_batch 

def get_batch(): 

    # Create filename queue of images to read 
    filenames = [('/media/jessica/Jessica/TensorFlow/Practice/unlabeled_data_%d.png' % i) for i in range(11)] 
    filename_queue = tf.train.string_input_producer(filenames) 
    reader = tf.WholeFileReader() 
    key, value = reader.read(filename_queue) 

    # Read and process image 
    my_image = tf.image.decode_png(value) 
    my_image_float = tf.cast(my_image,tf.float32) 
    image_mean = tf.reduce_mean(my_image_float) 
    my_noise = tf.random_normal([96,96,3],mean=image_mean) 
    my_image_noisy = my_image_float + my_noise 
    print('Reading images') 

    return add_to_batch(my_image_noisy) 

def main(): 

    sess.run(tf.initialize_all_variables()) 
    tf.train.start_queue_runners(sess=sess) 
    writer = tf.train.SummaryWriter('/media/jessica/Jessica/TensorFlow/Practice/summary_logs', graph_def=sess.graph_def) 
    merged = tf.merge_all_summaries() 
    images = get_batch() 
    summary_str = sess.run(merged) 
    writer.add_summary(summary_str) 

嗨,TensorFlow shuffle_batch不工作

我试图建立TensorFlow一个简单的神经网络。我正在尝试分批加载我的输入图像。现在我正在测试11个图像和batch_size = 5的代码。最终我将处理100000个图像。

这段代码是从TensorFlow的cifar10.py例子中修改的。由于某种原因,我的代码停止(不终止,它只是挂在那里)tf.train.shuffle_batch([image],batch_size=5,capacity=1,min_after_dequeue=1,num_threads=1)

我试过batch_size,容量,min_after_dequeue等不同的组合,但我仍然不知道什么是错的。

任何帮助将不胜感激!谢谢!

+0

我编辑你的代码来修复缩进(否则Python解释器不会接受它)。让我知道,如果这是不正确的! – mrry

回答

7

看起来问题出现,因为声明

tf.train.start_queue_runners(sess=sess) 

...执行已创建的任何队列跑步之前。如果您在images = get_batch()之后移动此行,您的程序应该可以正常工作。

这里有什么问题? tf.train.shuffle_batch()函数内部使用tf.RandomShuffleQueue来产生随机批次。目前,将元素放入该队列的唯一方法是运行一个调用q.enqueue()操作的步骤。为了使这更容易,TensorFlow有一个概念"queue runners",它是在您构建图形时隐式收集的,然后通过致电tf.train.start_queue_runners()开始。但是,调用tf.train.start_queue_runners()仅启动在该时间点已定义的队列运行程序,因此它必须在创建队列运行程序的代码之后出现

相关问题