2017-01-05 68 views
2

使用MultiRNNCell随着Tensorflow 0.12,已经出现了更改该MultiRNNCell作品,对于初学者来说,state_is_tuple现在设置为True默认的方式,此外,还有关于它的讨论:在tensorflow 0.12

state_is_tuple : If True , accepted and returned states are n -tuples, where n = len(cells) . If False , the states are all concatenated along the column axis. This latter behavior will soon be deprecated.

我想知道我究竟怎么可以用同GRU细胞多层RNN,这是我到目前为止的代码:

def _run_rnn(self, inputs): 
     # embedded inputs are passed in here 
     self.initial_state = tf.zeros([self._batch_size, self._hidden_size], tf.float32) 
     cell = tf.nn.rnn_cell.GRUCell(self._hidden_size) 
     cell = tf.nn.rnn_cell.DropoutWrapper(cell, output_keep_prob=self._dropout_placeholder) 
     cell = tf.nn.rnn_cell.MultiRNNCell([cell] * self._num_layers, state_is_tuple=False) 

     outputs, last_state = tf.nn.dynamic_rnn(
      cell = cell, 
      inputs = inputs, 
      sequence_length = self.sequence_length, 
      initial_state = self.initial_state 
     ) 

     return outputs, last_state 

我输入查询词ID并返回相应的嵌入矢量。

ValueError: Dimension 1 in both shapes must be equal, but are 100 and 200 for 'rnn/while/Select_1' (op: 'Select') with input shapes: [?], [64,100], [64,200]

我有一个?的地方是我的占位符中:

def _add_placeholders(self): 
     self.input_placeholder = tf.placeholder(tf.int32, shape=[None, self._max_steps]) 
     self.label_placeholder = tf.placeholder(tf.int32, shape=[None, self._max_steps]) 
     self.sequence_length = tf.placeholder(tf.int32, shape=[None]) 
     self._dropout_placeholder = tf.placeholder(tf.float32) 

回答

2

你的主要问题是,现在,上面的代码,我在下面的错误运行打招呼在initial_state的设置中。因为你的状态现在是一个元组,(更准确地说是LSTMStateTuple,你不能直接把它分配给tf.zeros,而是使用,

self.initial_state = cell.zero_state(self._batch_size, tf.float32) 

看一看的documentation更多。


要使用这个代码,你就需要通过这feed_dict。做这样的事,

state = sess.run(model.initial_state) 
for batch in batches: 
    # Logic to add input placeholder in `feed_dict` 
    feed_dict[model.initial_state] = state 
    # Note I'm re-using `state` below 
    (loss, state) = sess.run([model.loss, model.final_state], feed_dict=feed_dict) 
+0

作为一个占位符,我怎么能继续通过当前的状态,然后给出上述?状态当前是一个张量,所以对于我来说,保持传递一个匹配零向量维数的张量是很简单的。 – TheM00s3

+0

您可以在feed_dict中传递张量。所以如果你想通过单元格的未来状态,只需使用'sess'来计算该状态并将其传递给馈给字典 – martianwars

+0

@ TheM00s3看到更新的答案 – martianwars