2017-07-31 96 views
2

我使用ImageDataGenerator().flow_from_directory(...)从目录生成批量数据。Keras:从ImageDataGenerator或predict_generator获取真实标签(y_test)

模型成功建立后,我想要得到一个True和Predicted类标签的两列数组。用model.predict_generator(validation_generator, steps=NUM_STEPS)我可以得到一个预测类的数组。是否可以让predict_generator输出相应的True类标签?

要添加:validation_generator.classes确实会打印True标签,但是按照从目录中检索它们的顺序,它不考虑通过扩充进行的批处理或样本扩展。

回答

0

你可以得到预测标签: y_pred = numpy.rint(预测) ,您可以通过获得真正的标签: y_true = validation_generator.classes 你应该在此之前设定洗牌= False在验证发电机。

最后你可以打印混淆矩阵 print confusion_matrix(y_true,y_pred)