0

我在Spark中有一个RandomForestClassifierModel。使用.toDebugString()输出以下如何查看Spark中的随机森林统计(斯卡拉)

Tree 0 (weight 1.0): 
    If (feature 0 in {1.0,2.0,3.0}) 
    If (feature 3 in {2.0,3.0}) 
    If (feature 8 <= 55.3) 
. 
. 
    Else (feature 0 not in {1.0,2.0,3.0}) 
. 
. 
Tree 1 (weight 1.0): 
. 
. 
...etc 

我想因为它不用通过模型来查看实际的数据,像

Tree 0 (weight 1.0): 
    If (feature 0 in {1.0,2.0,3.0}) 60% 
    If (feature 3 in {2.0,3.0}) 57% 
    If (feature 8 <= 55.3) 22% 
. 
. 
    Else (feature 0 not in {1.0,2.0,3.0}) 40% 
. 
. 
Tree 1 (weight 1.0): 
. 
...etc 

通过观察每个节点的标签的概率,我可以通过数据(数千条记录)看到树中最可能遵循哪条路径,这将是非常好的见解!

我发现一个真棒答案在这里:Spark MLib Decision Trees: Probability of labels by features?

不幸的是在回答该方法使用MLlib API,和大量的努力后,我没能复制它使用数据帧API,它有不同的实现类节点和拆分:(

+0

也许您可以尝试通过查看ml包的代码来调整原始答案:https://github.com/apache/spark/blob/branch-2.0/mllib/src/main/scala/org/ apache/spark/ml/tree/Node.scala –

+1

我看了一下代码。我不确定这是否可能(至少与其他答案没有相同的策略),因为每个节点中的拆分比率都在ml杂志封装的privateStats属性中。也许可以用ImpurityCalculator使用Node的可见属性创建此属性,但是我找不到方法。 –

+0

@DanieldePaula感谢您关注此事。我宁愿不重构我的整个管道使用mllib。我会尽力用你最后的建议找到一种方法。到目前为止,我可以获得数组中的每一棵树。我想用API来做到这一点,而不必重新编写大量的类。如果您碰巧想出任何其他解决方案,请告诉我! – rtcode

回答

0

昨天发现有用的一种方法是我可以使用spark.read.parquet()函数来读取模型/数据文件的输出。这样可以检索关于某个节点的所有信息作为整个数据帧。

`val modelPath = "some/path/to/your/model" 
val dataPath = modelPath + "/data"  
val nodeData: DataFrame = spark.read.parquet(dataPath) 
nodeData.show(500,false) 
nodeData.printSchema()` 

然后你可以用信息重建树。希望能帮助到你。