2017-06-02 20 views

回答

1

一般来说,我不会建议试图击败NumPy的。很少有人可以竞争(对于长阵列),更难以找到更快的实现。即使速度更快,速度可能也不会超过2倍。所以它很少值得。

但是我最近试图自己做这样的事情,所以我可以分享一些有趣的结果。

我自己并没有想到这件事。我基于我的方法numbas (re-)implementation of np.median他们可能知道他们在做什么。

我最终什么样的主意是:

import numba as nb 
import numpy as np 

@nb.njit 
def _partition(A, low, high): 
    """copied from numba source code""" 
    mid = (low + high) >> 1 
    if A[mid] < A[low]: 
     A[low], A[mid] = A[mid], A[low] 
    if A[high] < A[mid]: 
     A[high], A[mid] = A[mid], A[high] 
     if A[mid] < A[low]: 
      A[low], A[mid] = A[mid], A[low] 
    pivot = A[mid] 

    A[high], A[mid] = A[mid], A[high] 

    i = low 
    for j in range(low, high): 
     if A[j] <= pivot: 
      A[i], A[j] = A[j], A[i] 
      i += 1 

    A[i], A[high] = A[high], A[i] 
    return i 

@nb.njit 
def _select_lowest(arry, k, low, high): 
    """copied from numba source code, slightly changed""" 
    i = _partition(arry, low, high) 
    while i != k: 
     if i < k: 
      low = i + 1 
      i = _partition(arry, low, high) 
     else: 
      high = i - 1 
      i = _partition(arry, low, high) 
    return arry[:k] 

@nb.njit 
def _nlowest_inner(temp_arry, n, idx): 
    """copied from numba source code, slightly changed""" 
    low = 0 
    high = n - 1 
    return _select_lowest(temp_arry, idx, low, high) 

@nb.njit 
def nlowest(a, idx): 
    """copied from numba source code, slightly changed""" 
    temp_arry = a.flatten() # does a copy! :) 
    n = temp_arry.shape[0] 
    return _nlowest_inner(temp_arry, n, idx) 

我做包含的定时之前的一些热身电话。热身是为了让编译时间不包括在时序:

rselect(np.random.rand(10), 5) 
nlowest(np.random.rand(10), 5) 

有一个(非常)慢的电脑我改变了元件的数量和重复的比特数。但结果似乎表明,我(当然,在numba开发者所做的那样)击败NumPy的:

results = pd.DataFrame(
    index=pd.Index([100, 500, 1000, 5000, 10000, 50000, 100000, 500000], name='Size'), 
    columns=pd.Index(['nsmall_np', 'nsmall_pd', 'nsmall_pir', 'nlowest'], name='Method') 
) 

rselect(np.random.rand(10), 5) 
nlowest(np.random.rand(10), 5) 

for i in results.index: 
    x = np.random.rand(i) 
    n = i // 2 
    for j in results.columns: 
     stmt = '{}(x, n)'.format(j) 
     setp = 'from __main__ import {}, x, n'.format(j) 
     results.set_value(i, j, timeit(stmt, setp, number=100)) 

print(results) 

Method nsmall_np nsmall_pd nsmall_pir  nlowest 
Size             
100  0.00343059 0.561372 0.00190855 0.000935566 
500  0.00428461 1.79398 0.00326862 0.00187225 
1000 0.00560669 3.36844 0.00432595 0.00364284 
5000  0.0132515 0.305471 0.0142569 0.0108995 
10000 0.0255161 0.340215 0.024847 0.0248285 
50000  0.105937 0.543337 0.150277  0.118294 
100000  0.2452 0.835571 0.333697  0.248473 
500000  1.75214 3.50201  2.20235  1.44085 

enter image description here

+0

你需要改变多少代码才能使用'njit'? – piRSquared

+1

