2017-07-19 23 views
2

我使用的是Tensorflow v1.1,我一直在试图弄清楚如何使用我的EMA'ed权重进行推理,但不管我做什么,我总是得到错误在检查点找不到密钥<变量名> Tensorflow

找不到:重点W/ExponentialMovingAverage检查点未发现

即使当我循环并打印出所有的tf.global_variables主要存在

这是一个可重复脚本重调整从Facenet's单元测试:

import tensorflow as tf 
import numpy as np 


tf.reset_default_graph() 

# Create 100 phony x, y data points in NumPy, y = x * 0.1 + 0.3 
x_data = np.random.rand(100).astype(np.float32) 
y_data = x_data * 0.1 + 0.3 

# Try to find values for W and b that compute y_data = W * x_data + b 
# (We know that W should be 0.1 and b 0.3, but TensorFlow will 
# figure that out for us.) 
W = tf.Variable(tf.random_uniform([1], -1.0, 1.0), name='W') 
b = tf.Variable(tf.zeros([1]), name='b') 
y = W * x_data + b 

# Minimize the mean squared errors. 
loss = tf.reduce_mean(tf.square(y - y_data)) 
optimizer = tf.train.GradientDescentOptimizer(0.5) 
opt_op = optimizer.minimize(loss) 

# Track the moving averages of all trainable variables. 
ema = tf.train.ExponentialMovingAverage(decay=0.9999) 
variables = tf.trainable_variables() 
print(variables) 
averages_op = ema.apply(tf.trainable_variables()) 
with tf.control_dependencies([opt_op]): 
    train_op = tf.group(averages_op) 

# Before starting, initialize the variables. We will 'run' this first. 
init = tf.global_variables_initializer() 

saver = tf.train.Saver(tf.trainable_variables()) 

# Launch the graph. 
sess = tf.Session() 
sess.run(init) 

# Fit the line. 
for _ in range(201): 
    sess.run(train_op) 

w_reference = sess.run('W/ExponentialMovingAverage:0') 
b_reference = sess.run('b/ExponentialMovingAverage:0') 

saver.save(sess, os.path.join("model_ex1")) 

tf.reset_default_graph() 

tf.train.import_meta_graph("model_ex1.meta") 
sess = tf.Session() 

print('------------------------------------------------------') 
for var in tf.global_variables(): 
    print('all variables: ' + var.op.name) 
for var in tf.trainable_variables(): 
    print('normal variable: ' + var.op.name) 
for var in tf.moving_average_variables(): 
    print('ema variable: ' + var.op.name) 
print('------------------------------------------------------') 

mode = 1 
restore_vars = {} 
if mode == 0: 
    ema = tf.train.ExponentialMovingAverage(1.0) 
    for var in tf.trainable_variables(): 
     print('%s: %s' % (ema.average_name(var), var.op.name)) 
     restore_vars[ema.average_name(var)] = var 
elif mode == 1: 
    for var in tf.trainable_variables(): 
     ema_name = var.op.name + '/ExponentialMovingAverage' 
     print('%s: %s' % (ema_name, var.op.name)) 
     restore_vars[ema_name] = var 

saver = tf.train.Saver(restore_vars, name='ema_restore') 

saver.restore(sess, os.path.join("model_ex1")) # error happens here! 

w_restored = sess.run('W:0') 
b_restored = sess.run('b:0') 

print(w_reference) 
print(w_restored) 
print(b_reference) 
print(b_restored) 

回答

2

key not found in checkpoint错误意味着该变量在内存模型,但不是在磁盘上的序列化的检查点文件存在。

您应该使用inspect_checkpoint tool来了解什么张量被保存在检查点中,以及为什么某些指数移动平均值不会在此处保存。

这不是从线应该引发错误

+0

嗨,谢谢,我一定会看看,花了,也更新了我的问题! – YellowPillow

+0

我想我明白你的错误可能来自哪里。您只用可训练变量初始化保存程序。尝试使用默认构建的保护程序。移动平均变量不可训练,所以不会在您的检查点结束。 –

+0

默认构建的保护程序是什么意思? – YellowPillow

0

我想补充到最好使用检查点训练有素的变量的方法您的摄制例子清楚。

请记住,保存程序var_list中的所有变量都应包含在您配置的检查点中。您可以通过检查那些在保护程序:

print(restore_vars) 

,并在检查站这些变量是:在你的情况

vars_in_checkpoint = tf.train.list_variables(os.path.join("model_ex1")) 

如果restore_vars都包括在vars_in_checkpoint那么它不会引发错误,否则首先初始化所有的变量:

all_variables = tf.get_collection_ref(tf.GraphKeys.GLOBAL_VARIABLES) 
sess.run(tf.variables_initializer(all_variables)) 

的所有变量将被初始化是这样的,或者未在检查点,那么你就可以筛选出不包含在该检查站restore_vars这些变量(假设与ExponentialMovingAverage所有变量在他们的名字没有在检查站):

temp_saver = tf.train.Saver(
    var_list=[v for v in all_variables if "ExponentialMovingAverage" not in v.name]) 
ckpt_state = tf.train.get_checkpoint_state(os.path.join("model_ex1"), lastest_filename) 
print('Loading checkpoint %s' % ckpt_state.model_checkpoint_path) 
temp_saver.restore(sess, ckpt_state.model_checkpoint_path) 

这可以节省相比,训练模型一段时间 从头开始​​。 (在我的情况下,恢复的变量与一开始从头开始的培训相比没有明显的改善,因为所有旧的优化器变量都被放弃了,但它可以显着加速优化过程,我认为,因为它就像是预训练一些变量)

无论如何,一些变量是有用的恢复像嵌入和一些图层等

相关问题