2017-04-10 32 views
1

对于交叉验证,如何保存不同训练集和交叉验证集的训练历史记录?我认为pickle写入的一种附加模式会起作用,但实际上它不起作用。如果可能的话,您能否请您指导保存所有模型的方法,现在我只能使用model.save(file)保存上次训练过的模型。如何保存Keras的训练历史作为交叉验证(循环)?

historyfile = 'history.pickle' 
f = open(historyfile,'w') 
f.close() 
ind = 0 
save = {} 
for train, test in kfold.split(input,output): 
    ind = ind+1 
    #create model 
    model = model_FCN() 
    # fit the model 
    history = model.fit(input[list(train)], output[list(train)], batch_size = 16, epochs = 100, verbose =1, validation_data =(input[list(test)],output[list(test)])) 
    #save to file 
    try: 
     f = open(historyfile,'a') ## appending mode?? 
     save['cv'+ str(ind)]= history.history 
     pickle.dump(save, f, pickle.HIGHEST_PROTOCOL) 
     f.close() 
    except Exception as e: 
     print('Unable to save data to', historyfile, ':', e) 

    scores = model.evaluate(MR_patch[list(test)], CT_patch[list(test)], verbose=0) 
    print("%s: %.2f" % (model.metrics_names[1], scores[1])) 
    cvscores.append(scores[1]) 
    print("cross validation stage: " + str(ind)) 

print("%.2f (+/- %.2f)" % (np.mean(cvscores), np.std(cvscores))) 

回答

0

为每个时间段对某些列车后保存模型和验证数据,可以使用Callback

例如:

from keras.callbacks import ModelCheckpoint 
import os 

output_directory = '' # here should be path to output directory  
model_checkpoint = ModelCheckpoint(os.path.join(output_directory , 'weights.{epoch:02d}-{val_loss:.2f}.hdf5')) 
model.fit(input[list(train)], 
      output[list(train)], 
      batch_size=16, 
      epochs=100, 
      verbose=1, 
      validation_data=(input[list(test)],output[list(test)]), 
      callbacks=[model_checkpoint]) 

后每个时间段你的模型将被保存在文件。这个回调如果你想保存模型中训练的每个折,你可以简单地在添加model.save(文件),你可以在文档(https://keras.io/callbacks/

查找更多信息您的for循环:

model.fit(input[list(train)], 
      output[list(train)], 
      batch_size=16, 
      epochs=100, 
      verbose=1, 
      validation_data=(input[list(test)],output[list(test)])) 
model.save(os.path.join(output_directory, 'fold_{}_model.hdf5'.format(ind))) 

要保存历史记录: 您可以保存一次历史记录,而无需将其追加到每个循环中的文件。 for循环之后,您应该使用键(折痕标记)和值(每个折叠的历史记录)字典并保存此字典,如下所示:

f = open(historyfile, 'wb') 
pickle.dump(save, f, pickle.HIGHEST_PROTOCOL) 
f.close() 
+0

非常感谢您耐心的回答! –