2017-07-05 90 views
1

我正在学习如何使用Keras和CIFAR-10数据集实现数据增强。我在网上教程和本书的帮助下学习Deep learning with Keras.Keras CONV数据增强培训似乎显示错误的批量大小和培训示例数

代码的具体细节是here

这里是我的问题,我确信它涉及到对我而言有些误会:

这是我CONV成立。

IMG_CHANNELS = 3 
IMG_ROWS = 32 
IMG_COLS = 32 
BATCH_SIZE = 128 
NB_EPOCH = 50 
NB_CLASSES = 10 
VERBOSE = 1 
VALIDATION_SPLIT = 0.2 
OPTIM = RMSprop() 

加载数据集,转化为明确的,漂浮和规范:

(X_train, y_train), (X_test, y_test) = cifar10.load_data() 
Y_train = np_utils.to_categorical(y_train, NB_CLASSES) 
Y_test = np_utils.to_categorical(y_test, NB_CLASSES) 
X_train = X_train.astype('float32') 
X_test = X_test.astype('float32') 
X_train /= 255 
X_test /= 255 

创建发电机

datagen = ImageDataGenerator(
     featurewise_center=False, # set input mean to 0 over the dataset 
     samplewise_center=False, # set each sample mean to 0 
     featurewise_std_normalization=False, # divide inputs by std of the dataset 
     samplewise_std_normalization=False, # divide each input by its std 
     zca_whitening=False, # apply ZCA whitening 
     rotation_range=0, # randomly rotate images in the range (degrees, 0 to 180) 
     width_shift_range=0.1, # randomly shift images horizontally (fraction of total width) 
     height_shift_range=0.1, # randomly shift images vertically (fraction of total height) 
     horizontal_flip=True, # randomly flip images 
     vertical_flip=False) # randomly flip images 
datagen.fit(X_train) 

火车模型(我还没有上市的型号)

model.fit_generator(datagen.flow(X_train, Y_train, 
        batch_size=BATCH_SIZE), 
        samples_per_epoch=X_train.shape[0], 
        nb_epoch=NB_EPOCH, 
        verbose=VERBOSE) 

我的问题是正如我训练以下显示:

Epoch 1/40 
390/390 [==============================] - 199s - loss: 0.9751 - acc: 0.6588 

我看不出为什么我得到390个例子。 Samples_per_epoch等于X_train.shape [0],它是50000,批量大小是128,所以我认为它应该在128个批次中上升到50000.

回答

2

进度条不显示样本数,但数字步骤或批次(当您使用model.fit而不是model.fit_generator时,它会自动显示样品)。每个批次包含128个样本,共有50,000个样本。 50,000/128 = 390.625。这就是为什么你看到390而不是50,000。

由于您使用的是model.fit_generator,因此无法显示样本总数。除非您将batch_size设置为1.原因是发生器预计会无限循环其数据,直到达到steps_per_epochssamples_per_epoch阈值(*)

顺便说一下,您可以在model.fit中对此进行更改,回拨为ProgbarLogger,看起来here

+0

这是我的意见,但由于某种奇怪的原因,这本书有一个数字显示的过程,它显示50000.你确定吗? – GhostRider

+0

是的,我编辑了我的问题。您可以使用ProgbarLogger回调来更改进度条,以便更新每个样本。 –

+0

我在拟合模型(并使用“样本”和“步骤”)之前添加了它,并且它没有影响... – GhostRider