2017-06-07 150 views
0

永远困我想通过批量培养我的模型批,因为我无法找到任何例子,如何正确地做到这一点。就我所能做的事情而言,我的任务是在Tensorflow中逐批地训练模型。Tensorflow:教育训练由一批sess.run

queue=tf.FIFOQueue(capacity=50,dtypes=[tf.float32,tf.float32],shapes=[[10],[2]]) 
enqueue_op=queue.enqueue_many([X,Y]) 
dequeue_op=queue.dequeue() 

qr=tf.train.QueueRunner(queue,[enqueue_op]*2) 

with tf.Session() as sess: 
    sess.run(tf.global_variables_initializer()) 
    X_train_batch,y_train_batch=tf.train.batch(dequeue_op,batch_size=2) 
    coord=tf.train.Coordinator() 
    enqueue_threads=qr.create_threads(sess,coord,start=True) 
    sess.run(tf.local_variables_initializer()) 
    for epoch in range(100): 
     print("inside loop1") 
     for iter in range(5): 
      print("inside loop2") 
      if coord.should_stop(): 
       break 
      batch_x,batch_y=sess.run([X_train_batch,y_train_batch]) 
      print("after sess.run") 
      print(batch_x.shape) 
      _=sess.run(optimizer,feed_dict={x_place:batch_x,y_place:batch_y}) 
     coord.request_stop() 
     coord.join(enqueue_threads) 

,输出,

inside loop1 
inside loop2 

正如你所看到的, 这永远困在它运行batch_x,batch_y=sess.run([X_train_batch,y_train_batch])线。 我不知道我该如何解决这个问题,或者这是逐批地训练模型的正确方法?

+0

是输出真的“内循环1,内循环1”或是“内循环1,内循环2”?其次,在我看来,你最后两行缩进了一点,应该与“纪元”一致。 – Wontonimo

+0

抱歉的错字,现在编辑,我找到了解决方案,现在编辑问题.. –

回答

1

经过几个小时的搜索,我自己找到了解决方案。所以,我现在在下面回答我自己的问题。 的队列由后台线程,其创建,当你调用tf.train.start_queue_runners()如果你不调用这个方法,后台线程将无法启动,队列将保持为空,训练运会无限期地阻塞等待输入填补。

FIX: 就在训练循环之前调用tf.train.start_queue_runners(sess)。 像我这样做:

queue=tf.FIFOQueue(capacity=50,dtypes=[tf.float32,tf.float32],shapes=[[10],[2]]) 
enqueue_op=queue.enqueue_many([X,Y]) 
dequeue_op=queue.dequeue() 

qr=tf.train.QueueRunner(queue,[enqueue_op]*2) 

with tf.Session() as sess: 
    sess.run(tf.global_variables_initializer()) 
    X_train_batch,y_train_batch=tf.train.batch(dequeue_op,batch_size=2) 
    coord=tf.train.Coordinator() 
    enqueue_threads=qr.create_threads(sess,coord,start=True) 
    tf.train.start_queue_runners(sess) 
    for epoch in range(100): 
     print("inside loop1") 
     for iter in range(5): 
      print("inside loop2") 
      if coord.should_stop(): 
       break 
      batch_x,batch_y=sess.run([X_train_batch,y_train_batch]) 
      print("after sess.run") 
      print(batch_x.shape) 
      _=sess.run(optimizer,feed_dict={x_place:batch_x,y_place:batch_y}) 
     coord.request_stop() 
     coord.join(enqueue_threads)