2016-08-11 22 views
3

MNIST集合包含60000个用于训练集的图像。在训练我的Tensorflow时,我想运行训练步骤来训练整个训练集的模型。 Tensorflow网站上的深度学习示例使用20000次迭代,批量大小为50(总计为1,000,000批次)。当我尝试超过30,000次迭代时,我的数字预测失败(预测所有手写数字为0)。我的问题是,我应该使用多少次迭代,批量大小为50来训练整个MNIST集的张量流模型?用整个MNIST数据集(60000图像)训练张量流所需的迭代次数?

self.mnist = input_data.read_data_sets('MNIST_data', one_hot=True) 
for i in range(FLAGS.training_steps): 
    batch = self.mnist.train.next_batch(50) 
    self.train_step.run(feed_dict={self.x: batch[0], self.y_: batch[1], self.keep_prob: 0.5}) 
    if (i+1)%1000 == 0: 
     saver.save(self.sess, FLAGS.checkpoint_dir + 'model.ckpt', global_step = i) 

回答

2

我认为这取决于您的停止标准。如果损失没有改善,您可以停止培训,或者您可以拥有验证数据集,并在验证准确性无法再提高时停止培训。

+0

我想我会做到这一点。可能在每1000次迭代中,我会尽量准确。如果在某个时候,积分下降到0,我应该停止在那里下雨。 – Swapnil

1

随着机器学习,你往往会有严重的收益递减情况。例如这里是从我的细胞神经网络的一个准确的列表:

Epoch 0 current test set accuracy : 0.5399 
Epoch 1 current test set accuracy : 0.7298 
Epoch 2 current test set accuracy : 0.7987 
Epoch 3 current test set accuracy : 0.8331 
Epoch 4 current test set accuracy : 0.8544 
Epoch 5 current test set accuracy : 0.8711 
Epoch 6 current test set accuracy : 0.888 
Epoch 7 current test set accuracy : 0.8969 
Epoch 8 current test set accuracy : 0.9064 
Epoch 9 current test set accuracy : 0.9148 
Epoch 10 current test set accuracy : 0.9203 
Epoch 11 current test set accuracy : 0.9233 
Epoch 12 current test set accuracy : 0.929 
Epoch 13 current test set accuracy : 0.9334 
Epoch 14 current test set accuracy : 0.9358 
Epoch 15 current test set accuracy : 0.9395 
Epoch 16 current test set accuracy : 0.942 
Epoch 17 current test set accuracy : 0.9436 
Epoch 18 current test set accuracy : 0.9458 

正如你所看到的收益开始下降后〜10个历元*,但是这可能会因您的网络和学习速度上。基于多少时间你有多少时间有好处做的不尽相同,但我发现20是一个合理的数字

*我一直使用时代这个词来表示一个整个运行通过一个数据集但我不知道是该定义的准确性,这里每个时代为〜带的大小批量429个训练步128

0

您可以使用类似no_improve_epoch并将其设置为假设3.什么便索性意味着如果在3次迭代中没有> 1%的改善,则停止迭代。

no_improve_epoch= 0 
     with tf.Session() as sess: 
      sess.run(cls.init) 
      if cls.config.reload=='True': 
       print(cls.config.reload) 
       cls.logger.info("Reloading the latest trained model...") 
       saver.restore(sess, cls.config.model_output) 
      cls.add_summary(sess) 
      for epoch in range(cls.config.nepochs): 
       cls.logger.info("Epoch {:} out of {:}".format(epoch + 1, cls.config.nepochs)) 
       dev = train 
       acc, f1 = cls.run_epoch(sess, train, dev, tags, epoch) 

       cls.config.lr *= cls.config.lr_decay 

       if f1 >= best_score: 
        nepoch_no_imprv = 0 
        if not os.path.exists(cls.config.model_output): 
         os.makedirs(cls.config.model_output) 
        saver.save(sess, cls.config.model_output) 
        best_score = f1 
        cls.logger.info("- new best score!") 

       else: 
        no_improve_epoch+= 1 
        if nepoch_no_imprv >= cls.config.nepoch_no_imprv: 
         cls.logger.info("- early stopping {} Iterations without improvement".format(
          nepoch_no_imprv)) 
         break 

Sequence Tagging GITHUB

相关问题