2017-05-24 16 views
3

我已经使用scikit-learn培训了一个随机森林模型,现在我想将其树形结构保存在文本文件中,以便我可以在其他地方使用它。 根据this link,树对象由许多平行数组组成,每个数组都包含一些关于树的不同节点的信息(例如,左边的孩子,右边的孩子,它检查的功能......)。但是,似乎没有关于每个叶节点对应的类标签的信息!甚至在上面的链接提供的例子中都没有提到。scikit-learn将树结构中每个叶节点的决策标签保存在哪里?

有人知道scikit-learn决策树结构中存储的类标签在哪里?

回答

3

看一看该文档为sklearn.tree.DecisionTreeClassifier.tree_.value

from sklearn.datasets import load_iris 
from sklearn.cross_validation import cross_val_score 
from sklearn.tree import DecisionTreeClassifier 

clf = DecisionTreeClassifier(random_state=0) 
iris = load_iris() 

clf.fit(iris.data, iris.target) 

print(clf.classes_) 

[0, 1, 2] 

print(clf.tree_.value) 

[[[ 50. 50. 50.]] 

[[ 50. 0. 0.]] 

[[ 0. 50. 50.]] 

[[ 0. 49. 5.]] 

[[ 0. 47. 1.]] 

[[ 0. 47. 0.]] 

[[ 0. 0. 1.]] 

[[ 0. 2. 4.]] 

[[ 0. 0. 3.]] 

[[ 0. 2. 1.]] 

[[ 0. 2. 0.]] 

[[ 0. 0. 1.]] 

[[ 0. 1. 45.]] 

[[ 0. 1. 2.]] 

[[ 0. 0. 2.]] 

[[ 0. 1. 0.]] 

[[ 0. 0. 43.]]] 

每一行中clf.tree_.value“包含每个节点,的恒定预测值”,其对应索引到索引clf.classes_help(clf.tree_))。

请参阅this answer(很少)更多的细节。

+2

添加到答案中,对于此数组中的每一行,您都可以执行'clf.classes_ [np.argmax(value)]'来获得预测的类标签。 –

+0

@not_a_robot谢谢。你完美地解释了它。但是我仍然无法找到文档中提到的clf.tree_.value。我想我不再需要它了,因为你的答案正是我正在寻找的。 – user3597574

+0

只是另一个快速问题。看起来像clf.classes_给我标签[0,...,n-1],无论我使用什么标签。我对吗?我期待着[1,...,n]就我而言。 – user3597574