2016-02-26 129 views
0

与此相关的复制变量:How can I copy a variable in tensorflowtensorflow在RNN

我试图复制LSTM解码单元的值在其他地方使用它beamsearch。在伪代码,我想是这样的:

lstm_decode = tf.nn.rnn_cell(...) 
training_output = tf.nn.seq2seq.rnn_decoder(...) 
... do training by back-prop the error on trainint_output ... 

# duplicate the lstm_decode unit (same weights) 
lstm_decode_copy = copy(lstm_decode) 
... do beam search with the duplicated lstm ... 

的问题是,在tensorflow,没有召唤“tf.nn.rnn_cell(......)”过程中产生的LSTM变量,但它是实际上是在函数调用展开到rnn_decoder期间生成的。

我可以将范围设置为“tf.nn.seq2seq.rnn_decoder”函数调用,但lstm权重的实际初始化对我来说并不透明。我如何捕获这些值并重新使用它们来创建一个与所学的权重相同的lstm单元?

谢谢!

回答

0

我想我想通了。

你想要的是用于解码器呼叫范围设置为特定值,说“解码”,在这一行:

training_output = tf.nn.seq2seq.rnn_decoder(...scope="decoding") 

,稍后当你想用你学到的LSTM单位在解码期间,您将变量范围再次设置为“解码”,并使用scope.reuse_variables()来允许重新使用解码的变量。然后简单地使用“lstm_decode”,否则将被绑定到与以前相同的值。

with tf.variable_scope("decoding") as scope: 
    scope.reuse_variables() 
    ... use lstm_decode as usual ... 

这种方式在lstm_decode所有的权重将在这两个子图共享,并取其值学到期间的训练将出现在第二部分也是如此。