2017-06-06 38 views
0

我在写一个名为XGB的类,它从XGBClassifier(从python库xgboost.sklearn)继承。我写了一个初始化功能和一个合适的一个,如下所示:python继承:新类没有正确初始化

from xgboost.sklearn import XGBClassifier 
from balanceSmote import BalanceSmote 
from balance import Balance 


class XGB(XGBClassifier): 

def __init__(self,learning_rate=0.5, max_depth=3,colsample_bytree=0.5,n_estimators=300, 
      frac=None,k_neighbors=None,m_neighbors=None,out_step=None): 

    # These are the additional arguments that are not in XGBClassifier 
    if k_neighbors: 
     self.balancingStrategy = 'smote' 
     self.k_neighbors = k_neighbors 
     self.m_neighbors = m_neighbors 
     self.out_step = out_step 
    elif frac : 
     self.balancingStrategy = 'normal' 
     self.frac = frac 
    else: 
     self.balancingStrategy = 'false' 

    # Utilize the motherClass 
    super(XGB,self).__init__(seed=500, 
         learning_rate = learning_rate, 
         max_depth = max_depth, 
         colsample_bytree = colsample_bytree, 
         n_estimators = n_estimators) 

这里是我的测试代码:

xgb4 = XGB(learning_rate = 0.1, max_depth = 3, colsample_bytree = 1, n_estimators = 1000) 

xgb4.fit(trainData,trainLabel) 

初始化似乎顺利,但是当我尝试使用配合( )(这是)从XGBClassifier继承的方法,我有一个错误信息,告诉我一个参数丢失:

File "<ipython-input-3-47344b7fbc76>", line 1, in <module> 
runfile('/Users/celsloaner/Project/SPUDS/code/testSpark.py', wdir='/Users/celsloaner/Project/SPUDS/code') 

File "/anaconda/envs/SPUDS/lib/python3.5/site-packages/spyder/utils/site/sitecustomize.py", line 880, in runfile 
execfile(filename, namespace) 

File "/anaconda/envs/SPUDS/lib/python3.5/site-packages/spyder/utils/site/sitecustomize.py", line 102, in execfile 
exec(compile(f.read(), filename, 'exec'), namespace) 

File "/Users/celsloaner/Project/SPUDS/code/testSpark.py", line 50, in <module> 
xgb4.fit(predictor.trainData,predictor.trainLabel) 

File "/anaconda/envs/SPUDS/lib/python3.5/site-packages/xgboost/sklearn.py", line 396, in fit 
xgb_options = self.get_xgb_params() 

File "/anaconda/envs/SPUDS/lib/python3.5/site-packages/xgboost/sklearn.py", line 177, in get_xgb_params 
xgb_params = self.get_params() 

File "/anaconda/envs/SPUDS/lib/python3.5/site-packages/xgboost/sklearn.py", line 169, in get_params 
if params['missing'] is np.nan: 

KeyError: 'missing' 

的问题来自于母亲类中,应该已经正确初始化了。这里是母亲类问题的功能:

def get_params(self, deep=False): 
    """Get parameter.s""" 
    params = super(XGBModel, self).get_params(deep=deep) 
    if params['missing'] is np.nan: 
     params['missing'] = None # sklearn doesn't handle nan. see #4725 
    if not params.get('eval_metric', True): 
     del params['eval_metric'] # don't give as None param to Booster 
    return params 

字典params为明显不正确定义(关键“失踪”不存在)时XGBClassifier初始化称为XGB初始化。 你有什么想法是什么问题或如何跟踪它?

感谢

+0

如果没有条目'params'例如''missing''条目,您将通过尝试访问它''KeyError'。修改条件:'如果params ['missing']是np.nan:'to:'如果params.get('missing',np.nan)是np.nan:'和'eval_metric'相同 – alfasin

+0

The问题是'params = super(XGBModel,self).get_params(deep = deep)',它返回一个没有你想要的“dict”。 –

+0

我明白了,但为什么?该字典是在母类内部创建和管理的,我没有碰它,所以我不明白为什么它在母类中没有很好地定义,当我在继承类中使用它时。 – Salamandre

回答

0

那么,它工作时,我初始化所有的母亲类的参数,即使母亲类的构造函数应该有它们的默认值:

class XGB(XGBClassifier): 

    def __init__(self,max_depth=3, learning_rate=0.1, 
       n_estimators=100, silent=True, 
       objective="binary:logistic", 
       nthread=-1, gamma=0, min_child_weight=1, 
       max_delta_step=0, subsample=1, colsample_bytree=1, colsample_bylevel=1, 
       reg_alpha=0, reg_lambda=1, scale_pos_weight=1, 
       base_score=0.5, seed=0, missing=None, 
       frac=None,k_neighbors=None,m_neighbors=None,out_step=None): 

     if k_neighbors: 
      self.balancingStrategy = 'smote' 
      self.k_neighbors = k_neighbors 
      self.m_neighbors = m_neighbors 
      self.out_step = out_step 
     elif frac : 
      self.balancingStrategy = 'normal' 
      self.frac = frac 
     else: 
      self.balancingStrategy = 'false' 


     super(XGB,self).__init__(max_depth, learning_rate, 
           n_estimators, silent, objective, 
           nthread, gamma, min_child_weight, 
           max_delta_step, subsample, 
           colsample_bytree, colsample_bylevel, 
           reg_alpha, reg_lambda, 
           scale_pos_weight, base_score, seed, missing) 

不知道我理解的逻辑在这里,但它总是很高兴知道:)