2017-04-10 39 views
0

我使用Tensorflow r0.12训练了一些模型并保存了它。后来我更新到r1.0.1。某些型号的加载没有任何问题,但如果型号中有RNN单元,则加载将失败,并显示Key layer-5/bidirectional_rnn/bw/multi_rnn_cell/cell_1/basic_rnn_cell/biases not found in checkpoint。 另外,如果我检查model.index文件,我在那里看到类似的条目,例如:5/BiRNN/BW/MultiRNNCell/Cell0/BasicRNNCell/Linear/Bias在张量流中加载较旧的检查点

与RNN细胞的包现在在tf.contrib.rnn(这是tf.nn.rnn_cell在0.12),所以我认为一些命名已经改变。

问题是: 有没有办法加载我的模型,重新映射其张量并保存张量名称与r1.0兼容?

P.S.如果有帮助,我也有model.meta文件。

谢谢!

回答

0

如果有人得到同样的问题,这里是我使用的解决方案。它是inspect_checkpoint.pytensorflow.python.tools中的张量打印功能的修改版本。


def resave_tensors(file_name, rename_map, dry_run=False): 
    """ 
    Updates checkpoint by renaming tensors in it. 
    :param file_name: Filename with checkpoint. 
    :param rename_map: Map from old names to new ones 
    :param dry_run: If True, just print new tensors. 
    """ 
    renames_count = 0 
    reader = pywrap_tensorflow.NewCheckpointReader(file_name) 
    var_to_shape_map = reader.get_variable_to_shape_map() 
    for key in sorted(var_to_shape_map): 
     print("tensor_name: ", key) 
     tensor_val = reader.get_tensor(key) 
     print('shape: {}'.format(tensor_val.shape)) 
     if key in rename_map: 
      renames_count += 1 
      key = rename_map[key] 
     tf.Variable(tensor_val, dtype=tensor_val.dtype, name=key) 
    saver = tf.train.Saver() 
    if not dry_run: 
     with tf.Session() as session: 
      session.run(tf.global_variables_initializer()) 
      saver.save(session, file_name) 
    print('Renamed vars: {}'.format(renames_count)) 
相关问题