我试图使用tflearn和我自己的数据。ValueError:无法为张量u'InputData/X:0',形状为'(?,32,32,1)'的形状(64,32,32)提供值'
我有19748个灰度图像,我想用我的模型进行训练。我使用了tflearn的Image_Preloader方法来输入图像。所有图像都转换成32 * 32大小。但是当我开始训练过程时,我得到这个错误“ValueError:无法提供形状为'(?,32,32,1)的Tensor u'InputData/X:0'的形状值(64,32,32) '“
我已经尝试了一切在我的知识,但我不能解决它,并有类似的问题在stackoverflow中,但他们没有为我工作。
这是我的代码。
from __future__ import division, print_function, absolute_import
import tflearn
import pickle
from tflearn.layers.core import input_data, dropout, fully_connected
from tflearn.layers.conv import conv_2d, max_pool_2d
from tflearn.layers.normalization import local_response_normalization
from tflearn.layers.estimator import regression
from time import gmtime, strftime
from tflearn.data_utils import image_preloader
import numpy as np
dataset_file = 'noww.txt'
X = np.zeros((19748,32,32,1))
Y = np.zeros((19748,10))
X, Y = image_preloader(dataset_file, image_shape=(32, 32), mode='file', categorical_labels=True, normalize=True)
network = input_data(shape=[None, 32, 32, 1])
network = conv_2d(network, 64, 3, activation='relu')
network = conv_2d(network, 64, 3, activation='relu')
network = max_pool_2d(network, 2, strides=2)
network = conv_2d(network, 128, 3, activation='relu')
network = conv_2d(network, 128, 3, activation='relu')
network = max_pool_2d(network, 2, strides=2)
network = conv_2d(network, 256, 3, activation='relu')
network = conv_2d(network, 256, 3, activation='relu')
network = conv_2d(network, 256, 3, activation='relu')
network = max_pool_2d(network, 2, strides=2)
network = fully_connected(network, 1024, activation='relu')
network = dropout(network, 0.5)
network = fully_connected(network, 1024, activation='relu')
network = dropout(network, 0.5)
network = fully_connected(network, 10, activation='softmax')
network = regression(network, optimizer='rmsprop',
loss='categorical_crossentropy',
learning_rate=0.0001)
model = tflearn.DNN(network, checkpoint_path='model_1',
max_checkpoints=1, tensorboard_verbose=0)
model.fit(X, Y, n_epoch=200, shuffle=True,
show_metric=True, batch_size=64, snapshot_step=200,
snapshot_epoch=False, run_id='model_1')
请帮忙。