2017-03-22 118 views
1

我有许多树的xgboost.dump文本文件。 我想查找所有路径以获取每个路径的值。 这是一棵树。找到来自xgboost.dump的二叉树的所有路径

tree[0]: 
0:[a<0.966398] yes=1,no=2,missing=1 
    1:[b<0.323071] yes=3,no=4,missing=3 
     3:[c<0.461248] yes=7,no=8,missing=7 
      7:leaf=0.00972768 
      8:leaf=-0.0179376 
     4:[a<0.379082] yes=9,no=10,missing=9 
      9:leaf=0.0146003 
      10:leaf=0.0454369 
    2:[b<0.322352] yes=5,no=6,missing=5 
     5:[c<0.674868] yes=11,no=12,missing=11 
      11:leaf=0.0497964 
      12:leaf=0.00953781 
     6:[f<0.598267] yes=13,no=14,missing=13 
      13:leaf=0.0504545 
      14:leaf=0.0867654 

我想所有的路径转换成

path1, a<0.966398, b<0.323071, c<0.461248, leaf = 0.00097268 
path2, a<0.966398, b<0.323071, c>0.461248, leaf = -0.0179376 
path3, a<0.966398, b>0.323071, a<0.379082, leaf = 0.0146003 
path4, a<0.966398, b>0.323071, a>0.379082, leaf = 0.0454369 
path5, a>0.966398, b<0.322352, c<0.674868, leaf = 0.0497964 
path6, a>0.966398, b<0.322352, c>0.674868, leaf = 0.00953781 
path7, a>0.966398, b>0.322352, f<0.598267, leaf = 0.0504545 
path8, a>0.966398, b>0.322352, f>0.598267, leaf = 0.0864654 

我已经尝试列出像

array([[ 0, 1, 3, 7], 
     [ 0, 1, 3, 8], 
     [ 0, 1, 4, 9], 
     [ 0, 1, 4, 10], 
     [ 0, 2, 5, 11], 
     [ 0, 2, 5, 12], 
     [ 0, 2, 6, 13], 
     [ 0, 2, 6, 14]]) 

所有可能的路径,但一旦MAX_DEPTH较高这样会导致错误,一些分支将停止增长,路径将错误。 所以我需要解析文本文件中的yes,no来生成真实的,正确的路径。 有什么建议吗? 谢谢!

回答

0

下面是我用R实现来解决这个问题的方法。其他语言的用户可以遵循逻辑和实物复制。

首先,我从由xgb.model.dt.tree()生成的模型转储文件开始。

然后,我写了一个函数来解析从任意节点到转储模型的单个树内的最终父节点的有效路径。

后来,我使用purrr :: by_row()将该函数应用于模型转储的所有终端节点“叶”记录,并将结果转换用于目的。

该函数有两个参数,一个用于正在测试的树,另一个用于终端节点的标识。它遵循以下一般步骤:

  1. 与每个树为单位的目标(终端)节点开始,查找具有目标节点为有效孩子的C中(行“是”,“否” ,“失踪”)决定分裂。
  2. 将此有效父节点ID连接到一个向量中,该向量将用于跟踪从目标节点到最终父节点的路径的每个步骤。该向量在函数完成时返回。
  3. 接下来,为链上的每个节点重复“谁是我的父母”步骤,直到路径碰到最终父母为止(此节点ID始终以“-0”结尾),同时更新每个新步骤的路径向量连锁,链条。
  4. 一旦函数命中终端节点,返回()该路径。

在我的情况下,我将这个函数应用到模型转储中的所有“Leaf”节点上,使用purrr :: by_row()while .collat​​ing =“rows”来表示通道作为输出中的附加行。

这也很可能不是最快的方式。

xgb.booster模型中nrounds或max_depth的增加将导致此过程的运行时间增加。您可以使用树的子集(xgb.model.dt.tree()的参数n_first_tree = N)来开发您的方法,以便您可以估算解析整个最终模型中的终端节点路径所需的时间。在我的情况下,在max_depth = 5时有500棵树的模型可能需要30分钟以上。