2016-10-27 93 views
7

这个问题的两个部分:如何为张量流中的张量子集赋值?

(1)什么是更新tensorflow张量子集的最佳方法是什么?我见过几个相关问题:

Adjust Single Value within Tensor -- TensorFlowHow to update a subset of 2D tensor in Tensorflow?

和我知道,可变对象可以使用Variable.assign()(和/或scatter_update等)进行分配,但对我来说似乎很奇怪,tensorflow没有更直观的方式来更新Tensor对象的一部分。我已经搜索了tensorflow api docs和stackoverflow很长一段时间了,似乎无法找到比上面链接中提供的更简单的解决方案。这看起来特别奇怪,特别是考虑到Theano与Tensor.set_subtensor()具有相同的版本。我是否错过了某些东西,或者没有简单的方法通过tensorflow API在这一点上做到这一点?

(2)如果有一个更简单的方法,它是否可区分?

谢谢!

+0

您是否足够使用numpy数组初始化张量值?然后我推荐这种方式。 – Jin

+4

在最新版本的Tensorflow中,您可以使用类似numpy的切片来更新变量,如下所示:'v [2:4] .assign([1,2])',其中'v'是'Variable'。这是否回答你的问题? –

+0

谢谢你,欣赏这些想法/评论。不幸的是,我不是在寻找什么,尽管...使用numpy-like切片的更新变量就是它,除了它只适用于“变量”而不是“张量”。我重新设计了我的模型以避免明确需要这个操作,但现实看来,Tensor对象在tf中是完全不可变的(不像可变对象)。再次感谢您的想法! – joeliven

回答

0

我认为张量的不变性是构造计算图所必需的;你不能让张量更新它的某些值而不成为另一个张量,否则在它之前没有任何东西可以放入图中。 The same issue comes up in Autograd

使用布尔模板(使它们变量并使用assign,甚至在numpy之前定义它们)可以做到这一点(但很难看)。这是可以区分的,但在实践中我会避免更新副本。

如果你真的要和我真的希望有一个更好的方式来做到这一点,但这里是一个方式使用tf.dynamic_stitchtf.setdiff1d做到在1D:

def set_subtensor1d(a, b, slice_a, slice_b): 
    # a[slice_a] = b[slice_b] 
    a_range = tf.range(a.shape[0]) 
    _, a_from = tf.setdiff1d(a_range, a_range[slice_a]) 
    a_to = a_from 
    b_from, b_to = tf.range(b.shape[0])[slice_b], a_range[slice_a]  
    return tf.dynamic_stitch([a_to, b_to], 
        [tf.gather(a, a_from),tf.gather(b, b_from)]) 

对于更高的层面,这可能滥用reshape一概而论(其中nd_slice could be implemented like this但有可能是一个更好的方法):

def set_subtensornd(a, b, slice_tuple_a, slice_tuple_b): 
    # a[*slice_tuple_a] = b[*slice_tuple_b] 
    a_range = tf.range(tf.reduce_prod(tf.shape(a))) 
    a_idxed = tf.reshape(a_range, tf.shape(a)) 
    a_dropped = tf.reshape(nd_slice(a_idxed, slice_tuple_a), [-1]) 
    _, a_from = tf.setdiff1d(a_range, a_dropped) 
    a_to = a_from 
    b_range = tf.range(tf.reduce_prod(tf.shape(b))) 
    b_idxed = tf.reshape(b_range, tf.shape(b)) 
    b_from = tf.reshape(nd_slice(b_idxed, slice_tuple_b), [-1]) 
    b_to = a_dropped 
    a_flat, b_flat = tf.reshape(a, [-1]), tf.reshape(b, [-1]) 
    stitched = tf.dynamic_stitch([a_to, b_to], 
        [tf.gather(a_flat, a_from),tf.gather(b_flat, b_from)]) 
    return tf.reshape(stitched, tf.shape(a)) 

我不知道如何缓慢,这将是。我猜很慢。而且,除了在几个张量上运行之外,我还没有对它进行太多的测试。