我需要基于具有类成员关系信息的另一个数组(labels
)对1D numpy
数组(data
)中的元素进行总结。我在下面的代码中使用numba
来加速它。但是,如果我斑点没有明确的线ret[int(find(labels, g))] += y
投与int()
,我reveice的错误消息:使用numba对numpy数组进行索引时的TypeError
TypeError: unsupported array index type ?int64
有没有更好的解决方法是显式转换?
import numpy as np
from numba import jit
labels = np.array([45, 85, 99, 89, 45, 86, 348, 764])
n = int(1e3)
data = np.random.random(n)
groups = np.random.choice(a=labels, size=n, replace=True)
@jit(nopython=True)
def find(seq, value):
for ct, x in enumerate(seq):
if x == value:
return ct
@jit(nopython=True)
def subsumNumba(data, groups, labels):
ret = np.zeros(len(labels))
for y, g in zip(data, groups):
# not working without casting with int()
ret[int(find(labels, g))] += y
return ret
此代码与我的机器上的Numba 0.28.1一起使用时没有错误。你使用的是哪个版本的Numba。另外作为一个附注,你可能想要避免使用'zip'和'enumerate'并明确使用索引计数器出于性能原因。你必须测试一下,看看它是否对你的用例产生了影响,但在过去,根据我的经验,它确实如此。 – JoshAdel
@JoshAdel我有版本0.26.0(将尝试更新现在)。你的意思是代码在你的机器上没有* int()强制转换? – NoBackingDown
@JoshAdel它没有枚举测试函数'find',性能增益最小。进一步优化代码时,我会牢记它。 – NoBackingDown