2017-04-07 50 views
2

我正在收集每批中张量流的汇总统计。如何累积张量流中的汇总统计

我想收集在测试集上计算出的相同汇总统计信息,但测试集太大而无法在一个批处理中进行处理。

当我遍历测试集时,是否有一种方便的方式来计算相同的汇总统计信息?

+0

可能重复? https://stackoverflow.com/questions/40788785/how-to-average-summaries-over-multiple-batches/ – Maikefer

+0

这是一个重复的问题,但接受的答案没有提到流平均包,它现在有转移到'tf.metrics',这个问题有一个更新的答案,但确实提到了它。 –

回答

1

另一种可能性是累积在测试批次的摘要tensorflow的外侧,具有在图中的虚拟变量,然后可以分配积累的结果。举个例子:假设你通过几个批次计算验证集上的损失,并希望得到平均值的总结。在评估时间

with tf.name_scope('valid_loss'): 
    v_loss = tf.Variable(tf.constant(0.0), trainable=False) 
    self.v_loss_pl = tf.placeholder(tf.float32, shape=[], name='v_loss_pl') 
    self.update_v_loss = tf.assign(v_loss, self.v_loss_pl, name='update_v_loss') 

with tf.name_scope('valid_summaries'): 
    v_loss_s = tf.summary.scalar('validation_loss', v_loss) 
    self.valid_summaries = tf.summary.merge([v_loss_s], name='valid_summaries') 

然后:你可以做下面的方式实现这一目标

total_loss = 0.0 
for batch in all_batches: 
    loss, _ = sess.run([get_loss, ...], feed_dict={...}) 
    total_loss += loss 
total_loss /= float(n_batches) 

[_, v_summary_str] = sess.run([self.update_v_loss, self.valid_summaries], 
           feed_dict={self.v_loss_pl: total_loss}) 
writer.add_summary(v_summary_str) 

虽然这能够完成任务,但无可否认的感觉有点哈克。您发布的contrib的流量指标评估可能会更加优雅 - 我从来没有真正遇到它,所以很想查看它。