我有一个形状为[batch_size, D]
的二维张量A
,以及形状为[batch_size]
的一维张量B
。 B
的每个元素是A
的列索引,对于A
的每一行,例如。 B[i] in [0,D)
。Tensorflow索引到具有1d张量的2d张量
什么是tensorflow得到的值A[B]
例如最好的办法:
A = tf.constant([[0,1,2],
[3,4,5]])
B = tf.constant([2,1])
与所需的输出:
some_slice_func(A, B) -> [2,4]
还有另一种约束。实际上,batch_size
实际上是None
。
在此先感谢!