2017-06-19 58 views
0

我正在关注TensorFlow的Generative Adversarial Network的教程。本教程使用MNIST数据集来训练模型。我想减少输入的大小,以便我的程序运行速度更快,但不知道如何获取我正在使用的MNIST数据集的子集。下面是我用于提取所述数据集的代码:如何子集MNIST数据集?

from tensorflow.examples.tutorials.mnist import input_data 
mnist = input_data.read_data_sets("MNIST_data/") 

回答

0

有一种方法

mnist.next_batch(batchsize) 

提取从列车组的长度BATCHSIZE的随机样本。

如果你不想要的东西是随机的,您可以通过

x = mnist.train.images[start_batch:end_batch] 
y = mnist.train.labels[start_batch:end_batch] 

或类似与mnist.test访问它们的测试集。

+0

嗨,非常感谢您的回复。我能够使用您提供的方法对train.images和train.labels进行子集划分。但是,在将这些数据集分组后,我得到一个NDArray对象,并且我无法调用为ndarray的mnist数据集设计的任何方法。有没有什么办法可以将ndarray放回mnist数据集? – nnguyen24