我有一个数据集(71094火车图像和17000测试),我需要训练一个CNN。在预处理过程中,我尝试使用numpy创建一个矩阵,结果是荒谬的大(火车数据为71094 * 100 * 100 * 3)[所有图像都是RGB 100乘100]。因此,我收到了一个内存错误。我该如何解决这个问题。请帮忙。 这是我的代码..预处理CNN的numpy图像数据集:内存错误
import numpy as np
import cv2
from matplotlib import pyplot as plt
data_dir = './fashion-data/images/'
train_data = './fashion-data/train.txt'
test_data = './fashion-data/test.txt'
f = open(train_data, 'r').read()
ims = f.split('\n')
print len(ims)
train = np.zeros((71094, 100, 100, 3)) #this line causes the error..
for ix in range(train.shape[0]):
i = cv2.imread(data_dir + ims[ix] + '.jpg')
label = ims[ix].split('/')[0]
train[ix, :, :, :] = cv2.resize(i, (100, 100))
print train[0]
train_labels = np.zeros((71094, 1))
for ix in range(train_labels.shape[0]):
l = ims[ix].split('/')[0]
train_labels[ix] = int(l)
print train_labels[0]
np.save('./data/train', train)
np.save('./data/train_labels', train_labels)