2012-08-29 30 views
1

我试图边缘化一个轴上的数组,并检查一维峰值是否出现在与原始二维峰值相同的相关索引处。在什么情况下(x的形式)下面的断言是否失败?numpy array - argmax v argmax的总和的unravel_index

def check(x,axis=None): 
    import numpy 
    m = numpy.sum(x, axis=axis) 
    v,w = numpy.unravel_index(numpy.argmax(x), x.shape) 
    assert(v==numpy.argmax(m)) 
    return 

对于x=numpy.array(range(15)).reshape(5,3)check(x,axis=0)引发错误,但check(x,axis=1)没有。我看不出为什么有人提出AssertionError - 我是不是很愚蠢?

+0

您是否尝试打印'v'和numpy.argmax(m)的值?也许实际值会给你一个线索 – Dhara

+0

当axis = None时断言失败吗?这可能是因为sum函数将数组中的所有值相加,给出m [不是数组]的单个值。这个argmax总是0 – Dhara

回答

1

您正在检查拆散索引的错误坐标。取而代之的

v,w=numpy.unravel_index(numpy.argmax(x),x.shape) 
assert(v==numpy.argmax(m)) 

你可能想

vw = numpy.unravel_index(numpy.argmax(x),x.shape) 
assert vw[1 - axis] == numpy.argmax(m) 

或许

v,w=numpy.unravel_index(numpy.argmax(x),x.shape) 
assert (v if axis == 1 else w) == numpy.argmax(m) 
1

axis参数的值是非常关键的。

随着x一个(N, M)阵列,m=np.sum(x, axis=axis)会给你

  • 标如果axis=None;
  • a M array if axis=0;
  • a N array if axis=1

因此,你的np.argmax(m)将总是0是如果axis=None,或者如果axis=0(相应的axis=10M(相应的N)之间的整数。

但是,您的(v, w) = np.unravel_index(...)将始终会给您v作为0到N之间的整数。

正如你所看到的,与axis=0,潜力值的m的范围是不一样的,作为v,而它是axis=1

所以,比较mv如果axis=1,或w如果axis=0(@ ecatmur的答案告诉您如何)。