2015-10-09 62 views
2

考虑一个简单的记录阵列结构:如何聚合NumPy记录数组(总和,最小值,最大值等)?

import numpy as np 
ijv_dtype = [ 
    ('I', 'i'), 
    ('J', 'i'), 
    ('v', 'd'), 
] 
ijv = np.array([ 
    (0, 0, 3.3), 
    (0, 1, 1.1), 
    (0, 1, 4.4), 
    (1, 1, 2.2), 
    ], ijv_dtype) 
print(ijv) # [(0, 0, 3.3) (0, 1, 1.1) (0, 1, 4.4) (1, 1, 2.2)] 

我想从vaggregate某些统计(总和,最小值,最大值等)通过分组的IJ独特组合。从SQL思维,预期的结果是:

select i, j, sum(v) as v from ijv group by i, j; 
i | j | v 
---+---+----- 
0 | 0 | 3.3 
0 | 1 | 5.5 
1 | 1 | 2.2 

(顺序并不重要)

我能想起来的NumPy的是丑陋的最好的,和我没有信心,我已经下令结果正确(虽然它似乎在这里工作):

# Get unique groups, index and inverse 
u_ij, idx_ij, inv_ij = np.unique(ijv[['I', 'J']], return_index=True, return_inverse=True) 
# Assemble aggregate 
a_ijv = np.zeros(len(u_ij), ijv_dtype) 
a_ijv['I'] = u_ij['I'] 
a_ijv['J'] = u_ij['J'] 
a_ijv['v'] = [ijv['v'][inv_ij == i].sum() for i in range(len(u_ij))] 
print(a_ijv) # [(0, 0, 3.3) (0, 1, 5.5) (1, 1, 2.2)] 

我想有一个更好的方法来做到这一点!我正在使用NumPy 1.4.1。

+1

我第一次尝试将使用'(i,j)'元组作为关键字来收集'collections.default_dict(list)'中的数据。然后,我可以在每个列表上预制所需的统计数据。 – hpaulj

回答

1

numpy对于这样的任务来说太低级了。我认为您的解决方案是好的,如果你必须使用纯numpy,但如果你不介意使用的东西与抽象的更高层次,尝试pandas

import pandas as pd 

df = pd.DataFrame({ 
    'I': (0, 0, 0, 1), 
    'J': (0, 1, 1, 1), 
    'v': (3.3, 1.1, 4.4, 2.2)}) 

print(df) 
print(df.groupby(['I', 'J']).sum()) 

输出:

I J v 
0 0 0 3.3 
1 0 1 1.1 
2 0 1 4.4 
3 1 1 2.2 
     v 
I J  
0 0 3.3 
    1 5.5 
1 1 2.2 
+0

随着'numpy'的早期版本,'熊猫'可能不是一个选项。 – hpaulj

相关问题