2017-09-04 23 views
1

我用Normalization图层训练了一个模型。该代码是这样的:如何预测何时在Keras中使用规范化或丢失层进行训练?

在训练阶段:

model=Sequential() 
model.add() 

... 

k.set_learning_phase(1) 
ModelCheckpoint(weights_file) 
model.fit() 

在推理时间:

k.set_learning_phase(0) 
model.load_weights(weights_file) 
model.predict_classes() 

... 

Keras的版本:2.0.8。这是正确的,还是需要一些特殊的代码来计算经过训练后的BN,例如在Caffe中使用SegNet?

回答

2

不,您在使用BatchNormalization或Dropout图层时不需要做任何特别的事情。凯拉斯已经跟踪学习/测试阶段,所以当使用predictpredict_classes时,它是正确的。

你甚至不需要手动设置学习阶段,凯拉斯已经做到了。

+0

感谢您的及时回复。我知道了。 – spider

相关问题