2012-06-18 105 views
6

我有一个100x200的二维数组,表示为由黑色(0)和白色(255)单元组成的numpy数组。这是一个位图文件。然后我有二维形状(最简单的把它们想象成字母),它们也是2D黑白单元格。查找矩阵内的匹配子矩阵

我知道我可以天真地遍历矩阵,但这将是我的代码的'热'部分,所以速度是一个问题。在numpy/scipy中有快速的方法吗?

我简单看了一下Scipy的相关函数。我对'模糊匹配'不感兴趣,只有完全匹配。我也看了一些学术论文,但他们高于我的头脑。

回答

8

可以使用相关。您需要将黑色值设置为-1,将白色值设置为1(反之亦然),以便知道相关峰值的值,并且只会出现正确的字母。

下面的代码做我认为你想要的。

import numpy 
from scipy import signal 

# Set up the inputs 
a = numpy.random.randn(100, 200) 
a[a<0] = 0 
a[a>0] = 255 

b = numpy.random.randn(20, 20) 
b[b<0] = 0 
b[b>0] = 255 

# put b somewhere in a 
a[37:37+b.shape[0], 84:84+b.shape[1]] = b 

# Now the actual solution... 

# Set the black values to -1 
a[a==0] = -1 
b[b==0] = -1 

# and the white values to 1 
a[a==255] = 1 
b[b==255] = 1 

max_peak = numpy.prod(b.shape) 

# c will contain max_peak where the overlap is perfect 
c = signal.correlate(a, b, 'valid') 

overlaps = numpy.where(c == max_peak) 

print overlaps 

此输出(array([37]), array([84])),在代码中设置的偏移量的位置。

您可能会发现,如果您的字母大小乘以大数组大小大于Nlog(N),其中N是您要搜索的大数组的相应大小(对于每个维度),则您可能会通过使用基于fft的算法(如scipy.signal.fftconvolve)(考虑到如果您使用卷积而非相关性 - flipudfliplr)需要翻转其中一个数据集的每个轴来加快速度。唯一的修改是将分配C:

c = signal.fftconvolve(a, numpy.fliplr(numpy.flipud(b)), 'valid') 

上面的时序上的尺寸比较:

In [5]: timeit c = signal.fftconvolve(a, numpy.fliplr(numpy.flipud(b)), 'valid') 
100 loops, best of 3: 6.78 ms per loop 

In [6]: timeit c = signal.correlate(a, b, 'valid') 
10 loops, best of 3: 151 ms per loop 
+0

哇,伟大的答案!我有一些测试运行。 – DaveO

+1

刚刚发生的事情,您可以通过将值设置为0来“不关心”您的子矩阵的区域。这意味着这些值将不会影响交叉关联。然后'max_peak'值可以被找到为'max_peak = b [b!= 0] .size'(无论你有没有0值,这都会起作用)。 –

+0

所以我花了下午编辑我的代码,并使它工作!假设在array([0,6]),array([1,7]))处发现了2个2x3形状,意味着左上角是[0,1]和[6,7]。我想要做的是能够索引形状中的所有2x3单元格,并将它们赋值为0,因此在下一个要查找的形状中,我们将不检查图像的部分(根据上面的注释)。我如何使用correlate/fftconvolve的返回值在不使用循环的情况下索引2d形状?对位置列表片段进行排序。 – DaveO

7

这里是您可以使用,或适应,这取决于细节的方法你的要求。它采用ndimage.label and ndimage.find_objects:

  1. 标签使用ndimage.label此找到所有斑点的阵列中,并将其标签为整数的图像。

    import scipy 
    from scipy import ndimage 
    import matplotlib.pyplot as plt 
    
    #flatten to ensure greyscale. 
    im = scipy.misc.imread('letters.png',flatten=1) 
    objects, number_of_objects = ndimage.label(im) 
    letters = ndimage.find_objects(objects) 
    
    #to save the images for illustrative purposes only: 
    plt.imsave('ob.png',objects) 
    for i,j in enumerate(letters): 
        plt.imsave('ob'+str(i)+'.png',objects[j]) 
    

    例如输入:

  2. 使用ndimage.find_objects
  3. 然后使用交集,看是否found blobs符合您的wanted blobs

代码1.2.获取这些斑点的片

enter image description here

标记:

enter image description here

孤立的斑点再次进行测试:(甚至计时)

enter image description here enter image description here enter image description here enter image description here enter image description here enter image description here

+0

令人惊叹!这几乎是我想要做的。我将不得不尝试两个答案,并看看什么效果最好。感谢您花时间发布此信息! – DaveO