2013-02-07 53 views
2

如何提高此功能的速度?嵌套循环为numpy convolve

def foo(mri_data, radius): 

    mask = mri_data.copy() 

    ny = len(mri_data[0,:]) 
    nx = len(mri_data[:]) 

    for y in xrange(0, ny): 
     for x in xrange(0, nx): 
      if (mri_data[x-radius:x+radius,y-radius:y+radius] != 1.0).all(): 
       mask[x,y] = 0.0      
    return mask.copy() 

它采用numpy阵列形式的图像切片。遍历每个像素并测试该像素周围的边界框。如果框中的值不等于1,则我们通过将其设置为0来丢弃该像素。

我被告知我可以使用numpy.convolve,但我不确定这是如何涉及的。

编辑:图像值在二进制范围内,所以最低值为0.0,最大值为1.0。数值在例如0.767之间。

+0

小心边缘。如果你有例如'radius = 3'和'mri_data = np.arange(8)'那么你的第一个窗口是'mri_data [-3:3]',它返回一个空数组... – Jaime

回答

3

您可以滥用卷积的情况之一。我不会用它,但是界限,否则乏味......

from scipy.ndimage import convolve 

not_one = (mri_data != 1.0) # are you sure you want to compare with float like that?! 

conv = convolve(not_one, np.ones((2*radius, 2*radius))) 
all_not_one = (conv == (2*radius)**2) 

mask[all_not_one] = 0 

应该做的确实同样的事情...

3

你在做什么叫做binary_dilation,但有一个小错误代码。特别是当x,y小于半径时,你会得到负指数。这些负数是使用numpy索引规则解释的,这不是您想要的more on indexing here,因此在图像的两个边缘给出了错误的结果。

这是一些使用二进制扩展来完成相同的事情,并修复上述错误的代码。

import numpy as np 
from scipy.ndimage import binary_dilation 

def foo(mri_data, radius): 
    structure = np.ones((2*radius, 2*radius)) 
    # I set the origin here to match your code 
    mask = binary_dilation(mri_data == 1, structure, origin=-1) 
    return np.where(mask, mri_data, 0) 
+0

感谢Bi Rico。我不知道我做的这种门槛已经是过滤器了。我将不得不尝试你的解决方案。 –