2016-10-01 122 views
4

当我们要使用分布式TensorFlow,我们将使用关闭服务器TensorFlow

tf.train.Server.join() 

然而列表服务器,我无法找到任何方式关闭服务器,除了查杀处理。对于加入TensorFlow文档()是

Blocks until the server has shut down. 
This method currently blocks forever. 

这是很困扰我,因为我想为计算创造众多服务器和关闭它们,当一切结束。

有没有可能的解决方案。

感谢

回答

9

通过使用session.run(dequeue_op)而不是server.join(),您可以根据需要让参数服务器进程死亡,并让另一个进程在您希望此进程死亡时将某些内容排入该队列。

所以对于k参数服务器碎片您可以创建k队列,具有独特的shared_name属性,并从该队列尝试dequeue。当您想关闭服务器时,您会将所有队列和enqueue令牌循环到每个队列中。这会导致session.run解锁并且Python进程将运行到最后并退出,从而关闭服务器。

下面是2块碎片一个独立的例子摘自: https://gist.github.com/yaroslavvb/82a5b5302449530ca5ff59df520c369e

(多工/多碎片例如,见https://gist.github.com/yaroslavvb/ea1b1bae0a75c4aae593df7eca72d9ca

import subprocess 
import tensorflow as tf 
import time 
import sys 

flags = tf.flags 
flags.DEFINE_string("port1", "12222", "port of worker1") 
flags.DEFINE_string("port2", "12223", "port of worker2") 
flags.DEFINE_string("task", "", "internal use") 
FLAGS = flags.FLAGS 

# setup local cluster from flags 
host = "127.0.0.1:" 
cluster = {"worker": [host+FLAGS.port1, host+FLAGS.port2]} 
clusterspec = tf.train.ClusterSpec(cluster).as_cluster_def() 

if __name__=='__main__': 
    if not FLAGS.task: # start servers and run client 

     # launch distributed service 
     def runcmd(cmd): subprocess.Popen(cmd, shell=True, stderr=subprocess.STDOUT) 
     runcmd("python %s --task=0"%(sys.argv[0])) 
     runcmd("python %s --task=1"%(sys.argv[0])) 
     time.sleep(1) 

     # bring down distributed service 
     sess = tf.Session("grpc://"+host+FLAGS.port1) 
     queue0 = tf.FIFOQueue(1, tf.int32, shared_name="queue0") 
     queue1 = tf.FIFOQueue(1, tf.int32, shared_name="queue1") 
     with tf.device("/job:worker/task:0"): 
      add_op0 = tf.add(tf.ones(()), tf.ones(())) 
     with tf.device("/job:worker/task:1"): 
      add_op1 = tf.add(tf.ones(()), tf.ones(())) 

     print("Running computation on server 0") 
     print(sess.run(add_op0)) 
     print("Running computation on server 1") 
     print(sess.run(add_op1)) 

     print("Bringing down server 0") 
     sess.run(queue0.enqueue(1)) 
     print("Bringing down server 1") 
     sess.run(queue1.enqueue(1)) 

    else: # Launch TensorFlow server 
    server = tf.train.Server(clusterspec, config=None, 
          job_name="worker", 
          task_index=int(FLAGS.task)) 
    print("Starting server "+FLAGS.task) 
    sess = tf.Session(server.target) 
    queue = tf.FIFOQueue(1, tf.int32, shared_name="queue"+FLAGS.task) 
    sess.run(queue.dequeue()) 
    print("Terminating server"+FLAGS.task) 
+0

感谢您的回答。它工作得很好。但是当我尝试使用'tf.Supervisor'(TF网站上的那个)来适应这个例子时,我遇到了一些问题。一旦我实例化了一个“supervisor”对象,该图将被“终结”。因此我们不能在训练后排队。使用两个图表可能会工作,但根据[本文](http://stackoverflow.com/a/34249940/4811003),它可能会影响性能。有没有一个好的解决方案? – fois

+0

'queue0.enqueue(1)'实际上创建一个enqueue操作并修改图形。你可以改为'op1 = queue0.enqueue(1); ; sess.run(op1)' –

+0

你说得对。我很愚蠢。 非常感谢。 – fois

3

目前还没有任何清晰的方式来关闭TensorFlow GRPC服务器。它可能到shut down a gRPC server,但安全地执行此操作需要对所有正在进行的请求和响应缓冲区进行额外的内存管理,这需要大量额外的管道工作(最糟糕的一种:异步共享内存管理...)对于目前为止没有人要求—的功能!

在实践中,您应该可以使用相同的tf.train.Server对象进行许多不同的计算。如果这不适用于您的用例,请随时拨打open an GitHub issue并告诉我们更多关于您的使用案例。

+0

感谢您的回答。但是如果你在分布式Tensorflow的文档中使用这个例子,你会怎么做?我的意思是,在计算之后,两台工作服务器完成,而两台参数服务器仍在运行。 – fois

+0

目前,我从命令行中终止参数服务器的进程。我想知道它是否安全? – fois

+0

但是,如果我在完成训练后不杀死服务器,那么ps中的变量将影响下一次训练。是这样吗? – fois

0

才会出现此页面很常在谷歌,所以我想我会试着改进Yaroslav's answer,我希望对于那些刚刚进入分布式Tensorflow的人来说,我希望这是一个更明确的答案。

import tensorflow as tf 
import threading 

def main(job_name, task): 
    cluster = tf.train.ClusterSpec({ 
     'ps': ['localhost:22222', 'localhost:22223'], 
     'worker': ['localhost: 22224','localhost: 22225','localhost: 22226'] 
    }) 

    server = tf.train.Server(cluster, job_name=job_name, task_index=task) 

    if job_name == 'ps': 
     # create a shared queue on the parameter server which is visible on /job:ps/task:%d 
     with tf.device('/job:ps/task:%d' % task): 
      queue = tf.FIFOQueue(cluster.num_tasks('worker'), tf.int32, shared_name='done_queue%d' % task) 

     # wait for the queue to be filled 
     with tf.Session(server.target) as sess: 
      for i in range(cluster.num_tasks('worker')): 
       sess.run(queue.dequeue()) 
       print('ps:%d received "done" from worker:%d' % (task, i)) 
      print('ps:%d quitting' % task) 

    elif job_name == 'worker': 
     queues = [] 
     # create a shared queue on the worker which is visible on /job:ps/task:%d 
     for i in range(cluster.num_tasks('ps')): 
      with tf.device('/job:ps/task:%d' % i): 
       queues.append(tf.FIFOQueue(cluster.num_tasks('worker'), tf.int32, shared_name='done_queue%d' % i)) 

     # fill the queue 
     with tf.Session(server.target) as sess: 
      for i in range(cluster.num_tasks('ps')): 
       _, size = sess.run([queues[i].enqueue(task), queues[i].size()]) 
       print('Worker:%d sending "done" to ps:%d [elements=%d]' % (task, i, size)) 

if __name__ == '__main__': 
    threads = [ 
     threading.Thread(target=main, args=('ps', 0)), 
     threading.Thread(target=main, args=('ps', 1)), 
     threading.Thread(target=main, args=('worker', 0)), 
     threading.Thread(target=main, args=('worker', 1)), 
     threading.Thread(target=main, args=('worker', 2))] 
    for thread in threads: 
     thread.start() 
    for thread in threads: 
     thread.join() 

这是很简单的用这个片段替换代码的工人节在“规范” Distributed Tensorflow example扩展:

# create a worker that does nothing 
    elif job_name == 'worker': 
     with tf.device(tf.train.replica_device_setter(worker_device='/job:worker/task:%d' % task, cluster=cluster)): 
      global_step = tf.train.get_or_create_global_step() 
      no_op = tf.no_op() 

     done_ops = [] 
     # create a shared queue on the worker which is visible on /job:ps/task:%d 
     for i in range(cluster.num_tasks('ps')): 
      with tf.device('/job:ps/task:%d' % i): 
       done_queue = tf.FIFOQueue(cluster.num_tasks('worker'), tf.int32, shared_name='done_queue' + str(i)) 
       done_ops.append(done_queue.enqueue(task)) 

     hooks=[tf.train.StopAtStepHook(last_step=1), 
       tf.train.FinalOpsHook([done_ops])] 

     with tf.train.MonitoredTrainingSession(master=server.target, 
               is_chief=(task == 0), 
               hooks=hooks) as sess: 
      sess.run([no_op]) 

注意,MonitoredTrainingSession版本似乎是在连接所有慢得多的工人在一起。