2014-05-09 39 views
14

我有两个问题,了解从scikit学习决策树的结果。例如,这是我的决策树之一:如何解释从决策树scikit学习

enter image description here 我的问题是,我该如何使用树?

第一个问题是:若样品满足的条件,然后将其转到LEFT分支(如果存在的话),否则它会RIGHT。在我的情况下,如果一个样本的X [7]> 63521.3984。然后样品将进入绿色框。正确?

第二个问题是:当一个样品到达叶节点,我怎样才能知道它所属的类别?在这个例子中,我有三个类别进行分类。在红色框中,分别有91,212和113个样本满足条件。但是,我怎样才能决定这个类别呢? 我知道有一个函数clf.predict(样品)告诉类别。我可以做图吗? 非常感谢。

+1

出于好奇,你是如何绘制决策树的? – Matt

+4

首先将树导出为JSON格式(参见[链接](http://www.garysieling.com/blog/rending-scikit-decision-trees-d3-js)),然后使用d3.js绘制该树。或者你可以直接使用嵌入式函数:'tree.export_graphviz(clf,out_file = your_out_file,feature_names = your_feature_names)'希望它能起作用,@Matt –

回答

21

value线在每个盒子告诉你很多样品在该节点落入每个类别,为了如何。这就是为什么在每个框中,value中的数字合计为sample中显示的数字。例如,在你的红色框中,91 + 212 + 113 = 416。因此,这意味着如果到达此节点,则类别1中有91个数据点,类别2中有212个数据点,类别3中有113个。

如果您要预测到达该叶节点的新数据点的结果在决策树中,您会预测类别2,因为这是该节点上样本的最常见类别。

+0

我有兴趣知道哪个值属于哪个类。 'DecisionTreeClassifier.classes'持有这个信息。 – ezdazuzena

+0

(有用的答案:为了澄清使用python索引,尽管:红色框中的样本登陆将被预测(计数212)为类别1,而不是类别0(91)或类别2(113):-)) –

0

根据这本书“学习scikit学习:机器在Python学习”,决策树表示一系列的基于训练数据进行决策。

!(http://i.imgur.com/vM9fJLy.png

为实例进行分类,我们应该回答每个节点的问题。例如,性别< = 0.5? (我们是在谈论一个女人?)。 如果答案是肯定的,则转到树中的左侧子节点;否则你去右边的子节点。你一直在回答问题(她是在第三堂课吗?她是在第一堂课吗?她是13岁以下的?),直到你到达一片叶子。 当您在那里时,预测对应于具有大多数实例的目标类

2

第一个问题: 是的,你的逻辑是正确的。左边的节点是True,右边的节点是False。这是违反直觉的;真实通常意味着一个较小的值。

第二个问题: 这个问题最好通过用pydotplus将图形可视化为图来解决。 tree.export_graphviz()的'class_names'属性将为每个节点的大多数类添加一个类声明。代码在iPython中执行。

from sklearn.datasets import load_iris 
from sklearn import tree 
iris = load_iris() 
clf2 = tree.DecisionTreeClassifier() 
clf2 = clf2.fit(iris.data, iris.target) 

with open("iris.dot", 'w') as f: 
    f = tree.export_graphviz(clf, out_file=f) 

import os 
os.unlink('iris.dot') 

import pydotplus 
dot_data = tree.export_graphviz(clf2, out_file=None) 
graph2 = pydotplus.graph_from_dot_data(dot_data) 
graph2.write_pdf("iris.pdf") 

from IPython.display import Image 
dot_data = tree.export_graphviz(clf2, out_file=None, 
        feature_names=iris.feature_names, 
        class_names=iris.target_names, 
        filled=True, rounded=True, # leaves_parallel=True, 
        special_characters=True) 
graph2 = pydotplus.graph_from_dot_data(dot_data) 

## Color of nodes 
nodes = graph2.get_node_list() 

for node in nodes: 
    if node.get_label(): 
     values = [int(ii) for ii in node.get_label().split('value = [')[1].split(']')[0].split(',')]; 
     color = {0: [255,255,224], 1: [255,224,255], 2: [224,255,255],} 
     values = color[values.index(max(values))]; # print(values) 
     color = '#{:02x}{:02x}{:02x}'.format(values[0], values[1], values[2]); # print(color) 
     node.set_fillcolor(color) 
# 

Image(graph2.create_png()) 

enter image description here

作为用于确定的类别在叶,你的例子不具有叶与单个类,如虹膜数据集一样。这很常见,可能需要过度拟合模型才能获得这样的结果。类的离散分布是许多交叉验证模型的最佳结果。

享受代码!