下面死锁的代码:死锁在tensorflow的MonitoredTrainingSession和切片投入生产
import tensorflow as tf
def train():
"""Stripped down and modified from cifar10.cifar10_train.train"""
global_step = tf.contrib.framework.get_or_create_global_step() # for StopAtStepHook
images = tf.constant([[1, 2, 3], [1, 2, 3]])
labels = tf.constant([[1, 2, 3], [1, 2, 3]])
images, labels = tf.train.slice_input_producer([images, labels],
shuffle=False)
# input_var = tf.Variable([0, 0, 0])
# images = input_var.assign(images) # TODO placeholder would work ?
# input_batch = tf.scatter_nd_update(images, [[1, 2]], [77])
input_batch = tf.scatter_nd_update(tf.Variable(images), [[1, 2]], [77])
tf_print = tf.Print(input_batch, [input_batch])
with tf.train.MonitoredTrainingSession(
hooks=[tf.train.StopAtStepHook(last_step=3)]) as mon_sess:
while not mon_sess.should_stop():
mon_sess.run(tf_print)
if __name__ == '__main__':
train()
然而,如果我注释掉input_batch = tf.scatter_nd_update(tf.Variable(images), [[1, 2]], [77])
并取消注释行计划保持印刷:
I c:\tf_jenkins\home\workspace\release-win\device\cpu\os\windows\tensorflow\core\kernels\logging_ops.cc:79] [1 2 3]
- 为什么它僵局?是否正确地使用额外的变量来解决这个问题?或者我应该以某种方式使用占位符?
- 我错过了什么,它不会在3个步骤后终止?
有道理的感谢!没有其他办法吗?我应该以某种方式使用占位符吗? (还没有调查global_step API,希望它在StopAtStepHook中神奇地更新) –
'placeholder'不会像'scatter_nd_update'文档非常具体:'ref:A mutable'Tensor'一样工作。一个可变的张量。应该来自一个变量节点。“这一点由错误消息证实:'TypeError:'ScatterNdUpdate'的输入'ref'需要输入l值。 此外,你总是有能力不使用局部变量,例如: 'input_batch = tf.scatter_nd_update(tf.Variable([0,0,0])。assign(images),[[1,2] ],[77])'如果你不喜欢。 – npf
哦谢谢 - 我的意思是一个占位符而不是任务,我基本上是在寻找一个正确的方法。欢迎来到SO :) –