2017-06-23 110 views
0

我已经训练了一个使用Python使用Tensorflow实现的网络。最后,我用tf.train.Saver()保存了模型。现在我想用C++来使用这个预先训练好的网络进行预测。使用Tensorflow检查点在C++中恢复模型

我该怎么做?有没有办法转换检查点,所以我可以使用tiny-dnn或Tensorflow C++?

欢迎任何想法:)谢谢!

回答

5

您可能应该在SavedModel format中导出模型,它封装了计算图和保存的变量(tf.train.Saver仅保存变量,因此您必须保存图)。

然后,您可以使用LoadSavedModel以C++加载保存的模型。

确切的调用将取决于您的模型的输入和输出是什么。但是,Python代码会看起来像这样:

# You'd adjust the arguments here according to your model 
signature = tf.saved_model.signature_def_utils.predict_signature_def(                  
    inputs={'image': input_tensor}, outputs={'scores': output_tensor})                   


builder = tf.saved_model.builder.SavedModelBuilder('/tmp/my_saved_model')                  

builder.add_meta_graph_and_variables(                          
    sess=sess,                              
    tags=[tf.saved_model.tag_constants.SERVING],                        
    signature_def_map={                          
tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY:                 
     signature                               
})                                  

builder.save() 

然后在C++中,你会做这样的事情:

tensorflow::SavedModelBundle model; 
auto status = tensorflow::LoadSavedModel(session_options, run_options, "/tmp/my_saved_model", {tensorflow::kSavedModelTagServe}, &model); 
if (!status.ok()) { 
    std::cerr << "Failed: " << status; 
    return; 
} 
// At this point you can use model.session 

(注意,使用SavedModel格式也可以让您以服务模式使用TensorFlow Serving,如果这对您的应用程序有意义)

希望有所帮助。

+0

看起来像那正是我需要的!我会尝试一下,我会告诉你!谢谢 ! –

相关问题