2016-03-18 35 views
9

我按照this link中的教程并尝试更改模型的评估方法(位于底部)。我想获得一个前5的评价,我想使用下面的代码:TensorFlow in_top_k评估输入论题

topFiver=tf.nn.in_top_k(y, y_, 5, name=None) 

然而,这会产生以下错误:

File "AlexNet.py", line 111, in <module> 
    topFiver = tf.nn.in_top_k(pred, y, 5, name=None) 
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/ops/gen_nn_ops.py", line 346, in in_top_k 
    targets=targets, k=k, name=name) 
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/ops/op_def_library.py", line 486, in apply_op 
    _Attr(op_def, input_arg.type_attr)) 
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/ops/op_def_library.py", line 59, in _SatisfiesTypeConstraint 
    ", ".join(dtypes.as_dtype(x).name for x in allowed_list))) 
TypeError: DataType float32 for attr 'T' not in list of allowed values: int32, int64 

至于我可以告诉大家,问题是tf.nn.in_top_k()只适用于tf.int32tf.int64数据,但我的数据是tf.float32格式。有没有解决方法?

回答

19

的参数targets必须是类ID的向量(即predictions矩阵中的列索引)。这意味着它只适用于单类分类问题。

如果您的问题是单类问题,那么我认为您的张量是您的示例的真实标签的一个热门编码(例如,因为您也将它们传递给像tf.nn.softmax_cross_entropy_with_logits()那样的操作符)。情况下,你有两个选择:

  • 如果标签最初存储为整数标签,直接tf.nn.in_top_k()通过他们无需将其转换为一个热(另外,可以考虑使用tf.nn.sparse_softmax_cross_entropy_with_logits()为您损失的功能,因为它可能效率更高)
  • 如果标签最初存储在独热格式,可以使用tf.argmax()它们转换为整数:

    labels = tf.argmax(y_, 1) 
    topFiver = tf.nn.in_top_k(y, labels, 5)