2017-05-31 111 views
2

的选择分支我有一个多输出Keras模型类似于此的结构:列车多输出keras模型

s = some_shared_layers()(input) 
non_trainable1 = Dense(trainable=False) (s) 
non_trainable2 = Dense(trainable=False) (s) 
trainable = Dense() (s) 

model = Model(input, outputs=[non_trainable1, non_trainable2, trainable]) 

我的模型首先计算一个直传并使用该第一2个输出以操纵输入。然后计算另一个正向传球以获得第三个输出。

out1, out2,_ =model.predict(input_data) 
processed_data = foo(input_data, out1, out2) 
_,_, out3 = model.predict(processed_data) 

我应该如何调用model.fit()训练只有trainable层?如果我排除其他产出的损失,凯拉斯警告we will not be expecting any data to be passed to "non_trainable1" during training并将其从计算图中排除。

有没有更好的方法来构建这个任务的模型?

回答

0

如果我正确理解它,你根本不需要这些层,事实上,你应该有两个模型,一个用于预测,另一个用于训练。

非可训练:

model1 = Model(input, [non_trainable1, non_trainable2]) 
#model 1 doesn't need to be compiled, since you won't train it  

可训练:

model2 = Model(input, trainable) 
model2.compile(loss=onlyTheLossForTrainable)  

使用它们:

out1, out2 =model1.predict(input_data) 
processed_data = foo(input_data, out1, out2) 

model2.fit(processed_data, expected_outputs, ....)  
+0

我目前使用作为一种变通方法,它如果我按照您所描述的运行代码,则会起作用。但是,如果我尝试以更复杂的方式使用处理函数'foo' - 比如在'ImageDataGenerator'内部 - 然后'model2.fit_generator()'产生一个'ValueError:Tensor'non_trainable1“不是该图的一个元素。' – Manas

+0

为什么在发电机内?你在发电机里究竟想要做什么?这听起来你正在使用张量而不是使用预测。 –

+0

我正在处理图像,并正在使用生成器进行数据增强。作为一个天真的第一次尝试,我尝试传递'foo'作为'preprocessing_function'参数(ref [docs](https://keras.io/preprocessing/image/))。 'foo'拍摄图像,调用'model1.predict()'并返回一个编辑后的图像。 – Manas