看一看该文档为sklearn.tree.DecisionTreeClassifier.tree_.value
:
from sklearn.datasets import load_iris
from sklearn.cross_validation import cross_val_score
from sklearn.tree import DecisionTreeClassifier
clf = DecisionTreeClassifier(random_state=0)
iris = load_iris()
clf.fit(iris.data, iris.target)
print(clf.classes_)
[0, 1, 2]
print(clf.tree_.value)
[[[ 50. 50. 50.]]
[[ 50. 0. 0.]]
[[ 0. 50. 50.]]
[[ 0. 49. 5.]]
[[ 0. 47. 1.]]
[[ 0. 47. 0.]]
[[ 0. 0. 1.]]
[[ 0. 2. 4.]]
[[ 0. 0. 3.]]
[[ 0. 2. 1.]]
[[ 0. 2. 0.]]
[[ 0. 0. 1.]]
[[ 0. 1. 45.]]
[[ 0. 1. 2.]]
[[ 0. 0. 2.]]
[[ 0. 1. 0.]]
[[ 0. 0. 43.]]]
每一行中clf.tree_.value
“包含每个节点,的恒定预测值”,其对应索引到索引clf.classes_
(help(clf.tree_)
)。
请参阅this answer(很少)更多的细节。
添加到答案中,对于此数组中的每一行,您都可以执行'clf.classes_ [np.argmax(value)]'来获得预测的类标签。 –
@not_a_robot谢谢。你完美地解释了它。但是我仍然无法找到文档中提到的clf.tree_.value。我想我不再需要它了,因为你的答案正是我正在寻找的。 – user3597574
只是另一个快速问题。看起来像clf.classes_给我标签[0,...,n-1],无论我使用什么标签。我对吗?我期待着[1,...,n]就我而言。 – user3597574