才会出现此页面很常在谷歌,所以我想我会试着改进Yaroslav's answer,我希望对于那些刚刚进入分布式Tensorflow的人来说,我希望这是一个更明确的答案。
import tensorflow as tf
import threading
def main(job_name, task):
cluster = tf.train.ClusterSpec({
'ps': ['localhost:22222', 'localhost:22223'],
'worker': ['localhost: 22224','localhost: 22225','localhost: 22226']
})
server = tf.train.Server(cluster, job_name=job_name, task_index=task)
if job_name == 'ps':
# create a shared queue on the parameter server which is visible on /job:ps/task:%d
with tf.device('/job:ps/task:%d' % task):
queue = tf.FIFOQueue(cluster.num_tasks('worker'), tf.int32, shared_name='done_queue%d' % task)
# wait for the queue to be filled
with tf.Session(server.target) as sess:
for i in range(cluster.num_tasks('worker')):
sess.run(queue.dequeue())
print('ps:%d received "done" from worker:%d' % (task, i))
print('ps:%d quitting' % task)
elif job_name == 'worker':
queues = []
# create a shared queue on the worker which is visible on /job:ps/task:%d
for i in range(cluster.num_tasks('ps')):
with tf.device('/job:ps/task:%d' % i):
queues.append(tf.FIFOQueue(cluster.num_tasks('worker'), tf.int32, shared_name='done_queue%d' % i))
# fill the queue
with tf.Session(server.target) as sess:
for i in range(cluster.num_tasks('ps')):
_, size = sess.run([queues[i].enqueue(task), queues[i].size()])
print('Worker:%d sending "done" to ps:%d [elements=%d]' % (task, i, size))
if __name__ == '__main__':
threads = [
threading.Thread(target=main, args=('ps', 0)),
threading.Thread(target=main, args=('ps', 1)),
threading.Thread(target=main, args=('worker', 0)),
threading.Thread(target=main, args=('worker', 1)),
threading.Thread(target=main, args=('worker', 2))]
for thread in threads:
thread.start()
for thread in threads:
thread.join()
这是很简单的用这个片段替换代码的工人节在“规范” Distributed Tensorflow example扩展:
# create a worker that does nothing
elif job_name == 'worker':
with tf.device(tf.train.replica_device_setter(worker_device='/job:worker/task:%d' % task, cluster=cluster)):
global_step = tf.train.get_or_create_global_step()
no_op = tf.no_op()
done_ops = []
# create a shared queue on the worker which is visible on /job:ps/task:%d
for i in range(cluster.num_tasks('ps')):
with tf.device('/job:ps/task:%d' % i):
done_queue = tf.FIFOQueue(cluster.num_tasks('worker'), tf.int32, shared_name='done_queue' + str(i))
done_ops.append(done_queue.enqueue(task))
hooks=[tf.train.StopAtStepHook(last_step=1),
tf.train.FinalOpsHook([done_ops])]
with tf.train.MonitoredTrainingSession(master=server.target,
is_chief=(task == 0),
hooks=hooks) as sess:
sess.run([no_op])
注意,MonitoredTrainingSession版本似乎是在连接所有慢得多的工人在一起。
感谢您的回答。它工作得很好。但是当我尝试使用'tf.Supervisor'(TF网站上的那个)来适应这个例子时,我遇到了一些问题。一旦我实例化了一个“supervisor”对象,该图将被“终结”。因此我们不能在训练后排队。使用两个图表可能会工作,但根据[本文](http://stackoverflow.com/a/34249940/4811003),它可能会影响性能。有没有一个好的解决方案? – fois
'queue0.enqueue(1)'实际上创建一个enqueue操作并修改图形。你可以改为'op1 = queue0.enqueue(1);; sess.run(op1)' –
你说得对。我很愚蠢。 非常感谢。 – fois