'_partition'函数被简单地复制,'_select'函数只在最后一行('arry [:k]'而不是'arry [k]')中被改变。另外两个函数被改变了一点:我改变了函数名称,用一个新的'idx'参数替换了'mid'部分,并删除了处理一个偶数长度数组中位数的部分。 'nlowest'函数最初是'median_impl'函数。我也用'@ njit'改变了'@ register_jitable',并且我不需要('想要')'@ overload'。说实话,这个评论可能需要花费更长时间才能改变numba源代码。 :D – MSeifert

+0

是的,看着你链接的代码,看起来他们已经是'numba'的老练用户了。感谢分享:-) – piRSquared

2

更新
@ user2357112指出在我的功能在现场操纵的评论中。转过来就是我的表现提升来自的地方。所以最后,我们与quickselectnumba的粗略实现具有非常相似的性能。仍然没有什么可以打喷嚏,但不是我所希望的。


正如我在质询时说,我与numba瞎搞,想和大家分享我已经找到。

请注意,我已导入njit而不是jit。这是一个装饰器,可以自动防止本身回退到本地python对象。意思是说,当它加快速度时,它只会使用它实际上可以加速的东西。这反过来意味着我的功能失败了很多,而我找出什么是允许的,什么是不允许的。

到目前为止,这是我的看法,与numba小号jitnjit写东西是挑剔和困难,但那种值得的,当你看到一个不俗的表现回报。

这是我的快速和肮脏的quickselect功能

import numpy as np 
from numba import njit 
import pandas as pd 
import numexpr as ne 

@njit 
def rselect(a, k): 
    n = len(a) 
    if n <= 1: 
     return a 
    elif k > n: 
     return a 
    else: 
     p = np.random.randint(n) 
     pivot = a[p] 
     a[0], a[p] = a[p], a[0] 
     i = j = 1 
     while j < n: 
      if a[j] < pivot: 
       a[j], a[i] = a[i], a[j] 
       i += 1 
      j += 1 
     a[i-1], a[0] = a[0], a[i-1] 
     if i - 1 <= k <= i: 
      return a[:k] 
     elif k > i: 
      return np.concatenate((a[:i], rselect(a[i:], k - i))) 
     else: 
      return rselect(a[:i-1], k) 

你会发现它返回相同的元素以问题的方法。

rselect(x, 5) 

array([2, 1, 0, 3, 4]) 

什么速度?

def nsmall_np(x, n): 
    return np.partition(x, n)[:n] 

def nsmall_pd(x, n): 
    pd.Series(x).nsmallest().values 

def nsmall_pir(x, n): 
    return rselect(x.copy(), n) 


from timeit import timeit 


results = pd.DataFrame(
    index=pd.Index([100, 1000, 3000, 6000, 10000, 100000, 1000000], name='Size'), 
    columns=pd.Index(['nsmall_np', 'nsmall_pd', 'nsmall_pir'], name='Method') 
) 

for i in results.index: 
    x = np.random.rand(i) 
    n = i // 2 
    for j in results.columns: 
     stmt = '{}(x, n)'.format(j) 
     setp = 'from __main__ import {}, x, n'.format(j) 
     results.set_value(
      i, j, timeit(stmt, setp, number=1000) 
     ) 

results 

Method nsmall_np nsmall_pd nsmall_pir 
Size          
100  0.003873 0.336693 0.002941 
1000  0.007683 1.170193 0.011460 
3000  0.016083 0.309765 0.029628 
6000  0.050026 0.346420 0.059591 
10000  0.106036 0.435710 0.092076 
100000 1.064301 2.073206 0.936986 
1000000 11.864195 27.447762 12.755983 

results.plot(title='Selection Speed', colormap='jet', figsize=(10, 6)) 

[1]: https://i.stack.imgur.com/hKo2o png格式

+2

你似乎变异的输入,而'numpy.partition'进行复印。你是否定时执行了'ndarray.partition'方法的性能? – user2357112

+0

@ user2357112好眼睛...看着它 – piRSquared

+0

@ user2357112和** PooF **有所有的性能好处。谢谢......看到乱搞已经教会了我一些东西。 – piRSquared

相关问题