2017-06-12 48 views
0

我将我的意见从https://github.com/tensorflow/tensorflow/issues/8833转移到StackOverflow,因为SO似乎更合适。TensorFlow LSTM状态从元组切换到张量

我在尝试使用tensorflow.contrib.seq2seqtensorflow.contrib.rnnBasicLSTMCell实现序列模型。在rnn_cell_impl.py,行c, h = state导致以下错误:

TypeError: 'Tensor' object is not iterable.

当单步调试代码,我才知道,错误造成的第三次c, h = state进行评估。前两次,状态为<class 'tensorflow.python.ops.rnn_cell_impl.LSTMStateTuple'>,但第三次状态为<class 'tensorflow.python.framework.ops.Tensor'>。显然,我想第三次有类型LSTMStateTuple,但我不知道什么可能导致交换机。

有问题的state张量的名称是define_model/define_decoder/decoder/while/Identity_3。我写了define_model()define_decoder()的方法,其余的信息表明我的decoder内发生了一些事情。

如果相关,我使用Python 3.6和Tensorflow 1.2。

回答

0

答案可以在上面的linked Github issue page找到。简要总结一下,问题是我的编码器使用了一个双向RNN,它产生一个2元组的LSTMStateTuples,即每个定向RNN有一个c和一个h状态。之后,解码器接受一个单独的单元,该单元与单个LSTMStateTuple相关联。为了解决这个问题,你需要分别连接两个定向RNNS的c状态和h状态,把它作为一个新的LSTMStateTuple包装起来,并传递给解码器的状态。