2016-07-28 29 views
0

我使用星火1.6.2,我有以下数据结构:获得每组的前N项pySpark

sample = sqlContext.createDataFrame([ 
        (1,['potato','orange','orange']), 
        (1,['potato','orange','yogurt']), 
        (2,['vodka','beer','vodka']), 
        (2,['vodka','beer','juice', 'vinegar']) 

    ],['cat','terms']) 

我想提取每猫前N个最常见的术语。我开发了以下解决方案,似乎可行,但我想看看是否有更好的方法来做到这一点。

from collections import Counter 
def get_top(it, terms=200): 
    c = Counter(it.__iter__()) 
    return [x[0][1] for x in c.most_common(terms)] 

(sample.select('cat',sf.explode('terms')).rdd.map(lambda x: (x.cat, x.col)) 
.groupBy(lambda x: x[0]) 
.map(lambda x: (x[0], get_top(x[1], 2))) 
.collect() 
) 

它提供了以下的输出:

[(1, ['orange', 'potato']), (2, ['vodka', 'beer'])] 

这是符合我所期待的,但我真的不喜欢,我诉诸使用计数器的事实。我怎样才能用火花一个人做到这一点?

感谢

回答

1

如果这样的工作可能是更好张贴这Code Review

就像一个练习,我做了这个没有柜台,但很大程度上,你只是复制相同的功能。

  • 计数(catterm
  • 组的每次出现由cat
  • 排序基于计数和切片的值,以(2)项的数目

代码:

from operator import add 

(sample.select('cat', sf.explode('terms')) 
.rdd 
.map(lambda x: (x, 1)) 
.reduceByKey(add) 
.groupBy(lambda x: x[0][0]) 
.mapValues(lambda x: [r[1] for r, _ in sorted(x, key=lambda a: -a[1])[:2]]) 
.collect()) 

输出:

[(1, ['orange', 'potato']), (2, ['vodka', 'beer'])] 
+0

不错。直到mapValues之前,我已经掌握了大部分内容,但无法找出这一点。谢谢! – browskie

+0

之前,我诉诸计数器是... – browskie

+0

它可能是一个挑战,跟踪脚手架结构,提取数据返回。 – AChampion