2016-08-19 287 views
1

我正尝试使用Keras基于文档[this example][1]构建自编码器。因为我的数据很大,所以我想使用生成器来避免将其加载到内存中。Keras ImageDataGenerator不按预期方式工作

我的模型看起来像:

model = Sequential() 
model.add(Convolution2D(16, 3, 3, activation='relu', border_mode='same', input_shape=(3, 256, 256))) 
model.add(MaxPooling2D((2, 2), border_mode='same')) 
model.add(Convolution2D(8, 3, 3, activation='relu', border_mode='same')) 
model.add(MaxPooling2D((2, 2), border_mode='same')) 
model.add(Convolution2D(8, 3, 3, activation='relu', border_mode='same')) 
model.add(MaxPooling2D((2, 2), border_mode='same')) 
model.add(Convolution2D(8, 3, 3, activation='relu', border_mode='same')) 
model.add(UpSampling2D((2, 2))) 
model.add(Convolution2D(8, 3, 3, activation='relu', border_mode='same')) 
model.add(UpSampling2D((2, 2))) 
model.add(Convolution2D(16, 3, 3, activation='relu')) 
model.add(UpSampling2D((2, 2))) 
model.add(Convolution2D(1, 3, 3, activation='sigmoid', border_mode='same')) 

model.compile(optimizer='adadelta', loss='binary_crossentropy') 

我的发电机:

from keras.preprocessing.image import ImageDataGenerator 
train_datagen = ImageDataGenerator(rescale=1./255) 
train_generator = train_datagen.flow_from_directory('IMAGE DIRECTORY', color_mode='rgb', class_mode='binary', batch_size=32, target_size=(256, 256)) 

再恰当不过的模型:

model.fit_generator(
     train_generator, 
     samples_per_epoch=1, 
     nb_epoch=1, 
     verbose=1, 
     ) 

我得到这个错误:

例外:当che时出错cking模型目标:期望convolution2d_76具有4个维度,但获得了具有形状的数组(32,1)

看起来像我的批处理的大小而不是样本。我究竟做错了什么?

回答

2

错误很可能是由于class_mode='binary'。它使发生器产生二进制类,所以输出形状为(batch_size, 1),而你的模型产生四维输出(因为最后一层是卷积)。

我想你想让你的标签成为图像本身。根据其使用的flow_from_directoryDirectoryIterator的来源,仅通过更改class_mode就无法做到。一种可能的解决办法是沿着线:

train_generator_ = train_datagen.flow_from_directory('IMAGE DIRECTORY', color_mode='rgb', class_mode=None, batch_size=32, target_size=(256, 256)) 
def train_generator(): 
    for x in train_iterator_: 
     yield x, x 

注意,我设置class_modeNone。它使发电机仅返回image而不是tuple(image, label)。然后我定义一个新的生成器,它将图像作为输入和标签返回。

+0

太棒了,非常感谢!现在一切正常。我很困惑,因为我认为它是在讨论输入形状,但是在实现您的解决方案后,我发现它是输出形状是问题所在。 – Lester