我正在构建用于语言识别的statefull LSTM。 正在有条件的我可以用更小的文件来训练网络,并且新的批处理将会像讨论中的下一句话一样。 但是,要正确训练网络,我需要重置一些批次之间的LSTM的隐藏状态。Tensorflow RNN-LSTM - 重置隐藏状态
我使用一个变量来存储LSTM的hidden_state性能:
with tf.variable_scope('Hidden_state'):
hidden_state = tf.get_variable("hidden_state", [self.num_layers, 2, self.batch_size, self.hidden_size],
tf.float32, initializer=tf.constant_initializer(0.0), trainable=False)
# Arrange it to a tuple of LSTMStateTuple as needed
l = tf.unstack(hidden_state, axis=0)
rnn_tuple_state = tuple([tf.contrib.rnn.LSTMStateTuple(l[idx][0], l[idx][1])
for idx in range(self.num_layers)])
# Build the RNN
with tf.name_scope('LSTM'):
rnn_output, _ = tf.nn.dynamic_rnn(cell, rnn_inputs, sequence_length=input_seq_lengths,
initial_state=rnn_tuple_state, time_major=True)
现在我对如何重置隐藏状态混乱。我已经尝试了两种解决方案,但它不工作:
首先解决
重置与“hidden_state”变量:
rnn_state_zero_op = hidden_state.assign(tf.zeros_like(hidden_state))
它不工作,我想这是因为拆散和元组在运行rnn_state_zero_op操作后,构造不会“重新播放”到图中。
解决方法二
继LSTMStateTuple vs cell.zero_state() for RNN in Tensorflow我试着细胞状态重置:
rnn_state_zero_op = cell.zero_state(self.batch_size, tf.float32)
它似乎没有任何工作。
问题
我在心中另一种解决方案,但它在猜测最好的:我没有保持由tf.nn.dynamic_rnn返回的状态,我已经想到这一点,但我得到一个元组我无法找到一种方法来构建重置元组的操作。
在这一点上,我必须承认,我不太了解tensorflow的内部工作,如果甚至有可能做我想做的事情。 有没有适当的方法来做到这一点?
谢谢!