2017-05-11 35 views
0

我一直在研究这一段时间,似乎无法破解它。在其他问题中,我看到他们使用这些代码示例为了保存和恢复使用metagraph和检查点文件的模型,但是当我做类似的事情时,它说w1未定义,当我将savemodel和restore模型分开时python文件。当我在保存部分的末尾恢复时,它可以正常工作,但它无法在一个单独的文件中重新定义所有内容。我查看了检查点文件,看起来奇怪,它只有两行,它似乎没有引用任何变量或有任何值。它只有1kb。我曾尝试将'w1'作为字符串放入打印函数中,而不是返回值,而是返回值。这是否适用于其他人?如果是这样,你的检查点文件是什么样的?在张量流中加载metagraph和检查点

#Saving 
import tensorflow as tf 
w1 = tf.Variable(tf.random_normal(shape=[2]), name='w1') 
w2 = tf.Variable(tf.random_normal(shape=[5]), name='w2') 
saver = tf.train.Saver([w1,w2]) 
sess = tf.Session() 
sess.run(tf.global_variables_initializer()) 
saver.save(sess, 'my_test_model',global_step=1000) 

#restoring 
with tf.Session() as sess:  
    saver = tf.train.import_meta_graph('my_test_model-1000.meta',clear_devices=True) 
    saver.restore(sess,tf.train.latest_checkpoint('./')) 
    print sess.run(w1) 
+0

重复http://stackoverflow.com/questions/33759623/tensorflow-how-to-save-restore-a-model – hars

回答

2

您的图保存正确,但恢复它不会恢复包含图的节点的变量。 w1是一个python变量,你从来没有声明过在'恢复'部分代码。为了得到一个手柄上的权重,

  • 可以在TF图形使用他们的名字:w1=get_variable(name='w1')。问题是你必须密切关注你的名字范围,并确保你没有多个相同名称的变量(在这种情况下,TF为他们的名字添加'_1',所以你可能会得到错误的)。如果你这样做,张量板可以帮助你知道每个变量的确切名称。

  • 您可以使用集合:将有趣的节点保存到集合中,并在还原后从它们中取回。在构建图表时,在保存之前,请执行以下操作:例如:tf.add_to_collection('weights', w1)tf.add_to_collection('weights', w2),并在您的还原代码中:[w1, w2] = tf.get_collection('weights1')。那么你可以正常使用w1和w2。

我认为后者虽然更详细,但对于未来架构的变化可能会更好。我知道所有这些看起来都很冗长,但请记住,通常情况下,您不必在所有变量上取回句柄,但只有少数变量:输入,输出和训练步骤通常已足够。

+0

谢谢。我现在就试试看。你有机会知道为什么这个代码片段是共享的吗?如果它只在同一个文件中工作,它似乎不是特别有用。 –

+0

得到它的工作。只是对您的代码进行一次小小的更正它是'tf.get_collection'而不是'tf.get_from_collection'。感谢您的帮助。我拉着头发不确定为什么它为其他人工作,但不是我。 –

+0

我认为旧的做事方式是重新创建图形,然后加载检查点,而不使用元图。在这种情况下,您可以直接使用w1(因为它是在您构建图形时定义的)。我们实际上仍然可以这样做,我忘了在我的答案中提到它。但是,当你加载它时,你仍然需要访问图形构建函数,并且你需要与保存的图形完全一样,所以它又不如当前方法强壮:如果在代码中更改1个节点名称实例,你将无法加载旧的ckpt ... – gdelab

相关问题