2017-03-02 30 views
0

我有一个RandomForestRegressor,GBTRegressor,我想获取它们的所有参数。如何获取PySpark中估计器的所有参数

from pyspark.ml.regression import RandomForestRegressor, GBTRegressor 
est = RandomForestRegressor() 
est.getMaxDepth() 
est.getSeed() 

RandomForestRegressorGBTRegressor有不同的参数,所以它不是铁杆所有的方法是一个好主意:我发现的唯一方式,它可以与几个get方法等来完成。 一种解决方法可能是这样的:

get_methods = [method for method in dir(est) if method.startswith('get')] 

params_est = {} 
for method in get_methods: 
    try: 
     key = method[3:] 
     params_est[key] = getattr(est, method)() 
    except TypeError: 
     pass 

然后输出将是这样的:

params_est 

{'CacheNodeIds': False, 
'CheckpointInterval': 10, 
'FeatureSubsetStrategy': 'auto', 
'FeaturesCol': 'features', 
'Impurity': 'variance', 
'LabelCol': 'label', 
'MaxBins': 32, 
'MaxDepth': 5, 
'MaxMemoryInMB': 256, 
'MinInfoGain': 0.0, 
'MinInstancesPerNode': 1, 
'NumTrees': 20, 
'PredictionCol': 'prediction', 
'Seed': None, 
'SubsamplingRate': 1.0} 

但我觉得应该有一个更好的方式来做到这一点。

回答

1

extractParamMap可用于从各个估计得到的所有参数,可以例如:

>>> est = RandomForestRegressor() 
>>> {param[0].name: param[1] for param in est.extractParamMap().items()} 
{'numTrees': 20, 'cacheNodeIds': False, 'impurity': 'variance', 'predictionCol': 'prediction', 'labelCol': 'label', 'featuresCol': 'features', 'minInstancesPerNode': 1, 'seed': -5851613654371098793, 'maxDepth': 5, 'featureSubsetStrategy': 'auto', 'minInfoGain': 0.0, 'checkpointInterval': 10, 'subsamplingRate': 1.0, 'maxMemoryInMB': 256, 'maxBins': 32} 
>>> est = GBTRegressor() 
>>> {param[0].name: param[1] for param in est.extractParamMap().items()} 
{'cacheNodeIds': False, 'impurity': 'variance', 'predictionCol': 'prediction', 'labelCol': 'label', 'featuresCol': 'features', 'stepSize': 0.1, 'minInstancesPerNode': 1, 'seed': -6363326153609583521, 'maxDepth': 5, 'maxIter': 20, 'minInfoGain': 0.0, 'checkpointInterval': 10, 'subsamplingRate': 1.0, 'maxMemoryInMB': 256, 'lossType': 'squared', 'maxBins': 32} 
+0

我看到那个方法,但我错过了它与价值观字典...谢谢。 –

0

How to print best model params in pyspark pipeline描述的那样,你可以使用可用在任何模型的原始JVM对象的任何模型参数以下结构

<yourModel>.stages[<yourModelStage>]._java_obj.<getYourParameter>() 

都得到参数都可以在这里 https://spark.apache.org/docs/latest/api/java/org/apache/spark/ml/classification/RandomForestClassificationModel.html

例如,如果你想交叉验证之后,让您的随机森林的MAXDEPTH(getMaxDepth不可用PySpark)您使用

cvModel.bestModel.stages[-1]._java_obj.getMaxDepth() 
相关问题