2016-07-20 105 views
2

我有一个形状为[batch_size, D]的二维张量A,以及形状为[batch_size]的一维张量BB的每个元素是A的列索引,对于A的每一行,例如。 B[i] in [0,D)Tensorflow索引到具有1d张量的2d张量

什么是tensorflow得到的值A[B]

例如最好的办法:

A = tf.constant([[0,1,2], 
       [3,4,5]]) 
B = tf.constant([2,1]) 

与所需的输出:

some_slice_func(A, B) -> [2,4] 

还有另一种约束。实际上,batch_size实际上是None

在此先感谢!

回答

3

我能得到它的工作使用线性指标:

def vector_slice(A, B): 
    """ Returns values of rows i of A at column B[i] 

    where A is a 2D Tensor with shape [None, D] 
    and B is a 1D Tensor with shape [None] 
    with type int32 elements in [0,D) 

    Example: 
     A =[[1,2], B = [0,1], vector_slice(A,B) -> [1,4] 
      [3,4]] 
    """ 
    linear_index = (tf.shape(A)[1] 
        * tf.range(0,tf.shape(A)[0])) 
    linear_A = tf.reshape(A, [-1]) 
    return tf.gather(linear_A, B + linear_index) 

这种感觉稍微哈克虽然。

如果有人知道更好(如更清晰或更快),也请留下一个答案! (我不会接受我自己的一段时间)

0

最简单的方法可能是通过连接范围(batch_size)和B来构建适当的2d索引,以获得batch_size x 2矩阵。然后将其传递给tf.gather_nd。

0

代码什么@Eugene Brevdo说:

def vector_slice(A, B): 
    """ Returns values of rows i of A at column B[i] 

    where A is a 2D Tensor with shape [None, D] 
    and B is a 1D Tensor with shape [None] 
    with type int32 elements in [0,D) 

    Example: 
     A =[[1,2], B = [0,1], vector_slice(A,B) -> [1,4] 
      [3,4]] 
    """ 
    B = tf.expand_dims(B, 1) 
    range = tf.expand_dims(tf.range(tf.shape(B)[0]), 1) 
    ind = tf.concat([range, B], 1) 
    return tf.gather_nd(A, ind)