2016-03-03 53 views
0

我正在训练一个简单的网络来预测单个对象的边界框坐标。但是有图片没有找到任何物体。由于网络总是进行预测,因此它也预测0和1之间的置信度值,这应该表示图片中存在对象的概率。我的张量与预测被称为logits,它是一个(batch_size, 5)张量(置信度,x,y,宽度和高度)。同样,labels张量也是(batch_size, 5)Tensorflow如何检查张量行是否仅为零?

以前我是训练只与始终有一个物体的图像,这样我就可以基本上做到

loss = tf.l2_loss(logits - labels)

我也想开始训练,没有对象和图片时没有对象的图片,我不希望网络因其预测的坐标而受到惩罚。在这种情况下,所有重要的是置信度值,它应该接近0(没有对象)。

我应该如何构建我的标签和损失函数来完成此操作?我可以将没有对象的图像标签设置为全零,但是如何检查特定行只有零?在这种情况下,logits中的对应行也需要设置为零(置信度值!除外),以便由于坐标而产生的损失也为零。

回答

1

的一种方法找到,如果张量的一排都是零:

import tensorflow as tf 

image = tf.fill([8,8], 0) 
sess = tf.Session() 
sess.run(tf.initialize_all_variables()) 
image_row = tf.slice(image, [1,0], [1, -1]) 
total = tf.reduce_sum(tf.abs(image_row)) 
is_all_zero = tf.equal(total, 0) 
print sess.run([total, is_all_zero, image_row]) 

结果:

[0, True, array([[0, 0, 0, 0, 0, 0, 0, 0]], dtype=int32)] 
+0

这解决了OP的问题,但实际上并没有检查该行是全零,但总和为零。一行可能包含'[100,-100,0,0,0]',这将加和为零。如果你改变了一些东西,那么它可以在其他情况下一致地工作,并更精确地匹配OP问题的标题。 –

+0

先取绝对值 – zenna

+0

谢谢,增加了abs运算符来修复此案例。 – fabrizioM