2013-11-26 84 views
10

假设我有两个2-d阵列如下:查找匹配行的索引中两个2 d阵列

array([[3, 3, 1, 0], 
     [2, 3, 1, 3], 
     [0, 2, 3, 1], 
     [1, 0, 2, 3], 
     [3, 1, 0, 2]], dtype=int8) 

array([[0, 3, 3, 1], 
     [0, 2, 3, 1], 
     [1, 0, 2, 3], 
     [3, 1, 0, 2], 
     [3, 3, 1, 0]], dtype=int8) 

每个阵列中的一些行具有由值相匹配(但不一定是一个对应的行通过索引)在另一个阵列中,有些则不。

我想找到一种有效的方式来返回两个数组中对应于匹配行的索引对。如果他们是元组我希望回到

(0,4) 
(2,1) 
(3,2) 
(4,3) 

回答

6

这是一个全部的numpy解决方案 - 不一定比迭代Python更好。它仍然需要考虑所有组合。

In [53]: np.array(np.all((x[:,None,:]==y[None,:,:]),axis=-1).nonzero()).T.tolist() 
Out[53]: [[0, 4], [2, 1], [3, 2], [4, 3]] 

中间数组为(5,5,4)。该np.all它简化为:

array([[False, False, False, False, True], 
     [False, False, False, False, False], 
     [False, True, False, False, False], 
     [False, False, True, False, False], 
     [False, False, False, True, False]], dtype=bool) 

剩下的就是提取指数在那里,这是True

在原油的测试中,该倍47.8我们;另一个答案与L1字典在38.3 us;第三个在496美分处有一个双环。

5

我想不出一个numpy的具体办法做到这一点,但这里是我会与常规列表做:

>>> L1= [[3, 3, 1, 0], 
...  [2, 3, 1, 3], 
...  [0, 2, 3, 1], 
...  [1, 0, 2, 3], 
...  [3, 1, 0, 2]] 
>>> L2 = [[0, 3, 3, 1], 
...  [0, 2, 3, 1], 
...  [1, 0, 2, 3], 
...  [3, 1, 0, 2], 
...  [3, 3, 1, 0]] 
>>> L1 = {tuple(row):i for i,row in enumerate(L1)} 
>>> answer = [] 
>>> for i,row in enumerate(L2): 
... if tuple(row) in L1: 
...  answer.append((L1[tuple(row)], i)) 
... 
>>> answer 
[(2, 1), (3, 2), (4, 3), (0, 4)] 
+0

O(n)!尼斯。但是没有一种方法可以做到吗? – slider

+0

@slider:'我想不出一种颠簸的方式去做',主要是因为我不会使用numpy(它在我的待办事项列表中已经超过我很自豪地承认) – inspectorG4dget

+0

这可能是对'L2'只有一行的情况进行了推广,我们希望在'L1'中获得匹配行的'行索引','L1'中的行不一定是唯一的? – sodiumnitrate

4

您可以使用void数据类型技巧在两个数组的行上使用1D函数。 a_viewb_view是一维向量,每个条目代表一整行。然后,我选择对数组进行排序,并使用np.searchsorted来查找该数组中的另一个数组的项目。如果我们排序的数组长度为m而另一个长度为n,排序需要时间m * log(m),并且二进制搜索np.searchsorted确实需要时间n * log(m),总计为(n + m) * log(m)。因此,你要排序中最短的两个数组:

def find_rows(a, b): 
    dt = np.dtype((np.void, a.dtype.itemsize * a.shape[1])) 

    a_view = np.ascontiguousarray(a).view(dt).ravel() 
    b_view = np.ascontiguousarray(b).view(dt).ravel() 

    sort_b = np.argsort(b_view) 
    where_in_b = np.searchsorted(b_view, a_view, 
           sorter=sort_b) 
    where_in_b = np.take(sort_b, where_in_b) 
    which_in_a = np.take(b_view, where_in_b) == a_view 
    where_in_b = where_in_b[which_in_a] 
    which_in_a = np.nonzero(which_in_a)[0] 
    return np.column_stack((which_in_a, where_in_b)) 

随着ab你的两个样本阵列:

In [14]: find_rows(a, b) 
Out[14]: 
array([[0, 4], 
     [2, 1], 
     [3, 2], 
     [4, 3]], dtype=int64) 

In [15]: %timeit find_rows(a, b) 
10000 loops, best of 3: 29.7 us per loop 

在我的系统的字典方法时钟在大约22我们更快地为您的测试数据,但是对于1000x4的数组,这种numpy方法比纯Python方法快6倍(483 us vs 2.54 ms)。 O(n)!

+0

这太棒了。我花了整整一个小时才弄清楚你在做什么。 尽管searchsorted有一个小小的错误,可能会返回该项目应插入到最后,导致索引超出界限错误。例如 – Dalupus

+0

只是将数组的最后一行更改为[3,3,3,3],您将得到'IndexError:索引5超出大小5的范围' – Dalupus