2017-03-08 133 views
2

我正在测试mxnet的RNN模型。教程here不起作用,错误消息表示许多功能已被弃用。我没有找到最新的RNN教程。 mxnet项目中还有一些示例。但是对于RNN,examples仅显示如何使用训练集来训练模型。他们没有展示如何使用训练好的模型进行进一步的预测。训练代码如下:mxnet:如何使用训练有素的RNN模型进行预测

model.fit(
    train_data   = data_train, 
    eval_data   = data_val, 
    eval_metric   = mx.metric.Perplexity(invalid_label), 
    kvstore    = args.kv_store, 
    optimizer   = args.optimizer, 
    optimizer_params = { 'learning_rate': args.lr, 
          'momentum': args.mom, 
          'wd': args.wd }, 
    initializer   = mx.init.Xavier(factor_type="in", magnitude=2.34), 
    num_epoch   = args.num_epochs, 
    batch_end_callback = mx.callback.Speedometer(args.batch_size, args.disp_batches)) 

是否有人知道如何使用的培训RNN模型作出推断或预测?

我必须明白,我正在寻找如何使用RNN模型作出预测,而不是CNN或其他模型。

非常感谢您的帮助!

+0

https://github.com/dmlc/mxnet/blob/master/example/rnn/cudnn_lstm_bucketing.py既有列车和测试代码。这有帮助吗? –

+2

不可以。但https://github.com/dmlc/mxnet/tree/master/python/mxnet/module中的示例确实有帮助。 – pfc

+1

@pfc如果你找到了答案,你会回答你自己的问题,可能需要相同的帮助吗? – lynguyen

回答

1

通常模型是扩展BaseModel类。而BaseModel有the method predict。该方法可以使用与fit方法使用的相同类型:DataIter只有一个区别,它不需要train_data,只有eval_data。所以实际的预测过程可以以简单的方式来实现这样的:

result = mod.predict(dataiter.next) 
相关问题