2013-04-16 142 views
1

我想优化下面的代码,可能通过在Cython中重写它:它只需要一个低维但相对较长的numpy数组,查看其值为0的列,并将它们标记为-1。代码是:优化索引和检索Python中numpy数组中的元素?

import numpy as np 

def get_data(): 
    data = np.array([[1,5,1]] * 5000 + [[1,0,5]] * 5000 + [[0,0,0]] * 5000) 
    return data 

def get_cols(K): 
    cols = np.array([2] * K) 
    return cols 

def test_nonzero(data): 
    K = len(data) 
    result = np.array([1] * K) 
    # Index into columns of data 
    cols = get_cols(K) 
    # Mark zero points with -1 
    idx = np.nonzero(data[np.arange(K), cols] == 0)[0] 
    result[idx] = -1 

import time 
t_start = time.time() 
data = get_data() 
for n in range(5000): 
    test_nonzero(data) 
t_end = time.time() 
print (t_end - t_start) 

data是数据。 cols是查找非零值的数据列数组(为简单起见,我将它全部放在同一列中)。我们的目标是计算一个numpy数组,result,其中感兴趣列非零的每行的值为1,并且感兴趣的相应列的值为零的行的值为-1。

在15,000行3列不太大的数组上运行5000次大约需要20秒。有没有办法可以加快速度?看起来大部分工作都是寻找非零元素并用索引检索它们(调用nonzero并随后使用它的索引)。这可以优化吗?或者这是最好的可以完成的吗? Cython实现如何在这方面获得更快的速度?

+0

非零是一个很好的尝试(不知道它是否帮助很大或所有虽然)。如果你绝望并知道cols是有效的,你可以尝试制作一个线性索引。如果K在循环中不变,则不能每次都重做np.arange ... – seberg

回答

2
cols = np.array([2] * K) 

这会很慢。这是创建一个非常大的Python列表,然后将其转换为一个numpy数组。相反,这样做:

cols = np.ones(K, int)*2 

那将是这样快

result = np.array([1] * K) 

在这里,你应该做的:

result = np.ones(K, int) 

这将直接产生numpy的阵列。

idx = np.nonzero(data[np.arange(K), cols] == 0)[0] 
result[idx] = -1 

cols是一个数组,但您可以传递一个2.此外,使用非零增加一个额外的步骤。

idx = data[np.arange(K), 2] == 0 
result[idx] = -1 

应该有同样的效果。

+0

感谢但cols不是计算的一部分,它只是一个示例。正如我所提到的,cols并不总是像2这样的一个价值 - 我只是为了简单而做出来的。它通常是一个向量将不同的列。所以我不认为这些建议加快了速度 – user248237dfsf

+2

@ user248237dfsf,好吧,我错过了关于cols不真实的部分。然而,'numpy.array([2] * K)'确实非常慢,并且它代码的原因很慢。没有看到你真实的代码,我真的不知道为什么它会很慢。 –