0
我在写一个名为XGB的类,它从XGBClassifier(从python库xgboost.sklearn)继承。我写了一个初始化功能和一个合适的一个,如下所示:python继承:新类没有正确初始化
from xgboost.sklearn import XGBClassifier
from balanceSmote import BalanceSmote
from balance import Balance
class XGB(XGBClassifier):
def __init__(self,learning_rate=0.5, max_depth=3,colsample_bytree=0.5,n_estimators=300,
frac=None,k_neighbors=None,m_neighbors=None,out_step=None):
# These are the additional arguments that are not in XGBClassifier
if k_neighbors:
self.balancingStrategy = 'smote'
self.k_neighbors = k_neighbors
self.m_neighbors = m_neighbors
self.out_step = out_step
elif frac :
self.balancingStrategy = 'normal'
self.frac = frac
else:
self.balancingStrategy = 'false'
# Utilize the motherClass
super(XGB,self).__init__(seed=500,
learning_rate = learning_rate,
max_depth = max_depth,
colsample_bytree = colsample_bytree,
n_estimators = n_estimators)
这里是我的测试代码:
xgb4 = XGB(learning_rate = 0.1, max_depth = 3, colsample_bytree = 1, n_estimators = 1000)
xgb4.fit(trainData,trainLabel)
初始化似乎顺利,但是当我尝试使用配合( )(这是)从XGBClassifier继承的方法,我有一个错误信息,告诉我一个参数丢失:
File "<ipython-input-3-47344b7fbc76>", line 1, in <module>
runfile('/Users/celsloaner/Project/SPUDS/code/testSpark.py', wdir='/Users/celsloaner/Project/SPUDS/code')
File "/anaconda/envs/SPUDS/lib/python3.5/site-packages/spyder/utils/site/sitecustomize.py", line 880, in runfile
execfile(filename, namespace)
File "/anaconda/envs/SPUDS/lib/python3.5/site-packages/spyder/utils/site/sitecustomize.py", line 102, in execfile
exec(compile(f.read(), filename, 'exec'), namespace)
File "/Users/celsloaner/Project/SPUDS/code/testSpark.py", line 50, in <module>
xgb4.fit(predictor.trainData,predictor.trainLabel)
File "/anaconda/envs/SPUDS/lib/python3.5/site-packages/xgboost/sklearn.py", line 396, in fit
xgb_options = self.get_xgb_params()
File "/anaconda/envs/SPUDS/lib/python3.5/site-packages/xgboost/sklearn.py", line 177, in get_xgb_params
xgb_params = self.get_params()
File "/anaconda/envs/SPUDS/lib/python3.5/site-packages/xgboost/sklearn.py", line 169, in get_params
if params['missing'] is np.nan:
KeyError: 'missing'
的问题来自于母亲类中,应该已经正确初始化了。这里是母亲类问题的功能:
def get_params(self, deep=False):
"""Get parameter.s"""
params = super(XGBModel, self).get_params(deep=deep)
if params['missing'] is np.nan:
params['missing'] = None # sklearn doesn't handle nan. see #4725
if not params.get('eval_metric', True):
del params['eval_metric'] # don't give as None param to Booster
return params
字典params为明显不正确定义(关键“失踪”不存在)时XGBClassifier初始化称为XGB初始化。 你有什么想法是什么问题或如何跟踪它?
感谢
如果没有条目'params'例如''missing''条目,您将通过尝试访问它''KeyError'。修改条件:'如果params ['missing']是np.nan:'to:'如果params.get('missing',np.nan)是np.nan:'和'eval_metric'相同 – alfasin
The问题是'params = super(XGBModel,self).get_params(deep = deep)',它返回一个没有你想要的“dict”。 –
我明白了,但为什么?该字典是在母类内部创建和管理的,我没有碰它,所以我不明白为什么它在母类中没有很好地定义,当我在继承类中使用它时。 – Salamandre