2017-05-11 104 views
1

我有一个3d NumPy数组,我想从中切取很多切片。这些切片在第一维和第二维中将具有一个或多个长度,而第三维将全部返回。切片应始终为3d。保存可变NumPy切片的尺寸

我在此尝试:

import numpy as np 

a = np.zeros((1000, 10, 100)) 
row_sets = ([19, 20], [21]) 
col_sets = ([6], [7, 8]) 

for rows in row_sets: 
    for cols in col_sets: 
     b = a[[rows], [cols]] 
     print(rows, cols, b.shape) 

结果:

[19, 20] [6] (1, 2, 100) 
[19, 20] [7, 8] (1, 2, 100) 
[21] [6] (1, 1, 100) 
[21] [7, 8] (1, 2, 100) 

如果我从切片删除嵌套括号:

b = a[rows, cols] 

我有什么似乎是第二维中的同一问题,维度不保留:

[19, 20] [6] (2, 100) 
[19, 20] [7, 8] (2, 100) 
[21] [6] (1, 100) 
[21] [7, 8] (2, 100) 

我寻找的结果会是这样:

[19, 20] [6] (2, 1, 100) 
[19, 20] [7, 8] (2, 2, 100) 
[21] [6] (1, 1, 100) 
[21] [7, 8] (1, 2, 100) 

回答

2

你触发advanced indexing通过使用整数作为索引,从而降低了结果阵列的尺寸的列表,如果你想要仍然切片数组,您可以使用np.ix_从整数列表重建切片索引:

for rows in row_sets: 
    for cols in col_sets: 
     b = a[np.ix_(rows, cols)] 
     print(rows, cols, b.shape) 

#[19, 20] [6] (2, 1, 100) 
#[19, 20] [7, 8] (2, 2, 100) 
#[21] [6] (1, 1, 100) 
#[21] [7, 8] (1, 2, 100)