我按照目前Keras博客上的教程,我有以下问题:一旦模型被训练,我该如何挑选图像并对其进行分类?图像分类Keras预测
我知道train_generator.class_indices在模型中有类。
目标是引入图像的路径并返回相应的类。
下面是代码:
#libraries used
from keras import backend as K
from keras import applications
from keras import optimizers
from keras.preprocessing.image import ImageDataGenerator, array_to_img, img_to_array, load_img
from keras.models import Sequential
from keras.layers import Conv2D, MaxPooling2D
from keras.layers import Activation, Dropout, Flatten, Dense
from keras.optimizers import SGD
import cv2
from PIL import Image
import matplotlib.pyplot as plt
from scipy.misc import imread
import numpy as np
%matplotlib inline
# dimensions of our images.
img_width, img_height = 150, 150
train_data_dir = 'data/train'
validation_data_dir = 'data/validation'
nb_train_samples = 3000
nb_validation_samples = 1200
epochs = 10 #50
batch_size = 16
n_classes = 3
# Get data
# this is the augmentation configuration we will use for training
train_datagen = ImageDataGenerator(rescale=1./255, shear_range=0.2, zoom_range=0.2, horizontal_flip=True)
# this is the augmentation configuration we will use for testing:
# only rescaling
test_datagen = ImageDataGenerator(rescale=1./255)
train_generator = train_datagen.flow_from_directory(train_data_dir, target_size=(img_width, img_height),\
batch_size=batch_size,class_mode = 'categorical') # class_mode='binary'
validation_generator = test_datagen.flow_from_directory(validation_data_dir, target_size=(img_width, img_height),\
batch_size=batch_size, class_mode='categorical') #class_mode = 'categorical
if K.image_data_format() == 'channels_first':
input_shape = (3, img_width, img_height)
else:
input_shape = (img_width, img_height, 3)
model = Sequential()
model.add(Conv2D(32, (3, 3), input_shape=input_shape))
model.add(Activation('relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Conv2D(32, (3, 3)))
model.add(Activation('relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Conv2D(64, (3, 3)))
model.add(Activation('relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Flatten())
model.add(Dense(64))
model.add(Activation('relu'))
model.add(Dropout(0.5))
model.add(Dense(n_classes))
model.add(Activation('softmax')) #sigmoid
model.compile(loss='categorical_crossentropy', optimizer='rmsprop', metrics=['accuracy']) #loss binary_crossentropy
# on the other model
#model.compile(optimizer=SGD(lr=0.0001, momentum=0.9), loss='categorical_crossentropy', metrics=['accuracy'])
model.fit_generator(train_generator, steps_per_epoch=nb_train_samples // batch_size,
epochs=epochs,validation_data=validation_generator, validation_steps = nb_validation_samples // batch_size)
编辑1:
我写了下面的函数,它不工作:
def predict(model, img, target_size):
if img.size != target_size:
img = img.resize(target_size)
x = image.img_to_array(img)
x = np.expand_dims(x, axis=0)
x = preprocess_input(x)
preds = model.predict(x)
return preds[0]
target_size = (150, 150)
model = load_model(model_name)
img_path = 'image_test/test1.jpg'
img = Image.open(img_path)
fig = plt.figure()
plt.imshow(img)
plt.show()
preds = predict(model, img, target_size)
preds
编辑2 :
错误提出:
---------------------------------------------------------------------------
AttributeError Traceback (most recent call last)
<ipython-input-55-516f01bf49e9> in <module>()
17 plt.imshow(img)
18 plt.show()
---> 19 preds = predict(model, img, target_size)
20 preds
<ipython-input-55-516f01bf49e9> in predict(model, img, target_size)
3 img = img.resize(target_size)
4
----> 5 x = image.img_to_array(img)
6 x = np.expand_dims(x, axis=0)
7 x = preprocess_input(x)
AttributeError: 'numpy.ndarray' object has no attribute 'img_to_array'
编辑3:该解决方案如下(我们必须重塑图):
def predict(model, img, target_size):
if img.size != target_size:
img = img.resize(target_size)
x = img.getdata() #.img_to_array(img)
x = np.expand_dims(x, axis=0)
x = (x/255.)
print("shape = ", x.shape)
x = x.reshape(1,150,150,3)
preds = model.predict(x)
return preds[0]
target_size = (150, 150)
img_path = 'image_test/bird.jpg'
img = Image.open(img_path)
fig = plt.figure()
plt.imshow(img)
plt.show()
preds = predict(model, img, target_size)
preds
你的意思是'model.predict(图像)'将无法正常工作? – desertnaut
我在我的文章中编辑并编辑。 – NunodeSousa
什么不行?你得到一个错误还是预测不是你期望的? – Harald