2017-06-02 18 views
0

下面死锁的代码:死锁在tensorflow的MonitoredTrainingSession和切片投入生产

import tensorflow as tf 

def train(): 
    """Stripped down and modified from cifar10.cifar10_train.train""" 
    global_step = tf.contrib.framework.get_or_create_global_step() # for StopAtStepHook 
    images = tf.constant([[1, 2, 3], [1, 2, 3]]) 
    labels = tf.constant([[1, 2, 3], [1, 2, 3]]) 
    images, labels = tf.train.slice_input_producer([images, labels], 
                shuffle=False) 
    # input_var = tf.Variable([0, 0, 0]) 
    # images = input_var.assign(images) # TODO placeholder would work ? 
    # input_batch = tf.scatter_nd_update(images, [[1, 2]], [77]) 
    input_batch = tf.scatter_nd_update(tf.Variable(images), [[1, 2]], [77]) 
    tf_print = tf.Print(input_batch, [input_batch]) 
    with tf.train.MonitoredTrainingSession(
      hooks=[tf.train.StopAtStepHook(last_step=3)]) as mon_sess: 
     while not mon_sess.should_stop(): 
      mon_sess.run(tf_print) 

if __name__ == '__main__': 
    train() 

然而,如果我注释掉input_batch = tf.scatter_nd_update(tf.Variable(images), [[1, 2]], [77])并取消注释行计划保持印刷:

I c:\tf_jenkins\home\workspace\release-win\device\cpu\os\windows\tensorflow\core\kernels\logging_ops.cc:79] [1 2 3]

  • 为什么它僵局?是否正确地使用额外的变量来解决这个问题?或者我应该以某种方式使用占位符?
  • 我错过了什么,它不会在3个步骤后终止?

回答

1
  1. 我不知道关于你的第一个问题,但我相信会发生什么情况是,当你创建MonitoredTrainingSession它试图初始化你图的变量。但在你的情况下,其中一个变量的初始值依赖于隐藏在tf.train.slice_input_producer之后的出列操作。由于队列尚未启动,因此代码会死锁,等待队列排入队列。 在您评论的实现中,init_op确实运行,因此队列可以启动并使您的代码正常工作。

  2. 下面是对第二个问题的解释。 StopAtStepHook依赖于global_step张量正在更新,这是不是你的脚本的情况。这段代码 tf_print = tf.group(tf.Print(input_batch, [input_batch]), tf.assign_add(global_step,1))将工作:基本上它会将tf.Print操作和global_step增量分组在一起,所以每次运行tf_print时,global_step都会递增。

    import tensorflow as tf 
    
    def train(): 
        """Stripped down and modified from cifar10.cifar10_train.train""" 
        global_step = tf.contrib.framework.get_or_create_global_step() # for StopAtStepHook 
        images = tf.constant([[1, 2, 3], [1, 2, 3]]) 
        labels = tf.constant([[1, 2, 3], [1, 2, 3]]) 
        images, labels = tf.train.slice_input_producer([images, labels], shuffle=False) 
        input_var = tf.Variable([0, 0, 0]) 
        images = input_var.assign(images) # TODO placeholder would work ? 
        input_batch = tf.scatter_nd_update(images, [[1, 2]], [77]) 
        tf_print = tf.group(tf.Print(input_batch, [input_batch]), 
             tf.assign_add(global_step, 1)) 
        with tf.train.MonitoredTrainingSession(
          hooks=[tf.train.StopAtStepHook(last_step=3)]) as mon_sess: 
         while not mon_sess.should_stop(): 
          mon_sess.run(tf_print) 
    
    if __name__ == '__main__': 
        train() 
    
+0

有道理的感谢!没有其他办法吗?我应该以某种方式使用占位符吗? (还没有调查global_step API,希望它在StopAtStepHook中神奇地更新) –

+0

'placeholder'不会像'scatter_nd_update'文档非常具体:'ref:A mutable'Tensor'一样工作。一个可变的张量。应该来自一个变量节点。“这一点由错误消息证实:'TypeError:'ScatterNdUpdate'的输入'ref'需要输入l值。 此外,你总是有能力不使用局部变量,例如: 'input_batch = tf.scatter_nd_update(tf.Variable([0,0,0])。assign(images),[[1,2] ],[77])'如果你不喜欢。 – npf

+0

哦谢谢 - 我的意思是一个占位符而不是任务,我基本上是在寻找一个正确的方法。欢迎来到SO :) –