2014-10-20 68 views
0

我有一个很大的无向数组。我想迭代它来检查一个条件是否在本地满足。下一段代码解释了我的问题。在多维numpy数组上迭代的快速条件检查

a = np.random.randint(2, size=(60,80,3,3)) 

test = np.array([[1,0,0],[0,1,0],[0,0,0]]) 

for i in xrange(a.shape[0]): 
    for j in xrange(b.shape[1]): 
     if (a[i,j] == test).all(): 
      # Do something with indices i and j 

该代码显然非常慢。我尝试使用numpy.where,但它不起作用,因为它在四个指标中的每一个都寻求平等。

编辑:我也需要存储满足

回答

1
np.apply_over_axes(np.prod, a == test, [3,2]) == 1 

给你大小(60,80,1,1)的数组是True徘徊无论该条件成立的情况的指数(i,j)。由线程起动机中发现的短,preferrable版本是

(a == test).all(axis=(2,3)) 

两者都是等效的,但后者避免了布尔→整数→布尔转换。在该阵列上使用np.where以获得索引(i, j)

+0

哇,似乎工作。你认为它与'np.where(a == test).all(axis =(2,3)),1,0'完全相同吗?我发现这个工作,至少看起来像。 – fmonegaglia 2014-10-20 10:51:55

+1

你的意思是'np.where(a == test).all(axis =(2,3))'?是的,这是相同的。看起来好多了。 – Phillip 2014-10-20 10:56:08