2016-04-22 277 views
1

我是TensorFlow的初学者,目前正在培训CNN。Tensorflow:保存和恢复模型参数

我为了节省模型使用的参数,使用节电器,但我有担忧这是否会本身存储在模型中使用的所有变量,以及足以将值恢复为重新运行在训练的网络上执行分类/测试的程序。

让我们看一下TensorFlow给出的着名示例MNIST。

在这个例子中,我们有一堆卷积块,所有块都有权重,并且偏置变量在程序运行时被初始化。

W_conv1 = init_weight([5,5,1,32]) 
b_conv1 = init_bias([32]) 

处理好几个图层后,我们创建一个会话,并初始化添加到图形中的所有变量。

sess = tf.Session() 
sess.run(tf.initialize_all_variables()) 
saver = tf.train.Saver() 

这里,是否有可能作出评论saver.save代码,并通过培训后saver.restore(SESS,FILE_PATH)取代它,以恢复体重,偏见等,后面的参数到图表?这是应该如何?

for i in range(1000): 
... 

    if i%500 == 0: 
    saver.save(sess,"model%d.cpkt"%(i)) 

我目前的训练大数据集,因此终止,并重新启动培训是浪费时间和资源,所以我要求有人请澄清之前,我开始训练。

+0

这是有点不清楚你在问什么。 “评论saver.save代码,并用saver.restore(sess,file_path)替换它”你不想存储你的训练值并且重新设置以前的训练(通过恢复)? “如此终止,重新开始培训是一种浪费”。这意味着当您完成所有培训后,您想要保存一次模型? –

+0

@Sung Kim:你的后一个问题的答案是肯定的。我不打算用存储的值重新开始训练,而是在完成训练后简单地将模型保存一次。因为在Matlab中,这非常简单,事实上,这是我第一次编程Python和TensorFlow,所以我不知道是否有其他优雅的方式来保存参数。 –

回答

4

如果你想保存最后的结果只有一次,你可以这样做:

with tf.Session() as sess: 
    for i in range(1000): 
    ... 


    path = saver.save(sess, "model.ckpt") # out of the loop 
    print "Saved:", path 

在其他程序中,可以使用从saver.save返回预测什么的路径加载模型。你可以在https://github.com/sugyan/tensorflow-mnist看到一些例子。

1

根据here和Sung Kim解决方案的说明,我写了一个非常简单的模型来准确解决这个问题。基本上这样你需要从同一个类创建一个对象,并从保存器中恢复它的变量。你可以找到这个解决方案的例子here