2016-09-14 74 views
0

我想转换MNIST tensorflow example以分布形式运行。我正在使用文档page上给出的模板来执行此操作。这里是我的代码:分配tensorflow:主管没有初始化?

import tensorflow as tf 

# Flags for defining the tf.train.ClusterSpec 
tf.app.flags.DEFINE_string("ps_hosts", "", 
          "Comma-separated list of hostname:port pairs") 
tf.app.flags.DEFINE_string("worker_hosts", "", 
          "Comma-separated list of hostname:port pairs") 

# Flags for defining the tf.train.Server 
tf.app.flags.DEFINE_string("job_name", "", "One of 'ps', 'worker'") 
tf.app.flags.DEFINE_integer("task_index", 0, "Index of task within the job") 
FLAGS = tf.app.flags.FLAGS 


# Import data 
from tensorflow.examples.tutorials.mnist import input_data 


def main(_): 
    ps_hosts = FLAGS.ps_hosts.split(",") 
    worker_hosts = FLAGS.worker_hosts.split(",") 

    # Create a cluster from the parameter server and worker hosts. 
    cluster = tf.train.ClusterSpec({"ps": ps_hosts, "worker": worker_hosts}) 

    # Create and start a server for the local task. 
    server = tf.train.Server(cluster, 
          job_name=FLAGS.job_name, 
          task_index=FLAGS.task_index) 

    if FLAGS.job_name == "ps": 
    server.join() 
    elif FLAGS.job_name == "worker": 

    # Assigns ops to the local worker by default. 
    with tf.device(tf.train.replica_device_setter(
     worker_device="/job:worker/task:%d" % FLAGS.task_index, 
     cluster=cluster)): 

     # Build model... 
     #loss = ... 
     #global_step = tf.Variable(0) 


     # Create the model 
     x = tf.placeholder(tf.float32, [None, 784]) 
     W = tf.Variable(tf.zeros([784, 10])) 
     b = tf.Variable(tf.zeros([10])) 
     y = tf.nn.softmax(tf.matmul(x, W) + b) 

     global_step = tf.Variable(0) 

     #train_op = tf.train.AdagradOptimizer(0.01).minimize(
     # loss, global_step=global_step) 

     # Define loss and optimizer 
     y_ = tf.placeholder(tf.float32, [None, 10]) 
     cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(y), reduction_indices=[1])) 
     train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy,global_step=global_step) 



     saver = tf.train.Saver() 
     summary_op = tf.merge_all_summaries() 
     init_op = tf.initialize_all_variables() 

    mnist = input_data.read_data_sets(FLAGS.data_dir, one_hot=True) 

    # Create a "supervisor", which oversees the training process. 
    sv = tf.train.Supervisor(is_chief=(FLAGS.task_index == 0), 
          logdir="/tmp/train_logs", 
          init_op=init_op, 
          summary_op=summary_op, 
          saver=saver, 
          global_step=global_step, 
          save_model_secs=600) 

    # The supervisor takes care of session initialization, restoring from 
    # a checkpoint, and closing when done or an error occurs. 
    with sv.managed_session(server.target) as sess: 
     # Loop until the supervisor shuts down or 1000000 steps have completed. 
     step = 0 
     while not sv.should_stop() and step < 10000: 
     # Run a training step asynchronously. 
     # See `tf.train.SyncReplicasOptimizer` for additional details on how to 
     # perform *synchronous* training. 
     #_, step = sess.run([train_op, global_step]) 
     print(step) 
     batch_xs, batch_ys = mnist.train.next_batch(100) 
     train_feed = {x: batch_xs, y_: batch_ys} 
     _, step = sess.run([train_step, global_step],feed_dict = train_feed) 


    # Ask for all the services to stop. 
    sv.stop() 

if __name__ == "__main__": 
    tf.app.run() 

首先,我启动2个参数服务器,然后2个工作节点。服务器在所有4个中都得到了正确的初始化,但是主管没有开始训练。

这里是主管输出:

I tensorflow/core/distributed_runtime/rpc/grpc_channel.cc:206] Initialize HostPortsGrpcChannelCache for job ps -> {url1:2220, url1:2221} 
I tensorflow/core/distributed_runtime/rpc/grpc_channel.cc:206] Initialize HostPortsGrpcChannelCache for job worker -> {localhost:2230, url2:2230} 
I tensorflow/core/distributed_runtime/rpc/grpc_server_lib.cc:202] Started server with target: grpc://localhost:2230 

Extracting /tmp/data/train-images-idx3-ubyte.gz 
Extracting /tmp/data/train-labels-idx1-ubyte.gz 
Extracting /tmp/data/t10k-images-idx3-ubyte.gz 
Extracting /tmp/data/t10k-labels-idx1-ubyte.gz 

回答

1

我以前有这个问题。这里有两条建议给你。首先,你应该确保节点正在相互通信。其次,你应该检查你的Tensorflow版本。我在v12上遇到了这个问题,但是在v10上,相同的代码工作得很好。