2017-06-01 78 views
3

我使用Keras Sequential模型来训练多个多类分类器。将类标签附加到Keras模型

在评估中,Keras输出一个置信度向量,我可以从argmax中推断出正确的类ID。然后,我可以使用查找表来接收实际的类标签(例如字符串)。

到目前为止,解决方案是加载训练好的模型,然后单独加载查找表。由于我有相当多的分类器,我宁愿将两个结构都保存在一个文件中。

所以我在寻找的是一种将实际标签查找矢量集成到Keras模型中的方法。这将允许我有一个分类器文件,它能够获取一些输入数据并返回该数据的正确类标签。

解决这个问题的一种方法是将模型和查找表存储在一个元组中,并将该元组写入一个pickle中,但这看起来不太优雅。

回答

5

所以我尝试了一下我自己的解决方案,这似乎工作。尽管我希望更简单一些。

第二次打开模型文件并不是真的最佳我认为。如果任何人都可以做得更好,无论如何,请做。

import h5py 

from keras.models import load_model 
from keras.models import save_model 


def load_model_ext(filepath, custom_objects=None): 
    model = load_model(filepath, custom_objects=None) 
    f = h5py.File(filepath, mode='r') 
    meta_data = None 
    if 'my_meta_data' in f.attrs: 
     meta_data = f.attrs.get('my_meta_data') 
    f.close() 
    return model, meta_data 


def save_model_ext(model, filepath, overwrite=True, meta_data=None): 
    save_model(model, filepath, overwrite) 
    if meta_data is not None: 
     f = h5py.File(filepath, mode='a') 
     f.attrs['my_meta_data'] = meta_data 
     f.close() 
+0

接受我自己的答案缺乏替代品。如果有人提出更好的解决方案,我会接受他们的。 – Cerno

+1

我试图解决的同样的问题。但你的解决方案不适用于我: '''save_model_ext(mod1,filepath ='test_model.h5',meta_data = {0:'c1',1:'c2'})''' 产生一个错误: ''' TypeError:Object dtype dtype('O')没有原生HDF5等价物 ''' 您的函数期望得到'meta_data'是什么类型? – slymore

+0

你好。您必须使用可以转换为HDF5的数据。 dtype =“O”表示您的数据包含一个显然无效的Python对象。如果我记得,我用python字典没有问题。这真的是你尝试的代码还是事实更复杂? – Cerno