2017-07-17 35 views
4

如果我可以在保存和恢复LSTM方面得到一些帮助,我将非常感激。如何恢复LSTM层

我有这个LSTM层 -

# LSTM cell 
cell = tf.contrib.rnn.LSTMCell(n_hidden) 
output, current_state = tf.nn.dynamic_rnn(cell, word_vectors, dtype=tf.float32) 

outputs = tf.transpose(output, [1, 0, 2]) 
last = tf.gather(outputs, int(outputs.get_shape()[0]) - 1) 

# Saver function 
saver = tf.train.Saver() 
saver.save(sess, 'test-model') 

金丹节省模式,可以让我保存和恢复LSTM的重量和偏见。但是,我需要恢复此LSTM图层并为其提供一组新的输入。

要恢复整个模型,我做:

with tf.Session() as sess: 
    saver = tf.train.import_meta_graph('test-model.meta') 
    saver.restore(sess, tf.train.latest_checkpoint('./')) 
  1. 是否有可能对我来说,与预先训练的权重和偏见初始化LSTM细胞?

  2. 如果不是,我该如何恢复这个LSTM层?

非常感谢!

回答

1

您已经加载模型,因此模型的权重。您只需使用get_tensor_by_name即可从图表中获取任何张量并将其用于推断。

实施例:

with tf.Session() as sess: 
    saver = tf.train.import_meta_graph('test-model.meta') 
    saver.restore(sess, tf.train.latest_checkpoint('./')) 

    # Get the tensors by their variable name 
    word_vec = = detection_graph.get_tensor_by_name('word_vec:0') 
    output_tensor = detection_graph.get_tensor_by_name('outputs:0') 

    sess.run(output_tensor, feed_dict={word_vec: ...}) 

在上面的例子word_vecoutputs是创建图的期间分配给该张量名称。确保你指定了名字,以便他们可以通过他们的名字进行调用。

+0

非常感谢您回答我的问题!对此,我真的非常感激。所以我不需要再次通过LSTM传递'word_vec'?这是如何工作的? – AnnaR

+0

它只是一个例子,你应该通过feed_dict传递你在图中定义的输入。 –

+0

非常感谢! – AnnaR