2017-06-03 42 views
0

我有以下代码来初始化变量v3和v4。初始化变量V3和V4后,我修改这些变量并保存在一个检查点文件:为什么tensorflow train.Saver()保存初始变量值而不是修改后的值?

import tensorflow as tf 
sess = tf.Session() 
v3 = tf.Variable(tf.random_uniform([4,2]), name="v3") 
init_op = tf.global_variables_initializer() 
sess.run(init_op) 
with sess.as_default(): 
    print(v3.eval()) 
    v3 = tf.transpose(v3) 
    print(v3.eval()) 
    sess.run(v3) 
    v4 = tf.Variable(v3+3, name="v4") 
    v4 = v4 + 5 
    init_op = tf.global_variables_initializer() 
    sess.run(init_op) 
    print(v4.eval()) 
    saver = tf.train.Saver() 
    saver.save(sess, "pktest_ckpt") 

本刊V3的下列值,换位v3和v4值:

[[ 0.90765333 0.61777163] 
[ 0.5102632 0.45610023] 
[ 0.36511779 0.5465256 ] 
[ 0.61696458 0.86357415]] 
[[ 0.90765333 0.5102632 0.36511779 0.61696458] 
[ 0.61777163 0.45610023 0.5465256 0.86357415]] 
[[ 8.96951866 8.24961662 8.30669975 8.54586029] 
[ 8.55886841 8.16989517 8.48039341 8.06889534]] 

后距离检查点文件恢复变量,我看到变量的值是初始初始化的人,而不是修改后的值:

tf.reset_default_graph() 
mg = tf.train.import_meta_graph("pktest_ckpt.meta") 
with tf.Session() as sess: 
    for v in tf.global_variables(): 
     print(v) 
    saver = tf.train.Saver(tf.global_variables()) 
    saver.restore(sess, tf.train.latest_checkpoint("./")) 
    print(sess.run('v4:0')) 

它打印:

<tf.Variable 'v3:0' shape=(4, 2) dtype=float32_ref> 
<tf.Variable 'v4:0' shape=(2, 4) dtype=float32_ref> 
INFO:tensorflow:Restoring parameters from ./pktest_ckpt 
[[ 3.75337863 3.52812386 3.97137022 3.76210618] 
[ 3.81927872 3.41938591 3.82610369 3.20377684]] 

我的期望是让形状(2,4)和v4值V3作为 [8.96951866 8.24961662 8.30669975 8.54586029] [8.55886841 8.16989517 8.48039341 8.06889534]

谁能解释这是为什么发生?

回答

0

我想我得到了答案。我应该这样保存变量:

import tensorflow as tf 
sess = tf.Session() 
v3 = tf.Variable(tf.random_uniform([4,2]), name="v3") 
init_op = tf.global_variables_initializer() 
sess.run(init_op) 
with sess.as_default(): 
    v4 = tf.Variable(v3+3, name="v4") 
    init_op = tf.variables_initializer([v3,v4]) 
    sess.run(init_op) 
    do_add = v4.assign(tf.add(v4,5)) 
    sess.run(do_add) 

    print(v4.eval()) 
    saver = tf.train.Saver() 
    saver.save(sess, "pktest_ckpt") 
0

这可能是因为您在第二次初始化后没有真正运行操作。尝试加入session.run([v3, v4])后您的第二个session.run(init_op)

+0

尝试运行。新值不会保存到模型中。我想我们必须使用分配操作符,因为我在我的答案中发布了。但是,编写额外的代码行来创建操作步骤并在会话中运行它是非常直观的。 Tensorflow应该自动执行此操作。 –

相关问题