2013-04-26 28 views
1

我想使用weave.blitz改善以下numpy的代码的性能:闪电代码产生不同的输出

def fastIteration(self): 
    g = self.grid 
    nx,ny = g.ux.shape 

    uxold = g.old_ux 
    ux = g.ux 
    ux[0:,1:-1] = uxold[0:,1:-1] + ReI* (uxold[0:,2:] - 2*uxold[0:,1:-1] + uxold[0:,0:-2]) 

    g.setBC() 
    g.old_ux = ux.copy() 

在此代码g是计算网格。它由两个不同的领域ux和uxold组成。旧的只是用于临时存储变量。在完整的代码中,大约95%的运行时间用于fastIteration方法,因此即使简单的性能增益也会显着减少执行此代码的时间。

的numpy的方法的输出看起来好像:

numpy result

由于这个代码是我的瓶颈,我想用编织热捧提高速度。这种方法看起来像:

def blitzIteration(self): 
    ### does not work correct so far 
    g = self.grid 
    nx,ny = g.ux.shape 

    uxold = g.old_ux 
    ux = g.ux 
    expr = "ux[0:,1:-1] = uxold[0:,1:-1] + ReI* (uxold[0:,2:] - 2*uxold[0:,1:-1] + uxold[0:,0:-2])" 
    weave.blitz(expr, check_size=0) 
    g.setBC() 
    g.old_ux = ux.copy() 

然而,这并不产生正确的输出:(fixed转载,提交并有一个关于实际错误的详细信息那里) output for blitz code

回答

2

它看起来像在weave.blitz的错误。

我认为这是奇怪的写0:而不是更短的:得到一个完整的切片,所以我取代了所有这些片和voilà,它的工作。

我真的不知道哪里的错误所在,但weave.blitz产生的expr_code略有不同:

  • 当使用0:

    ipdb> expr_code 
    'ux_blitz_buggy(blitz::Range(0,_end),blitz::Range(1,Nux_blitz_buggy(1)-1-1))=uxold(blitz::Range(0,_end),blitz::Range(1,Nuxold(1)-1-1))+ReI*(uxold(blitz::Range(0,_end),blitz::Range(2,_end))-2*uxold(blitz::Range(0,_end),blitz::Range(1,Nuxold(1)-1-1))+uxold(blitz::Range(0,_end),blitz::Range(0,Nuxold(1)-2-1)));\n' 
    
  • 当使用:

    ipdb> expr_code 
    'ux_blitz_not_buggy(_all,blitz::Range(1,Nux_blitz_not_buggy(1)-1-1))=uxold(_all,blitz::Range(1,Nuxold(1)-1-1))+ReI*(uxold(_all,blitz::Range(2,_end))-2*uxold(_all,blitz::Range(1,Nuxold(1)-1-1))+uxold(_all,blitz::Range(0,Nuxold(1)-2-1)));\n' 
    

因此,blitz::Range(0,_end)变成_all并且它们的行为方式不同。

为方便起见,下面是一个完整的脚本,它重现了问题,只会在问题存在时成功。

import numpy as np 
from scipy.weave import blitz 


def test_blitz_bug(N=4): 
    ReI = 1.2 
    ux_blitz_buggy, ux_blitz_not_buggy, ux_np = np.zeros((N, N)), np.zeros((N, N)), np.zeros((N, N)) 
    uxold = np.random.randn(N, N) 
    ux_np[0:,1:-1] = uxold[0:,1:-1] + ReI* (uxold[0:,2:] - 2*uxold[0:,1:-1] + uxold[0:,0:-2]) 
    expr_buggy = 'ux_blitz_buggy[0:,1:-1] = uxold[0:,1:-1] + ReI* (uxold[0:,2:] - 2*uxold[0:,1:-1] + uxold[0:,0:-2])' 
    expr_not_buggy = 'ux_blitz_not_buggy[:,1:-1] = uxold[:,1:-1] + ReI* (uxold[:,2:] - 2*uxold[:,1:-1] + uxold[:,0:-2])' 
    blitz(expr_buggy) 
    blitz(expr_not_buggy) 
    assert not np.allclose(ux_blitz_buggy, ux_np) 
    assert np.allclose(ux_blitz_not_buggy, ux_np) 

if __name__ == '__main__': 
    test_blitz_bug() 
+1

@jordeca:这里是: '$蟒蛇blitz_bug.py' '$蟒蛇-c “进口SciPy的;打印SciPy的.__版本__”' 0.13.0.dev-639ef30 '$蟒蛇 - c“import numpy; print numpy .__ version __”' 1.7.1 '$ uname -a' Linux ratatoskr 2.6.32-45-generic#104-Ubuntu SMP Tue Feb 19 21:20:09 UTC 2013 x86_64 GNU/Linux – 2013-04-26 21:58:43

+0

@ Zhenya谢谢! – jorgeca 2013-04-27 14:13:22