一般来说,我不会建议试图击败NumPy的。很少有人可以竞争(对于长阵列),更难以找到更快的实现。即使速度更快,速度可能也不会超过2倍。所以它很少值得。
但是我最近试图自己做这样的事情,所以我可以分享一些有趣的结果。
我自己并没有想到这件事。我基于我的方法numbas (re-)implementation of np.median
。 他们可能知道他们在做什么。
我最终什么样的主意是:
import numba as nb
import numpy as np
@nb.njit
def _partition(A, low, high):
"""copied from numba source code"""
mid = (low + high) >> 1
if A[mid] < A[low]:
A[low], A[mid] = A[mid], A[low]
if A[high] < A[mid]:
A[high], A[mid] = A[mid], A[high]
if A[mid] < A[low]:
A[low], A[mid] = A[mid], A[low]
pivot = A[mid]
A[high], A[mid] = A[mid], A[high]
i = low
for j in range(low, high):
if A[j] <= pivot:
A[i], A[j] = A[j], A[i]
i += 1
A[i], A[high] = A[high], A[i]
return i
@nb.njit
def _select_lowest(arry, k, low, high):
"""copied from numba source code, slightly changed"""
i = _partition(arry, low, high)
while i != k:
if i < k:
low = i + 1
i = _partition(arry, low, high)
else:
high = i - 1
i = _partition(arry, low, high)
return arry[:k]
@nb.njit
def _nlowest_inner(temp_arry, n, idx):
"""copied from numba source code, slightly changed"""
low = 0
high = n - 1
return _select_lowest(temp_arry, idx, low, high)
@nb.njit
def nlowest(a, idx):
"""copied from numba source code, slightly changed"""
temp_arry = a.flatten() # does a copy! :)
n = temp_arry.shape[0]
return _nlowest_inner(temp_arry, n, idx)
我做包含的定时之前的一些热身电话。热身是为了让编译时间不包括在时序:
rselect(np.random.rand(10), 5)
nlowest(np.random.rand(10), 5)
有一个(非常)慢的电脑我改变了元件的数量和重复的比特数。但结果似乎表明,我(当然,在numba开发者所做的那样)击败NumPy的:
results = pd.DataFrame(
index=pd.Index([100, 500, 1000, 5000, 10000, 50000, 100000, 500000], name='Size'),
columns=pd.Index(['nsmall_np', 'nsmall_pd', 'nsmall_pir', 'nlowest'], name='Method')
)
rselect(np.random.rand(10), 5)
nlowest(np.random.rand(10), 5)
for i in results.index:
x = np.random.rand(i)
n = i // 2
for j in results.columns:
stmt = '{}(x, n)'.format(j)
setp = 'from __main__ import {}, x, n'.format(j)
results.set_value(i, j, timeit(stmt, setp, number=100))
print(results)
Method nsmall_np nsmall_pd nsmall_pir nlowest
Size
100 0.00343059 0.561372 0.00190855 0.000935566
500 0.00428461 1.79398 0.00326862 0.00187225
1000 0.00560669 3.36844 0.00432595 0.00364284
5000 0.0132515 0.305471 0.0142569 0.0108995
10000 0.0255161 0.340215 0.024847 0.0248285
50000 0.105937 0.543337 0.150277 0.118294
100000 0.2452 0.835571 0.333697 0.248473
500000 1.75214 3.50201 2.20235 1.44085
你需要改变多少代码才能使用'njit'? – piRSquared
'_partition'函数被简单地复制,'_select'函数只在最后一行('arry [:k]'而不是'arry [k]')中被改变。另外两个函数被改变了一点:我改变了函数名称,用一个新的'idx'参数替换了'mid'部分,并删除了处理一个偶数长度数组中位数的部分。 'nlowest'函数最初是'median_impl'函数。我也用'@ njit'改变了'@ register_jitable',并且我不需要('想要')'@ overload'。说实话,这个评论可能需要花费更长时间才能改变numba源代码。 :D – MSeifert
是的,看着你链接的代码,看起来他们已经是'numba'的老练用户了。感谢分享:-) – piRSquared