2016-08-18 48 views
3

我训练了一个递归神经网络(LSTM)并保存了权重和metagraph。当我检索metagraph进行预测时,只要序列长度与训练过程中的序列长度相同,一切都可以正常工作。Tensorflow:检索metagraph时修改占位符的形状

LSTM的好处之一是输入的序列长度可以变化(例如,如果输入是形成句子的字母,则句子的长度可以变化)。

从metagraph中检索图形时,如何更改输入的序列长度?

更多细节与代码:

在培训过程中,我使用占位符xy养活数据。对于预测,我检索这些占位符,但无法设法更改其形状(从[None, previous_sequence_length=100, n_input][None, new_sequence_length=50, n_input])。

在文件model.py,定义体系结构和占位符:

self.x = tf.placeholder("float32", [None, self.n_steps, self.n_input], name='x_input') 
self.y = tf.placeholder("float32", [None, self.n_classes], name='y_labels') 
tf.add_to_collection('x', self.x) 
tf.add_to_collection('y', self.y) 
... 

def build_model(self): 
    #using the placeholder self.x to build the model 
    ... 
    tf.split(0, self.n_input, self.x) # split input for RNN cell 
    ... 

在文件prediction.py,我检索预测元图:

with tf.Session() as sess: 
    latest_checkpoint = tf.train.latest_checkpoint(checkpoint_dir=checkpoint_dir) 
    new_saver = tf.train.import_meta_graph(latest_checkpoint + '.meta') 
    new_saver.restore(sess, latest_checkpoint) 
    x = tf.get_collection('x')[0] 
    y = tf.get_collection('y')[0] 
    ... 
    sess.run(..., feed_dict={x: batch_x}) 

这里是我的错误:

ValueError: Cannot feed value of shape (128, 50, 2) for Tensor u'placeholders/x_input:0', which has shape '(?, 100, 2)' 

注:我设法解决这个问题n 不使用metagraph,而是重新从头开始重建模型并仅加载保存的权重(而不是元数据图)。

编辑:与None更换self.n_steps和修改tf.split(0, self.n_input, self.x)tf.split(0, self.x.get_shape()[1], self.x)时,我得到了以下错误:TypeError: Expected int for argument 'num_split' not Dimension(None).

+0

你通常不能在事实后改变张量的形状。但是,您可以做的一件事是*不*修正训练过程中所有维度的形状,但不指定它们。您提供的张量的尺寸必须与占位符的形状兼容,但您不必强制首先指定所有占位符尺寸。在这里,尝试将“无”替换为self.n_steps。 –

+0

我在发布问题之前就尝试过这样做,但在创建模型期间的某个时候,我有'tf.split(0,self.n_input,self.x)'。当我不知道/修正'self.n_input'时,我把'self.x.get_shape()[1]'('tf.split(0,self.x.get_shape()[1],self) x)的')。但是,我得到以下错误:'TypeError:参数'num_split'的预期int不是Dimension(无).'。 – BiBi

回答

2

当你定义varible,我建议你把它写如下

[None, None, n_input] 

代替:

[None, new_sequence_length=50, n_input] 

它适用于我的情况。我希望它有帮助

+0

我试过了(参考初始文章中的评论),但由于'tf.split'函数在输入时需要使用分割数量,所以在此解决方案中为“无”,因此它不起作用。 – BiBi