2016-11-28 63 views
3

在一个大的代码库,我使用np.broadcast_to广播阵列(只是用简单的例子在这里):取消广播numpy的阵列

In [1]: x = np.array([1,2,3]) 

In [2]: y = np.broadcast_to(x, (2,1,3)) 

In [3]: y.shape 
Out[3]: (2, 1, 3) 

其他地方在代码中,我使用第三方功能,可以运行以Numpy数组的矢量化方式,但不是ufuncs。这些函数不理解广播,这意味着在像y这样的数组上调用这样的函数效率不高。诸如Numpy的vectorize这样的解决方案不好,因为虽然他们理解广播,但是他们在数组元素上引入了一个for循环,这是非常低效的。

理想情况下,我希望能够做的是有一个函数,我们可以调用例如unbroadcast,它返回一个最小形状的数组,如果需要可以将其广播回全尺寸。因此,例如为:

In [4]: z = unbroadcast(y) 

In [5]: z.shape 
Out[5]: (1, 1, 3) 

然后我就可以在z运行第三方功能,那么广播结果返回给y.shape

有没有一种方法可以实现unbroadcast,它依赖于Numpy的公共API?如果没有,是否有任何黑客可以产生期望的结果?

+0

“y [None,0]'? – Divakar

+0

你是什么意思“最小的形状”? 'unbroadcast'返回的N维的最小形状是不是总是'(1,1,...,1)'(甚至是'(1,)')? –

+0

我的意思是最小的形状,仍然包含所有需要的数据将其广播回完整阵列。所以在上面的例子中,“z.shape”是“(1,1,3)”而不是“(1,1,1)”。 – astrofrog

回答

3

这大概相当于自己的解决方案,只有一点点更内置-在。它采用as_stridednumpy.lib.stride_tricks

import numpy as np 
from numpy.lib.stride_tricks import as_strided 

x = np.arange(16).reshape(2,1,8,1) # shape (2,1,8,1) 
y = np.broadcast_to(x,(2,3,8,5)) # shape (2,3,8,5) broadcast 

def unbroadcast(arr): 
    #determine unbroadcast shape 
    newshape = np.where(np.array(arr.strides) == 0,1,arr.shape) # [2,1,8,1], thanks to @Divakar 
    return as_strided(arr,shape=newshape) # strides are automatically set here 

z = unbroadcast(x) 
np.all(z==x) # is True 

注意,在我原来的答案,我并没有定义一个函数,并将得到z阵列有(64,0,8,0)strides,而输入具有(64,64,8,8)。在当前版本中,返回的z数组与x有相同的大小,我猜想传递并返回数组强制创建副本。无论如何,我们可以在任何情况下手动设置as_strided以获得相同的数组,但这在上述设置中似乎不是必需的。

+1

或'np.where(np.array(y.strides)== 0,1,y.shape)'为新闻形式? – Divakar

+0

@Divakar权利,我一直忘记'np.where' :)显然这是numpy-idiomatic版本,谢谢。 –

+1

我的解决方案非常好的改进 - 谢谢! – astrofrog

4

我有一个可能的解决方案,所以会在这里发布(但是如果有人有更好的,请随时回复!)。一个解决方案是检查strides说法阵列,这将是0一起广播尺寸:

def unbroadcast(array): 
    slices = [] 
    for i in range(array.ndim): 
     if array.strides[i] == 0: 
      slices.append(slice(0, 1)) 
     else: 
      slices.append(slice(None)) 
    return array[slices] 

这给:

In [14]: unbroadcast(y).shape 
Out[14]: (1, 1, 3)