2013-11-01 164 views
5

我想找出找到两个np.arrays的行相交的有效方法。高效地找到两个二维numpy阵列的行相交

两个数组具有相同的形状,并且每行中的重复值不会发生。

例如:

import numpy as np 

a = np.array([[2,5,6], 
       [8,2,3], 
       [4,1,5], 
       [1,7,9]]) 

b = np.array([[2,3,4], # one element(2) in common with a[0] -> 1 
       [7,4,3], # one element(3) in common with a[1] -> 1 
       [5,4,1], # three elements(5,4,1) in common with a[2] -> 3 
       [7,6,9]]) # two element(9,7) in common with a[3] -> 2 

我的期望的输出是:np.array([1,1,3,2])

这是很容易与循环来做到这一点:

def get_intersect1ds(a, b): 
    result = np.empty(a.shape[0], dtype=np.int) 
    for i in xrange(a.shape[0]): 
     result[i] = (len(np.intersect1d(a[i], b[i]))) 
    return result 

结果:

>>> get_intersect1ds(a, b) 
array([1, 1, 3, 2]) 

但是有没有更有效的方法来做到这一点?

+0

嗯。 a'和'b'可以在每一行中有重复的值吗? – YXD

+0

@MrE好点,重复不会发生。谢谢。 – Akavall

+0

你期望输入数组有多大? –

回答

6

如果你有行内没有重复,你可以尝试复制什么np.intersect1d引擎盖下做(看源代码here):

>>> c = np.hstack((a, b)) 
>>> c 
array([[2, 5, 6, 2, 3, 4], 
     [8, 2, 3, 7, 4, 3], 
     [4, 1, 5, 5, 4, 1], 
     [1, 7, 9, 7, 6, 9]]) 
>>> c.sort(axis=1) 
>>> c 
array([[2, 2, 3, 4, 5, 6], 
     [2, 3, 3, 4, 7, 8], 
     [1, 1, 4, 4, 5, 5], 
     [1, 6, 7, 7, 9, 9]]) 
>>> c[:, 1:] == c[:, :-1] 
array([[ True, False, False, False, False], 
     [False, True, False, False, False], 
     [ True, False, True, False, True], 
     [False, False, True, False, True]], dtype=bool) 
>>> np.sum(c[:, 1:] == c[:, :-1], axis=1) 
array([1, 1, 3, 2]) 
+1

明智的答案 – YXD

1

我想不出一个干净的纯numpy的解决方案,但以下建议应该加快速度,有可能显着:

  1. 使用numba。当你调用intersect1d
+0

不幸的是,我没有进入numba,但我在想cython。我认为它也应该可以工作。感谢您的建议。 – Akavall

2

这个答案可能并不可行,因为如果输入具有形状(N,M),它会生成一个中间数组是与@autojit

  • assume_unique = True装饰你get_intersect1ds功能一样简单大小为(N,M,M),但它总是有趣的,看看你可以用广播做什么:

    In [43]: a 
    Out[43]: 
    array([[2, 5, 6], 
         [8, 2, 3], 
         [4, 1, 5], 
         [1, 7, 9]]) 
    
    In [44]: b 
    Out[44]: 
    array([[2, 3, 4], 
         [7, 4, 3], 
         [5, 4, 1], 
         [7, 6, 9]]) 
    
    In [45]: (np.expand_dims(a, -1) == np.expand_dims(b, 1)).sum(axis=-1).sum(axis=-1) 
    Out[45]: array([1, 1, 3, 2]) 
    

    对于大型阵列,该方法可以作出更多的内存友好通过分批所做的操作。