2017-10-06 119 views
0

我尝试通过import_graph_def()读取RNN网络并进行推理。 但我不能使用tf.trainable_variables()来获取任何变量。Tensorflow,使用import_graph_def()加载模型错误

在下面的代码,tf.trainable_variables()返回[](带有没有列表) 此外,当我使用saver = tf.train.Saver(),tensorflow报告“没有变量保存”

def eval_on_test(graph_path): 
batch_size = 80 
train_begin = 0 
train_end = 3000 
with tf.Graph().as_default() as graph: 
    with open(graph_path, 'rb') as f: 
     tf_graph = tf.GraphDef() 
     print("Loading graph_def from {}".format(graph_path)) 
     tf_graph.ParseFromString(f.read()) 

     return_elements = tf.import_graph_def(tf_graph, name="", return_elements=['input_x:0', 'output_y:0', 'pred:0', 'loss:0']) 
     X = return_elements[0] 
     Y = return_elements[1] 
     pred = return_elements[2] 
     loss = return_elements[3] 

    tf_config = tf.ConfigProto() 
    tf_config.gpu_options.allow_growth = True 

    print("graph loaded, start testing") 
    with tf.Session(config=tf_config) as sess: 

     init_op = sess.graph.get_operation_by_name('init') 
     sess.run(init_op) 
     print(tf.trainable_variables()) 
     batch_index,train_x,train_y=get_train_data(batch_size,time_step,train_begin,train_end) 
     for batch in range(len(batch_index)-1): 
      loss_ = sess.run(loss, feed_dict={X:train_x[batch_index[batch]:batch_index[batch+1]],Y:train_y[batch_index[batch]:batch_index[batch+1]]}) 
      print(batch, loss_) 

任何帮助,将不胜感激。

回答

0

import_graph_def将仅恢复图形,但不会恢复集合,如GLOBAL_VARIABLES,这就是为什么Saver无法找到图中的任何变量,来解决这个问题,你可以尝试tf.train.import_meta_graph这也将恢复所有集合。

+0

感谢您的信息。这解决了我的问题。我试图用graphdef保存和恢复图形,该后解决问题,https://www.tensorflow.org/extend/tool_developers/ – zhangc