2017-02-20 140 views
3

我正在使用tensorflow导入一些MNIST输入数据。我跟着这个教程... https://www.tensorflow.org/get_started/mnist/beginners使用matplotlib显示MNIST图像

我导入它们作为如此...

from tensorflow.examples.tutorials.mnist import input_data 

mnist = input_data.read_data_sets("/tmp/data/", one_hot=True) 

我希望能够从训练集显示任何图像。我知道图像的位置是mnist.train.images,所以我尝试访问的第一个图像,并像这样显示它...

with tf.Session() as sess: 
    #access first image 
    first_image = mnist.train.images[0] 

    first_image = np.array(first_image, dtype='uint8') 
    pixels = first_image.reshape((28, 28)) 
    plt.imshow(pixels, cmap='gray') 

我一个试图将图像转换为28 28 numpy的阵列,因为我知道每个图像是28乘28像素。

然而,当我运行代码我得到的是以下...

enter image description here

显然,我做错了什么。当我打印矩阵时,一切似乎都很好,但我认为我错误地重塑了它。

+1

的可能的复制的完整代码[TensorFlow - 从数据集MNIST显示图像](https://stackoverflow.com/questions/38308378/tensorflow-show-image -from-mnist-dataset) –

回答

1

您正在将一组浮点数(as described in the docs)转换为uint8,如果这些浮点数不是1.0,它们会将它们截断为0。你应该围绕它们,或者将它们用作浮动物或者使用它们作为浮动物或与255相乘。

我不确定,为什么你看不到白色背景,但是我建议使用明确的灰度等级。

4

以下代码显示了用于训练神经网络的MNIST数字数据库中显示的示例图像。它使用来自stackflow周围的各种代码并避免pil。

# Tested with Python 3.5.2 with tensorflow and matplotlib installed. 
from matplotlib import pyplot as plt 
import numpy as np 
from tensorflow.examples.tutorials.mnist import input_data 
mnist = input_data.read_data_sets('MNIST_data', one_hot = True) 
def gen_image(arr): 
    two_d = (np.reshape(arr, (28, 28)) * 255).astype(np.uint8) 
    plt.imshow(two_d, interpolation='nearest') 
    return plt 

# Get a batch of two random images and show in a pop-up window. 
batch_xs, batch_ys = mnist.test.next_batch(2) 
gen_image(batch_xs[0]).show() 
gen_image(batch_xs[1]).show() 

MNIST的定义是:https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/learn/python/learn/datasets/mnist.py

,导致我需要显示MNINST图像的tensorflow神经网络是:https://github.com/tensorflow/tensorflow/blob/r1.2/tensorflow/examples/tutorials/mnist/mnist_deep.py

因为我只用于编程的Python两个小时,我可能会犯一些新的错误。请随时纠正。

3

下面是使用matplotlib

表示图像
first_image = mnist.test.images[0] 
first_image = np.array(first_image, dtype='float') 
pixels = first_image.reshape((28, 28)) 
plt.imshow(pixels, cmap='gray') 
plt.show() 
+0

这对我很有用,谢谢..善意编辑以包括几个导入行,其中定义了'np','mnist','plt'等等,以便搜索快速答案的人可以快速复制并粘贴你的逐字。谢谢 – Somo