2017-02-17 56 views
2

我读星火文档的OHE项,如何解释星火结果OneHotEncoder

一热编码映射标签指数的二进制矢量的一列一列,用最多的单个值。此编码允许期望连续特征的算法(如Logistic回归)使用分类特征。

但遗憾的是他们没有给出OHE结果的完整解释。于是跑去给定的代码:

from pyspark.ml.feature import OneHotEncoder, StringIndexer 

df = sqlContext.createDataFrame([ 
(0, "a"), 
(1, "b"), 
(2, "c"), 
(3, "a"), 
(4, "a"), 
(5, "c") 
], ["id", "category"]) 

stringIndexer = StringIndexer(inputCol="category",  outputCol="categoryIndex") 
model = stringIndexer.fit(df) 
indexed = model.transform(df) 

encoder = OneHotEncoder(inputCol="categoryIndex", outputCol="categoryVec") 
encoded = encoder.transform(indexed) 
encoded.show() 

,并得到了结果:

+---+--------+-------------+-------------+ 
    | id|category|categoryIndex| categoryVec| 
    +---+--------+-------------+-------------+ 
    | 0|  a|   0.0|(2,[0],[1.0])| 
    | 1|  b|   2.0| (2,[],[])| 
    | 2|  c|   1.0|(2,[1],[1.0])| 
    | 3|  a|   0.0|(2,[0],[1.0])| 
    | 4|  a|   0.0|(2,[0],[1.0])| 
    | 5|  c|   1.0|(2,[1],[1.0])| 
    +---+--------+-------------+-------------+ 

我怎么能解释OHE(最后一栏)的结果?

回答

6

一个热编码变换中categoryIndex的值入二元载体,其中在最大的一个值可以是1。由于有三个值,所述载体是长度为2和所述映射如下:

0 -> 10 
1 -> 01 
2 -> 00 

(为什么是这样的映射?看this question约一热编码器丢弃最后一类)。

categoryVec列中的值是完全相同,但这些在稀疏的格式表示。在这种格式下,矢量的零点不被打印。第一个值(2)显示向量的长度,第二个值是一个数组,其中列出了零个索引,其中找到了非零条目,第三个值是另一个数组,该数组指示在这些索引处找到哪些数字。所以(2,[0],[1.0])意味着长度为2的矢量,其中位置0处为1.0,其他位置处为0。

参见:https://spark.apache.org/docs/latest/mllib-data-types.html#local-vector