2016-09-04 37 views
1

我需要基于具有类成员关系信息的另一个数组(labels)对1D numpy数组(data)中的元素进行总结。我在下面的代码中使用numba来加速它。但是,如果我斑点没有明确的线ret[int(find(labels, g))] += y投与int(),我reveice的错误消息:使用numba对numpy数组进行索引时的TypeError

TypeError: unsupported array index type ?int64

有没有更好的解决方法是显式转换?

import numpy as np 
from numba import jit 

labels = np.array([45, 85, 99, 89, 45, 86, 348, 764]) 
n = int(1e3) 
data = np.random.random(n) 
groups = np.random.choice(a=labels, size=n, replace=True) 

@jit(nopython=True) 
def find(seq, value): 
    for ct, x in enumerate(seq): 
     if x == value: 
      return ct 

@jit(nopython=True) 
def subsumNumba(data, groups, labels): 
    ret = np.zeros(len(labels)) 
    for y, g in zip(data, groups): 
     # not working without casting with int() 
     ret[int(find(labels, g))] += y 
    return ret 
+0

此代码与我的机器上的Numba 0.28.1一起使用时没有错误。你使用的是哪个版本的Numba。另外作为一个附注,你可能想要避免使用'zip'和'enumerate'并明确使用索引计数器出于性能原因。你必须测试一下,看看它是否对你的用例产生了影响,但在过去,根据我的经验,它确实如此。 – JoshAdel

+0

@JoshAdel我有版本0.26.0(将尝试更新现在)。你的意思是代码在你的机器上没有* int()强制转换? – NoBackingDown

+0

@JoshAdel它没有枚举测试函数'find',性能增益最小。进一步优化代码时,我会牢记它。 – NoBackingDown

回答

1

的问题是,find可以返回一个intNone如果它没有发现任何东西,所以我觉得?int64错误。为了避免投射,当find退出时,您需要提供int返回值,但不会找到所需的值,然后在调用者中处理它。

+0

就是这样!我没有想到我,因为'find'保证找到我的问题的结构。现在我只是返回一个虚拟整数的理论情况下,没有击中,它的作品! – NoBackingDown