2016-08-25 97 views
2

我试图从检查点文件恢复一些变量,如果相同的变量名称在当前模型中。
而且我发现有一些方法,如Tensorfow Github如何在Tensorflow中使用CheckpointReader恢复变量

所以我想要做如下使用has_tensor("variable.name")正在检查在检查点文件中的变量名,

...  
reader = tf.train.NewCheckpointReader(ckpt_path) 
for v in tf.trainable_variables(): 
    print v.name 
    if reader.has_tensor(v.name): 
     print 'has tensor' 
... 

但是我发现v.name同时返回变量namecolon+number。例如,我有变量名称W_ob_o,然后v.name返回W_o:0, b_o:0

但是reader.has_tensor()要求name没有colonnumber作为W_o, b_o

我的问题是:如何在变量名的末尾删除colonnumber以读取变量?
有没有更好的方法来恢复这些变量?

回答

4

你可以使用string.split()得到伸张名称:

...  
reader = tf.train.NewCheckpointReader(ckpt_path) 
for v in tf.trainable_variables(): 
    tensor_name = v.name.split(':')[0] 
    print tensor_name 
    if reader.has_tensor(tensor_name): 
     print 'has tensor' 
... 

接下来,让我用一个例子来说明我将如何从一个.cpkt文件还原每一个可能的变量。首先,让我们节省v2v3tmp.ckpt

import tensorflow as tf 

v1 = tf.Variable(tf.ones([1]), name='v1') 
v2 = tf.Variable(2 * tf.ones([1]), name='v2') 
v3 = tf.Variable(3 * tf.ones([1]), name='v3') 

saver = tf.train.Saver({'v2': v2, 'v3': v3}) 

with tf.Session() as sess: 
    sess.run(tf.initialize_all_variables()) 
    saver.save(sess, 'tmp.ckpt') 

这就是我会恢复每一个变量(属于一个新的图)显示tmp.ckpt起来:

with tf.Graph().as_default(): 
    assert len(tf.trainable_variables()) == 0 
    v1 = tf.Variable(tf.zeros([1]), name='v1') 
    v2 = tf.Variable(tf.zeros([1]), name='v2') 

    reader = tf.train.NewCheckpointReader('tmp.ckpt') 
    restore_dict = dict() 
    for v in tf.trainable_variables(): 
     tensor_name = v.name.split(':')[0] 
     if reader.has_tensor(tensor_name): 
      print('has tensor ', tensor_name) 
      restore_dict[tensor_name] = v 

    saver = tf.train.Saver(restore_dict) 
    with tf.Session() as sess: 
     sess.run(tf.initialize_all_variables()) 
     saver.restore(sess, 'tmp.ckpt') 
     print(sess.run([v1, v2])) # prints [array([ 0.], dtype=float32), array([ 2.], dtype=float32)] 

此外,您可能希望确保形状和dtype匹配。

+0

感谢rvinas !我认为这是最明显的例子,如何恢复我所看到的所需变量。 – user270700

+0

不客气! – rvinas

1

tf.train.NewCheckpointReader是一个创建CheckpointReader对象的漂亮方法。 CheckpointReader有几个非常有用的方法。与你的问题最相关的方法是get_variable_to_shape_map()。

  • get_variable_to_shape_map()提供了变量名称和形状的字典:

saved_shapes = reader.get_variable_to_shape_map() 
 
print 'fire9/squeeze1x1/kernels:', saved_shapes['fire9/squeeze1x1/kernels']

请看看下面这个快速教程: Loading Variables from Existing Checkpoints

+0

感谢您的回答!我会检查你的教程并尝试你的方法。 – user270700

相关问题