2017-01-15 40 views
0

在线文档中说,moving_average和moving_variance都是model_variables,而tf.model_variables()返回local_variables类型的张量。这意味着,当我保存我的状态时,model_variables不会被保存吗?Tensorflow的batch_norm中的模型变量

我正在尝试将批量归一化应用于几个3D卷积和完全连接的图层。我用batch_norm训练了我的网络并保存了一个检查点文件,但是当我去恢复我保存的状态时,它说move_mean找不到。确切的错误是,当TF将恢复的值分配给moving_mean时,lhs张量的形状[]不能与rhs的形状一致,[20]。

当我不在图层周围添加batch_norm时,图恢复正常。 我打算在训练结束时添加一个全局变量,以节省我的move_mean和moving_variance值。这是TF为我准备使用batch_norm的方式吗?

谢谢!

回答

1

变量moving_mean和moving_variance不在我保存的声明中,因为我已将updates_collections设置为默认值。由于我在运行图层时从未包含控件依赖项,因此这些变量从未更新过。

的代码,包括是:

from tensorflow.python import control_flow_ops 

update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) 
if update_ops: 
    updates = tf.tuple(update_ops) 
    total_loss = control_flow_ops.with_dependencies(updates, total_loss) 

或者设置

updates_collection=None 

为就地更新。

有关更多信息,请参见the API descriptioncurrent github discussion