我想优化下面的代码,可能通过在Cython中重写它:它只需要一个低维但相对较长的numpy数组,查看其值为0的列,并将它们标记为-1。代码是:优化索引和检索Python中numpy数组中的元素?
import numpy as np
def get_data():
data = np.array([[1,5,1]] * 5000 + [[1,0,5]] * 5000 + [[0,0,0]] * 5000)
return data
def get_cols(K):
cols = np.array([2] * K)
return cols
def test_nonzero(data):
K = len(data)
result = np.array([1] * K)
# Index into columns of data
cols = get_cols(K)
# Mark zero points with -1
idx = np.nonzero(data[np.arange(K), cols] == 0)[0]
result[idx] = -1
import time
t_start = time.time()
data = get_data()
for n in range(5000):
test_nonzero(data)
t_end = time.time()
print (t_end - t_start)
data
是数据。 cols
是查找非零值的数据列数组(为简单起见,我将它全部放在同一列中)。我们的目标是计算一个numpy数组,result
,其中感兴趣列非零的每行的值为1,并且感兴趣的相应列的值为零的行的值为-1。
在15,000行3列不太大的数组上运行5000次大约需要20秒。有没有办法可以加快速度?看起来大部分工作都是寻找非零元素并用索引检索它们(调用nonzero
并随后使用它的索引)。这可以优化吗?或者这是最好的可以完成的吗? Cython实现如何在这方面获得更快的速度?
非零是一个很好的尝试(不知道它是否帮助很大或所有虽然)。如果你绝望并知道cols是有效的,你可以尝试制作一个线性索引。如果K在循环中不变,则不能每次都重做np.arange ... – seberg