2017-10-12 41 views
0

我想用scikit-learn的cross_val_score()函数对我的Keras神经网络进行交叉验证。如何在scikit-learn的cross_val_score()中每次折叠后运行函数?

问题是,在每次折叠后不仅结果被记住,而且整个Keras模型。所以我想在每次折叠后用K.clear_session()来清除这个模型。但这只是上下文的细节。

我的主要问题是:如何在scikit-learn的cross_val_score()每次折叠后运行自定义函数?换句话说:可以运行在每次折叠后应该运行的回调?或者还有其他解决方法?

回答

0

您可以创建一个自定义回调函数,并重新编写此回调函数的on_train_end(self,logs = {})方法。这种新方法将在每个培训步骤结束时完成。类似的东西:

class CustomCall(Callback): 

    def __init__(self): 
     super(CustomCall, self).__init__() 

    def on_epoch_begin(self, epoch, logs={}): 
     return 

    def on_epoch_end(self, epoch, logs={}): 
     return 

    def on_batch_begin(self, batch, logs={}): 
     return 

    def on_train_end(self, logs={}): 
     # Stuff here 
     print('\n Delete previous trained model : ') 
     K.clear_session() 
     return 
+0

不幸的是,问题是,K.clear_session()必须在评估模型后调用,而不是在cross_val_score()内部训练之后调用。所以我必须在交叉折叠结束时调用K.clear_session(),而不是在Keras训练结束时。 –