1
给定一个3维numpy数组,如何查找前n个最小值的索引?最小值的指数可以发现:查找三维numpy数组中的n个最小值的索引
i,j,k = np.where(my_array == my_array.min())
给定一个3维numpy数组,如何查找前n个最小值的索引?最小值的指数可以发现:查找三维numpy数组中的n个最小值的索引
i,j,k = np.where(my_array == my_array.min())
下面是通用正变暗和通用n个最小号的一种方法 -
def smallestN_indices(a, N):
idx = a.ravel().argsort()[:N]
return np.stack(np.unravel_index(idx, a.shape)).T
的的2D
输出数组将持有的每一行索引与最小阵列号之一对应的元组。
我们也可以使用argpartition
,但可能无法维持订单。所以,我们需要argsort
还有多一点额外的工作 -
def smallestN_indices_argparitition(a, N, maintain_order=False):
idx = np.argpartition(a.ravel(),N)[:N]
if maintain_order:
idx = idx[a.ravel()[idx].argsort()]
return np.stack(np.unravel_index(idx, a.shape)).T
采样运行 -
In [141]: np.random.seed(1234)
...: a = np.random.randint(111,999,(2,5,4,3))
...:
In [142]: smallestN_indices(a, N=3)
Out[142]:
array([[0, 3, 2, 0],
[1, 2, 3, 0],
[1, 2, 2, 1]])
In [143]: smallestN_indices_argparitition(a, N=3)
Out[143]:
array([[1, 2, 3, 0],
[0, 3, 2, 0],
[1, 2, 2, 1]])
In [144]: smallestN_indices_argparitition(a, N=3, maintain_order=True)
Out[144]:
array([[0, 3, 2, 0],
[1, 2, 3, 0],
[1, 2, 2, 1]])
运行测试 -
In [145]: a = np.random.randint(111,999,(20,50,40,30))
In [146]: %timeit smallestN_indices(a, N=3)
...: %timeit smallestN_indices_argparitition(a, N=3)
...: %timeit smallestN_indices_argparitition(a, N=3, maintain_order=True)
...:
10 loops, best of 3: 97.6 ms per loop
100 loops, best of 3: 8.32 ms per loop
100 loops, best of 3: 8.34 ms per loop
工作过,谢谢! –
的可能的复制[获取N个最高值的索引在ndarray](https://stackoverflow.com/questions/26603747/get-the-indices-of-n-highest-values-in-an-ndarray) – stamaimer
发布的解决方案是否适合你? – Divakar
是的,它确实有效。非常感谢伙伴 –