2017-07-07 186 views
2

在DataFrameGroupBy基地行,我有以下数据框(称为DF):熊猫选择在分级指数

user_id product_id probReorder 
0  1   196   1.0 
1  1  10258   0.9 
2  1  10326   0.1 
3  1  12427   1.0 
4  1  13032   0.3 
... 

对于每个DF user_id说明,我想只保留N个在“probReorder”列中具有最大值的行。另外,我想N取决于user_id。 在我目前的做法,我有一个字典“lastReordNumber”,其键值对是(USER_ID,INT),我选择如下行:

predictions = [] 
for usr,data in df.groupby(by="user_id"): 
    data = data.nlargest(lastReordNumber[usr], "probReorder") 
    predictions.append(data) 
df = pd.concat(predictions) 

的问题是,这实在是太慢了。该数据帧有大约13M行和200k独特user_id's。有更快/更好的方法吗?

编辑:先前的代码产生当存在对于给定的USER_IDprobReorder列重复的值意外输出。例如:

lastReordNumber = {1:2, 2:3} 
df = pd.DataFrame({"user_id":[1,1,1,2,2,2,2],"probReorder":[0.9,0.6,0.9,0.1,1,0.5,0.4],\ 
    "product_id":[1,2,3,4,5,6,7]}) 

我得到的输出:

probReorder product_id user_id 
0   0.9   1  1 
1   0.9   3  1 
2   0.9   1  1 
3   0.9   3  1 
4   1.0   5  2 
5   0.5   6  2 
6   0.4   7  2 

这对于USER_ID = 2是我所期望的,但对于USER_ID = 1有重复的行。 我的预期输出是:

probReorder product_id user_id 
0   0.9   1  1 
1   0.9   3  1 
2   1.0   5  2 
3   0.5   6  2 
4   0.4   7  2 

这可以通过使用简单的一段代码

predictions = [] 
for usr,data in df.groupby(by="user_id"): 
    predictions.append(data.sort_values('probReorder', ascending=False).head(lastReordNumber[usr])) 
predictions = pd.concat(predictions, ignore_index=True) 

,其中每一列被完全排序,然后截断而获得。这也是相当高效的。 虽然我还没有理解如何解释nlargest()方法的结果。

+0

当你有等于最大的两个或多个行会发生什么? –

+0

@BobHaffner好问题。看起来nlargest不像我预期的那样行事,并且正在复制一些行。我应该发布测试用例的输出吗? – chubecca

+0

我会发布一些其他示例数据,其中包含另一个user_id。并发布您的期望输出 –

回答

2

您可以使用sort_valuesgroupbyhead

df1 = df.sort_values('probReorder', ascending=False) 
     .groupby('user_id', group_keys=False) 
     .apply(lambda x: x.head([x.name])) 
print (df1) 
    probReorder product_id user_id 
0   0.9   1  1 
2   0.9   3  1 
4   1.0   5  2 
5   0.5   6  2 
6   0.4   7  2 

nlargest另一种解决方案:

df1 = df.groupby('user_id', group_keys=False) 
     .apply(lambda x: x.nlargest(lastReordNumber[x.name], 'probReorder')) 
print (df1) 
    probReorder product_id user_id 
0   0.9   1  1 
2   0.9   3  1 
4   1.0   5  2 
5   0.5   6  2 
6   0.4   7  2 
+0

感谢您的答案。一些评论:drop_duplicates()在这种情况下不做任何事情,因为没有重复的(user_id,product_id)对。您的第一个解决方案应该与我在编辑中提供的解决方案相同,但它更优雅,也许更有效。您的第二个解决方案在我的机器上无法正常工作,它会产生上述我提供的相同“错误”输出。这可能是nlargest()中的一个错误,我必须查看它。 – chubecca

+0

我看到您的数据,似乎有一些重复。如果不是,那会更好。如果我的回答有帮助,请不要忘记[接受](http://meta.stackexchange.com/a/5235/295067) - 点击答案旁边的复选标记('✓')将其从灰色出来填补。谢谢。 – jezrael

+0

正如我所说的(“user_id”,“product_id”)列中没有重复项,如果我错了,请纠正我,因此您对drop_duplicates的调用不会执行任何操作。您的两个解决方案与我的两个解决方案相同,但其中一个解决方案在我的系统上的行为不如预期。我认为我的原始问题已经解决,但我仍然不明白nlargest()的问题。 – chubecca