作为学习我在tensorflow方法的一部分,我正在转换一些现有的矩阵处理逻辑。其中一个步骤是分散操作,例如下例中使用scatter_add的操作。我对这个例子的问题是,每次评估操作时,它都会在前一个结果的基础上累加。随着3 run()的调用,如下图所示,打印结果是:如何进行非累积张量流散射_add?
[[8 12 8]...]
[[16 24 16]...]
[[24 36 24]...]
而我要的是每次[[8 12 8]...]
。 indices
矢量包含重复项,updates
中的对应元素需要加在一起,但不能与已存在于scattered
中的现有值相加。
张量流文档中的散布操作都不是我所期望的。是否有适当的操作使用?如果不是,那么实现我所需要的最好方法是什么?
import tensorflow as tf
indices = tf.constant([0, 1, 0, 1, 0, 1, 0, 1], tf.int32)
updates = tf.constant([
[1., 2., 3., 4.],
[2., 3., 4., 1.],
[3., 4., 1., 2.],
[4., 1., 2., 3.],
[1., 2., 3., 4.],
[2., 3., 4., 1.],
[3., 4., 1., 2.],
[4., 1., 2., 3.]], tf.float32)
scattered = tf.Variable([
[0., 0., 0., 0.,],
[0., 0., 0., 0.,]], tf.float32)
# Requirement:
# scattered[i, j] = Sum of updates[k, j] where indices[k] == i
#
# i.e.
# scattered_data = [
# [1+3+1+3, 2+4+2+4, 3+1+3+1, 4+2+4+2],
# [2+4+2+4, 3+1+3+1, 4+2+4+2, 1+3+1+3]]
# == [
# [ 8, 12, 8, 12],
# [12, 8, 12, 8]]
scattered = tf.scatter_add(scattered, indices, updates, use_locking=True, name='scattered')
scattered_print = tf.Print(scattered, [scattered])
sess = tf.InteractiveSession()
sess.run(tf.global_variables_initializer())
sess.run(scattered_print)
# Printout: [[8 12 8]...]
sess.run(scattered_print)
# Printout: [[16 24 16]...]
sess.run(scattered_print)
# Printout: [[24 36 24]...]
sess.close()
感谢您的回复。但是,所显示的代码(当然)是真正问题的玩具版本。实际上,“索引”中的值范围从0到约15,000,其长度约为10,000,000。转换为32位条目的矩阵乘法需要创建一个大约600GB的中间矩阵。我正在试图将它变成GPU。现有的实现使用每个索引条目16位,因此表示大约20MB的索引数组。 –