2017-06-26 122 views
2

我一直在使用tensorflow的LinearClassifier()班培养了逻辑回归模型的模型,并设置model_dir参数,它指定的位置在哪里模型训练过程中保存检查站的metagrahps :如何从tensorflow高层API恢复训练的LinearClassifier并作出预测

# Create temporary directory where metagraphs will evenually be saved 
model_dir = tempfile.mkdtemp() 

logistic_model = tf.contrib.learn.LinearClassifier(
    feature_columns=feature_columns, 
    n_classes=num_labels, model_dir=model_dir) 

我一直在阅读有关从metagraphs恢复模式,但没有发现任何有关如何使用高级API创建的模型这样做。 LinearClassifier()的预测()函数,但我不能找到如何使用已通过关卡元图恢复模型的实例来运行预测任何文件。我会如何去做这件事?一旦模型被恢复,我的理解是,我有tf.Sess对象,它缺乏所有建在LinearClassifier类的功能,像这样的工作:

with tf.Session() as sess: 
    new_saver = tf.train.import_meta_graph('my-save-dir/my-model-10000.meta') 
    new_saver.restore(sess, 'my-save-dir/my-model-10000') 
    # Run prediction algorithm... 

如何运行相同的预测所使用的高级API,使一个恢复模型预测算法?有没有更好的方法来解决这个问题?

感谢您的输入。

+0

建议的解决方案适合您吗? –

+0

潜在的,我实现了你的建议的修复,但仍需要确认它的工作原理。现在超级淹没,当我有机会时会报告回来。谢谢你的帮助。 –

回答

1

LinearClassifier()有“model_dir” PARAM,如果在指向一个训练模型将恢复模型。
在训练过程中,你做:

logistic_model = tf.contrib.learn.LinearClassifier(feature_columns=feature_columns, n_classes=num_labels, model_dir=model_dir) 
classifier.fit(X_train, y_train, steps=10) 

在推理,LinearClassifier()将加载从给出的路径训练模型,并且不使用fit()方法,但调用predict()方法:

logistic_model = tf.contrib.learn.LinearClassifier(feature_columns=feature_columns, n_classes=num_labels, model_dir=model_dir) 
y_pred = classifier.predict(X_test) 
相关问题