我想对张量应用过滤器并删除不符合我的标准的值。例如,可以说我有一个张,看起来像这样:从softmax中删除低质量张量预测
softmax_tensor = [[ 0.05 , 0.05, 0.2, 0.7], [ 0.25 , 0.25, 0.3, 0.2 ]]
眼下,分类挑选张量argmax
预测:
predictions = [[3],[2]]
但是这ISN”这正是我想要的,因为我放弃了有关该预测信心的信息。我宁愿不做出预测,也不愿做出错误的预测。所以,我希望做的是返回过滤张量,像这样:
new_softmax_tensor = [[ 0.05 , 0.05, 0.2, 0.7]]
new_predictions = [[3]]
如果这是直线上升的蟒蛇,我有没有问题:
new_softmax_tensor = []
new_predictions = []
for idx,listItem in enumerate(softmax_tensor):
# get two highest max values and see if they are far enough apart
M = max(listItem)
M2 = max(n for n in listItem if n!=M)
if M2 - M > 0.3: # just making up a criteria here
new_softmax_tensor.append(listItem)
new_predictions.append(predictions[idx])
但鉴于tensorflow上工作的张量,我不知道如何做到这一点 - 如果我这样做,它会打破计算图吗?
A previous SO post建议使用tf.gather_nd,但在这种情况下,他们已经有了一个他们想要过滤的张量。我也看过tf.cond但仍不明白。我想很多其他人会从这个完全相同的解决方案中受益。
谢谢大家。
你的意思是这样吗? '' #Compute顶部2 SOFTMAX分数 actual_diff = tf.subtract(tf.nn.top_k(softmax_tensor中,k = 2,分类= TRUE)) #创建一个布尔张量,指出保持或之间的差不 cond_result = tf.cond(actual_diff
是的,可能是这样的... –