2015-03-13 65 views
3

我需要移动二维数组字段,即我有一个“previous_data”数组,我通过移位索引访问以创建我的“new_data”数组。通过移位索引在二维数组中移位数据

我可以在nonpythonic(和慢)循环中做到这一点,但非常感谢一些帮助找到pythonic(和更快)的解决方案!

任何帮助和提示都非常感谢!

import numpy as np 
import matplotlib.pyplot as plt 
from matplotlib import mpl 

def nonpythonic(): 
    #this works, but is slow (for large arrays)   
    new_data = np.zeros((ny,nx)) 
    for j in xrange(ny): 
     for i in xrange(nx): 
      #go through each item, check if it is within the bounds 
      #and assign the data to the new_data array 
      i_new = ix[j,i] 
      j_new = iy[j,i] 
      if ((i_new>=0) and (i_new<nx) and (j_new>=0) and (j_new<ny)): 
       new_data[j,i]=previous_data[j_new,i_new] 

    ef, axar = plt.subplots(1,2) 
    im = axar[0].pcolor(previous_data, vmin=0,vmax=2) 
    ef.colorbar(im, ax=axar[0], shrink=0.9) 
    im = axar[1].pcolor(new_data, vmin=0,vmax=2) 
    ef.colorbar(im, ax=axar[1], shrink=0.9) 

    plt.show() 

def pythonic(): 
    #tried a few things here, but none are working 
    #-tried assigning NaNs to indices (ix,iy) which are out of bounds, but NaN's don't work for indices 
    #-tried masked arrays, but they also don't work as indices 
    #-tried boolean arrays, but ended in shape mismatches 
    #just as in the nonworking code below 
    ind_y_good = np.where(iy>=0) and np.where(iy<ny) 
    ind_x_good = np.where(ix>=0) and np.where(ix<nx) 

    new_data = np.zeros((ny,nx)) 

    new_data[ind_y_good,ind_x_good] = previous_data[iy[ind_y_good],ix[ind_x_good]] 

#some 2D array: 
nx = 20 
ny = 30  
#array indices: 
iy, ix = np.indices((ny,nx)) 
#modify indices (shift): 
iy = iy + 1 
ix = ix - 4 
#create some out of range indices (which might happen in my real scenario) 
iy[0,2:7] = -9999 
ix[0:3,-1] = 6666 

#some previous data which is the basis for the new_data: 
previous_data = np.ones((ny,nx)) 
previous_data[2:8,10:20] = 2 
nonpythonic() 
pythonic() 

这是工作(nonpythonic)上面的代码的结果: nonpythonic working example of shifted data

回答

2

我实现的pythonic与某些掩蔽和索引摆弄复制nonpythonic一个版本 - 见下文。顺便说一下,我认为“新”索引应该是与新数组相对应的索引,而不是旧数组索引,但我已将它留在现有函数中。

实现的主要问题是,在你的问题的尝试,你的条件

ind_y_good = np.where(iy>=0) and np.where(iy<ny) 
ind_x_good = np.where(ix>=0) and np.where(ix<nx) 

必须结合,因为我们必须始终对xy指数。即如果x索引无效,那么y也是如此。

最后,如果指数实际上都被一个常数因子移位,那么可以通过使用NumPy的roll函数并对与有效区域对应的指数取一个片段来使其更简单。

import numpy as np 
import matplotlib.pyplot as plt 
from matplotlib import mpl 


def nonpythonic(previous_data, ix, iy, nx, ny): 
    #this works, but is slow (for large arrays)   
    new_data = np.zeros((ny,nx)) 
    for j in xrange(ny): 
     for i in xrange(nx): 
      #go through each item, check if it is within the bounds 
      #and assign the data to the new_data array 
      i_new = ix[j,i] 
      j_new = iy[j,i] 
      if ((i_new>=0) and (i_new<nx) and (j_new>=0) and (j_new<ny)): 
       new_data[j,i]=previous_data[j_new,i_new] 

    return new_data 

def pythonic(previous_data, ix, iy): 

    ny, nx = previous_data.shape 
    iy_old, ix_old = np.indices(previous_data.shape) 

    # note you must apply the same condition to both 
    # index arrays 
    valid = (iy >= 0) & (iy < ny) & (ix >= 0) & (ix < nx) 

    new_data = np.zeros((ny,nx)) 

    new_data[iy_old[valid], ix_old[valid]] = previous_data[iy[valid], ix[valid]] 
    return new_data 


def main(): 
    #some 2D array: 
    nx = 20 
    ny = 30  
    #array indices: 
    iy, ix = np.indices((ny,nx)) 
    #modify indices (shift): 
    iy = iy + 1 
    ix = ix - 4 
    #create some out of range indices (which might happen in my real scenario) 
    iy[0,2:7] = -9999 
    ix[0:3,-1] = 6666 

    #some previous data which is the basis for the new_data: 
    previous_data = np.ones((ny,nx)) 
    previous_data[2:8,10:20] = 2 
    data_nonpythonic = nonpythonic(previous_data, ix, iy, nx, ny) 
    data_pythonic = pythonic(previous_data, ix, iy) 

    new_data = data_nonpythonic 
    ef, axar = plt.subplots(1,2) 
    im = axar[0].pcolor(previous_data, vmin=0,vmax=2) 
    ef.colorbar(im, ax=axar[0], shrink=0.9) 
    im = axar[1].pcolor(new_data, vmin=0,vmax=2) 
    ef.colorbar(im, ax=axar[1], shrink=0.9) 
    plt.show() 
    print(np.allclose(data_nonpythonic, data_pythonic)) 

if __name__ == "__main__": 
    main() 
+1

非常感谢,这个作品!现在我看到了解决方案,看起来非常明显和直接! – user3497890 2015-03-16 12:05:08

+0

太棒了!用一双新鲜的眼睛总是比较容易 – YXD 2015-03-16 12:06:23