2017-05-28 24 views
0

我有2个numpy的阵列如下:如何平衡numpy数组中的类?

images包含图像的文件的名称(images.shape是(N,3,128,128)): image_1.jpg image_2.jpg image_3.jpg image_4.jpg

labels包含相应的标签(0-3)(labels.shape是(N)): 1 1 3 2

我面对的问题是,这些类是不平衡,用3类>> 1> 2> 0

我想,以平衡由最终的数据集:

  • 每班
  • 计数图像(样本)的数量获取类的计数与最低的数字图像
  • 使用该计数为图像/标签的对于其他3个类别的最大数目
  • 随机
  • 弹出过量图像/在images从其他3类标签和labels

到目前为止,我使用Counter以确定每类图像的数量:

from Collections import Counter 
import numpy as np 

count = Counter(labels) 
print(count) 

>>>Counter({'1': 2991, '0': 2953, '2': 2510, '3': 2488}) 

你会如何建议我随机imageslabels所以它们包含的类0,1 2488个样本流行元素相匹配,和2?

回答

1

你可以使用np.random.choice创建一个整数值面膜,你可以应用到你的标签和图片来平衡数据集:

n = 2488 

mask = np.hstack([np.random.choice(np.where(labels == l)[0], n, replace=False) 
         for l in np.unique(labels)]) 
+0

将'ix'是我同时适用于一个布尔向量数组? (图片和标签) – pepe

+0

对不起,我的回答是错误的,我修好了; 'ix'是一个整数向量,用于索引两个数组 – maxymoo