我正在使用TF.LEARN和mnist数据。我以0.96的准确度训练了我的神经网络,但现在我不太确定如何预测一个值。获取有关预测MNIST数据集的奇怪值
这里是我的代码..
#getting mnist data to a zip in the computer.
mnist.SOURCE_URL = 'https://web.archive.org/web/20160117040036/http://yann.lecun.com/exdb/mnist/'
trainX, trainY, testX, testY = mnist.load_data(one_hot=True)
# Define the neural network
def build_model():
# This resets all parameters and variables
tf.reset_default_graph()
net = tflearn.input_data([None, 784])
net = tflearn.fully_connected(net, 100, activation='ReLU')
net = tflearn.fully_connected(net, 10, activation='softmax')
net = tflearn.regression(net, optimizer='sgd', learning_rate=0.1, loss='categorical_crossentropy')
# This model assumes that your network is named "net"
model = tflearn.DNN(net)
return model
# Build the model
model = build_model()
model.fit(trainX, trainY, validation_set=0.1, show_metric=True, batch_size=100, n_epoch=8)
#Here is the problem
#lets say I want to predict what my neural network will reply back if I put back the send value from my trainX
the value of trainX[2] is 4
pred = model.predict([trainX[2]])
print(pred)
#What I get is
[[2.6109733880730346e-05, 4.549271125142695e-06, 1.8098366126650944e-05, 0.003199575003236532, 0.20630565285682678, 0.0003870908112730831, 4.902480941382237e-05, 0.006617342587560415, 0.018498118966817856, 0.764894425868988]]
我要的是 - > 4
的问题是,我不知道如何使用此功能预测并放入trainX值以获得预测。
我做这个 预解码= model.predict([trainX [5]) 打印(np.argmax(预解码)) 得到了答案,但谢谢你告诉我关于tf.argmax(预解码,1) –
我想我不清楚我的问题,我只是想知道如何计算索引号,这基本上是通过使用np.argmax ...对不起,混乱。我非常感谢答案! –
@MasnadNihit如果你只有一个预测,那么'np.argmax'适合你。如果你有多个,那么你需要'np.argmax(pred,1)'来同时获得所有预测的所有索引。 –