2014-02-06 82 views
1

比方说,我有两个numpy数组a = (n x m)b = (z x m)其中列m是一些公共标识符。按列合并两个数组

a = np.array([[1, 0, 0, 1], [0, 1, 0, 1], [0, 0, 0, 1], [1, 1, 0, 1]]) 
b = np.array([[1, 0, 0, 1], [1, 1, 0, 1], [0, 1, 1, 0]]) 

有一个numpy的上下的方式来获得c = (n x z)其中c_ij = 1 if (any element in (row i of a AND row j of b) is equal to 1) else 0没有循环,所以在这种情况下

c = np.array([[1, 1, 0], [1, 1, 1], [1, 1, 0], [1, 1, 1]]) 
+0

你能后的这一个循环基于版本返回相同的输出? –

回答

2

IIUC,你可以把它看作一个矩阵乘法:

>>> a = np.array([[1, 0, 0, 1], [0, 1, 0, 1], [0, 0, 0, 1], [1, 1, 0, 1]]) 
>>> b = np.array([[1, 0, 0, 1], [1, 1, 0, 1], [0, 1, 1, 0]]) 
>>> (a.dot(b.T) > 0).astype(int) 
array([[1, 1, 0], 
     [1, 1, 1], 
     [1, 1, 0], 
     [1, 1, 1]]) 

不足之处是这种方法比需要做更多的工作,因为它完成了整个乘法。如果性能确实非常关键(并且关键比人们认为关键的关键要少很多),那么您可以编写一些cython或使用numba来获得类C速度的短路行为。不过,本地的numpy巫师可能会想到一些聪明的东西。 :^)

+0

如果numpy使用体面的BLAS进行编译,这将会非常难以用纯粹的numpy来打败,因为numpy的“any”函数(IIRC)在第一个true时不会破坏,而是会在整个阵列中破译。如果你通过scipy直接调用'SGEMM',你可能会做得更好一些。 – Daniel

0

我知道你没有循环说,但这个工程:

np.array([[np.any(x&y) for x in a] for y in b],dtype=int).T