2017-06-21 30 views
1

我火车LSTM网络如何从vanila Tensorflow中的LSTM单元中提取所有权重?

cell_fw = tf.contrib.rnn.BasicLSTMCell(HIDDEN_SIZE) 
cell_bw = tf.contrib.rnn.BasicLSTMCell(HIDDEN_SIZE) 

rnn_outputs, final_state_fw, final_state_bw = tf.contrib.rnn.static_bidirectional_rnn(
    cell_fw=cell_fw, 
    cell_bw=cell_bw, 
    inputs=rnn_inputs, 
    dtype=tf.float32 
) 

此外,我尝试将其保存系数:

d = {} 
with tf.Session() as sess: 
    # train code ... 
    variables_names =[v.name for v in tf.global_variables()] 
    values = sess.run(variables_names) 
    for k,v in zip(variables_names, values): 
     d[k] = v 

字典d必须从每个LSTM细胞只有2个对象:

[(k,v.shape) for (k,v) in sorted(d.items(), key=lambda x:x[0])] 
[('bidirectional_rnn/bw/basic_lstm_cell/biases:0', (1024,)), 
('bidirectional_rnn/bw/basic_lstm_cell/weights:0', (272, 1024)), 
('bidirectional_rnn/fw/basic_lstm_cell/biases:0', (1024,)), 
('bidirectional_rnn/fw/basic_lstm_cell/weights:0', (272, 1024)), 
('char_embedding:0', (70, 16)), 
('softmax_biases:0', (5068,)), 
('softmax_weights:0', (5068, 512))] 

我我感到困惑。每个LSTM单元应该包含多达4个可训练层,或者不是?如果是这样,如何从LSTM单元获得所有权重?

回答

1

4个权重(和偏置)一LSTM细胞的被存储为单个张量,其中,沿着所述第二轴的切片对应于不同种类的权重的(在栅极,忘记栅极,ECC)

例如,我想你的情况下,HIDDEN_SIZE的值是256

要访问不同的部分,你应该沿着长度1024的轴切片张量(但我不知道不同种类的权重是以何种顺序排列的存储...)

+0

哦,这是真的。谢谢,我可以放松。 – Roosh

相关问题