2014-01-18 137 views
3

快速插值这个问题类似于前问题回答Fast interpolation over 3D array,但解决不了我的问题。在三维阵列三维原点X

我有尺寸(时间,高度,纬度,经度),标记为y.shape=(nt, nalt, nlat, nlon)一个四维阵列。 x是高度并随(时间,纬度,经度)而变化,这意味着x.shape = (nt, nalt, nlat, nlon)。我想插入每个(nt,nlat,nlon)的高度。插值的x_new应该是1d,不会随(时间,纬度,经度)而改变。

我用numpy.interp,一样scipy.interpolate.interp1d,想想在原职的答案。这些答案我不能减少循环。

我只能这样做:

# y is a 4D ndarray 
# x is a 4D ndarray 
# new_y is a 4D array 
for i in range(nlon): 
    for j in range(nlat): 
     for k in range(nt): 
      y_new[k,:,j,i] = np.interp(new_x, x[k,:,j,i], y[k,:,j,i]) 

这些循环使这个插值太慢了计算。会有人有好点子吗?帮助将不胜感激。

+0

如果一个new_x是按升序排列,我想你可以写一个用Cython功能来提高计算速度。 – HYRY

+0

@HYRY new_x按升序排列。但是原来的x可能不是从小到大的顺序。这是否可用?我不熟悉Cython,也许需要一些时间来学习它,或者可以在前一篇文章中学习tiago的代码。 – Hao

+0

但是,'numpy.interp'需要'x'按升序排列,所以需要先将x和y按x排序。 – HYRY

回答

1

这是通过使用numba我的解决方案,它是关于快三倍。

创建测试数据第一,x需要按升序排列:

import numpy as np 
rows = 200000 
cols = 66 
new_cols = 69 
x = np.random.rand(rows, cols) 
x.sort(axis=-1) 
y = np.random.rand(rows, cols) 
nx = np.random.rand(new_cols) 
nx.sort() 

做20万次口译在numpy的:

%%time 
ny = np.empty((x.shape[0], len(nx))) 
for i in range(len(x)): 
    ny[i] = np.interp(nx, x[i], y[i]) 

我用的合并方法,而不是二进制搜索方法,因为nx是顺序的,并且nx的长度与x大致相同。

  • interp()使用二进制搜索,时间复杂度为O(len(nx)*log2(len(x))
  • 合并方法:时间复杂度为O(len(nx) + len(x))

这里是numba代码:

import numba 

@numba.jit("f8[::1](f8[::1], f8[::1], f8[::1], f8[::1])") 
def interp2(x, xp, fp, f): 
    n = len(x) 
    n2 = len(xp) 
    j = 0 
    i = 0 
    while x[i] <= xp[0]: 
     f[i] = fp[0] 
     i += 1 

    slope = (fp[j+1] - fp[j])/(xp[j+1] - xp[j])   
    while i < n: 
     if x[i] >= xp[j] and x[i] < xp[j+1]: 
      f[i] = slope*(x[i] - xp[j]) + fp[j] 
      i += 1 
      continue 
     j += 1 
     if j + 1 == n2: 
      break 
     slope = (fp[j+1] - fp[j])/(xp[j+1] - xp[j]) 

    while i < n: 
     f[i] = fp[n2-1] 
     i += 1 

@numba.jit("f8[:, ::1](f8[::1], f8[:, ::1], f8[:, ::1])") 
def multi_interp(x, xp, fp): 
    nrows = xp.shape[0] 
    f = np.empty((nrows, x.shape[0])) 
    for i in range(nrows): 
     interp2(x, xp[i, :], fp[i, :], f[i, :]) 
    return f 

然后调用numba功能:

%%time 
ny2 = multi_interp(nx, x, y) 

检查结果:

np.allclose(ny, ny2) 

在我的电脑,时间为:

python version: 3.41 s 
numba version: 1.04 s 

这种方法需要一个数组,其中最后一个轴是要interp()轴。

+0

非常感谢。那很棒!我明天会检查它。 – Hao