2017-08-23 26 views
1

我想设计这样的损失函数:Tensorflow:tf.argmax和切片

sum((y[argmax(y_)] - y_[argmax(y_)])²) 

我没有找到一个方法来做到y[argmax(y_)]。我试过y[k],y[:,k]y[None,k]这些工作都没有。这是我的代码:

Na = 3 
    x = tf.placeholder(tf.float32, [None, 2]) 
    W = tf.Variable(tf.zeros([2, Na])) 
    b = tf.Variable(tf.zeros([Na])) 
    y = tf.nn.relu(tf.matmul(x, W) + b) 
    y_ = tf.placeholder(tf.float32, [None, 3]) 
    k = tf.argmax(y_, 1) 
    diff = y[k] - y_[k] 
    loss = tf.reduce_sum(tf.square(diff)) 

和错误:

File "/home/ncarrara/phd/code/cython/robotnavigation/ftq/cftq19.py", line 156, in <module> 
    diff = y[k] - y_[k] 
    File "/home/ncarrara/miniconda3/lib/python2.7/site-packages/tensorflow/python/ops/array_ops.py", line 499, in _SliceHelper 
    name=name) 
    File "/home/ncarrara/miniconda3/lib/python2.7/site-packages/tensorflow/python/ops/array_ops.py", line 663, in strided_slice 
    shrink_axis_mask=shrink_axis_mask) 
    File "/home/ncarrara/miniconda3/lib/python2.7/site-packages/tensorflow/python/ops/gen_array_ops.py", line 3515, in strided_slice 
    shrink_axis_mask=shrink_axis_mask, name=name) 
    File "/home/ncarrara/miniconda3/lib/python2.7/site-packages/tensorflow/python/framework/op_def_library.py", line 767, in apply_op 
    op_def=op_def) 
    File "/home/ncarrara/miniconda3/lib/python2.7/site-packages/tensorflow/python/framework/ops.py", line 2508, in create_op 
    set_shapes_for_outputs(ret) 
    File "/home/ncarrara/miniconda3/lib/python2.7/site-packages/tensorflow/python/framework/ops.py", line 1873, in set_shapes_for_outputs 
    shapes = shape_func(op) 
    File "/home/ncarrara/miniconda3/lib/python2.7/site-packages/tensorflow/python/framework/ops.py", line 1823, in call_with_requiring 
    return call_cpp_shape_fn(op, require_shape_fn=True) 
    File "/home/ncarrara/miniconda3/lib/python2.7/site-packages/tensorflow/python/framework/common_shapes.py", line 610, in call_cpp_shape_fn 
    debug_python_shape_fn, require_shape_fn) 
    File "/home/ncarrara/miniconda3/lib/python2.7/site-packages/tensorflow/python/framework/common_shapes.py", line 676, in _call_cpp_shape_fn_impl 
    raise ValueError(err.message) 
ValueError: Shape must be rank 1 but is rank 2 for 'strided_slice' (op: 'StridedSlice') with input shapes: [?,3], [1,?], [1,?], [1]. 

回答

0

这是可以做到使用tf.gather_nd

import tensorflow as tf 

Na = 3 
x = tf.placeholder(tf.float32, [None, 2]) 
W = tf.Variable(tf.zeros([2, Na])) 
b = tf.Variable(tf.zeros([Na])) 
y = tf.nn.relu(tf.matmul(x, W) + b) 
y_ = tf.placeholder(tf.float32, [None, 3]) 
k = tf.argmax(y_, 1) 
# Make index tensor with row and column indices 
num_examples = tf.cast(tf.shape(x)[0], dtype=k.dtype) 
idx = tf.stack([tf.range(num_examples), k], axis=-1) 
diff = tf.gather_nd(y, idx) - tf.gather_nd(y_, idx) 
loss = tf.reduce_sum(tf.square(diff)) 

说明:

在这种情况下,这个想法tf.gather_nd是制作一个矩阵(一个二维张量),其中每一行包含索引输出中的行和列。例如,如果我有一个矩阵a含有:含有

| 1 2 | 
| 0 1 | 
| 2 2 | 
| 1 0 | 

接着的tf.gather_nd(a, i)结果将是载体(一维张量)::

| 1 2 3 | 
| 4 5 6 | 
| 7 8 9 | 

和含有基质i

| 6 | 
| 2 | 
| 9 | 
| 4 | 

在这种情况下,列索引由tf.argmaxk中给出;它会告诉你每一行,哪一列是最高值的列。现在你只需要将行索引与每一个这些。 k中的第一个元素是行0中最大值列的索引,第1行中的下一个元素是索引,依此类推。 num_examples只是xtf.range(num_examples)中的行数,然后给出从0到x减去1的行数(即所有的行索引序列)中的行向量。现在您只需要将ktf.stack所做的一样,结果idx就是tf.gather_nd的参数。

+0

看起来不错,但现在我不确定是否足以验证您的答案,无论如何,谢谢! –

+0

@nicolascarrara我已经添加了一些解释。 – jdehesa

+0

非常感谢,现在很清楚! –