2016-03-16 66 views
1

我想用TensorFlow使用Python线程来实现异步梯度下降。在主代码,我定义图表,包括训练操作,它得到一个变量来保持global_step的计数:在Tensorflow中的线程之间共享变量

with tf.variable_scope("scope_global_step") as scope_global_step: 
    global_step = tf.get_variable('global_step', [], initializer=tf.constant_initializer(0), trainable=False) 
optimizer = tf.train.GradientDescentOptimizer(FLAGS.learning_rate) 
train_op = optimizer.minimize(loss, global_step=global_step) 

如果我打印的global_step的名字,我得到:

scope_global_step/global_step:0

主要的代码也可以启动多个线程执行training方法:

threads = [threading.Thread(target=training, args=(sess, train_op, loss, scope_global_step)) for i in xrange(NUM_TRAINING_THREADS)] 
for t in threads: t.start() 

我想每个线程如果该值停止执行global_step大于或等于FLAGS.max_steps。为此,我建立training方法,因为它如下:

def training(sess, train_op, loss, scope_global_step): 
    while (True): 
     _, loss_value = sess.run([train_op, loss]) 
     with tf.variable_scope(scope_global_step, reuse=True): 
      global_step = tf.get_variable('global_step', [], initializer=tf.constant_initializer(0), trainable=False) 
      global_step = global_step.eval(session=sess) 
      if global_step >= FLAGS.max_steps: break 

这失败消息:

ValueError: Under-sharing: Variable scope_global_step/global_step does not exist, disallowed. Did you mean to set reuse=None in VarScope?

我可以看到,:0被加到变量的名称是首次创建时,当我尝试检索它时,不使用该后缀。为什么是这样? 如果我手动将后缀添加到变量的名称,当我尝试检索它时,它仍然声称该变量不存在。为什么TensorFlow找不到变量?不应该自动在线程间共享变量吗?我的意思是,所有线程都在同一个会话中运行,对吧?

而且关系到我training方法了另一个问题:global_step.eval(session=sess)再次执行图表,或者它只是获取train_oploss操作的执行后分配到gloabl_step价值?一般来说,从Python代码中使用变量获取值的推荐方法是什么?

回答

1

TL; DR:传递您的第一个代码片段的培训螺纹参数中创建的global_steptf.Variable对象,并呼吁传入的变量sess.run(global_step)

作为一般规则,您的训练循环(尤其是单独线程中的训练循环)不应修改图形。 tf.variable_scope()上下文管理器和tf.get_variable()可以修改图(即使它们不总是),所以你不应该在你的训练循环中使用它们。最安全的做法是在创建训练线程时将global_step对象(您首先创建的对象)作为args元组之一。然后,你可以简单地重写你的训练功能:

def training(sess, train_op, loss, global_step): 
    while (True): 
     _, loss_value = sess.run([train_op, loss]) 
     current_step = sess.run(global_step) 
     if current_step >= FLAGS.max_steps: break 

为了回答您的其他问题,运行global_step.eval(session=sess)sess.run(global_step)只取了global_step变量的当前值,并且不会重新执行您的图形的其余部分。这是获取tf.Variable值以供在Python代码中使用的推荐方式。

+0

谢谢@mrry。你的解决方案当然更清洁。然而,我仍然想知道为什么'tf.get_variable()'找不到变量。你能解释为什么这样吗?谢谢! – nicolas

+0

我认为这实际上失败了,因为在另一个线程中运行'tf.variable_scope()'实际上引用了另一个'tf。Graph'实例来自您最初创建变量的实例。如果你使用了'with sess.graph.as_default(),tf.variable_scope(“scope_global_step”,reuse = True):'它会,但只有当你有一个训练线程**时。该图不是用于写入的线程安全的,并且输入变量范围会导致一些图形内部的数据结构被更新,因此您绝对不应该这样做:)。 – mrry

+0

再次感谢@mrry。我想我学到了一两件重要的事情:-) – nicolas