2017-04-05 70 views
4

我正在使用Tensorflow Java Api将已创建的Tensorflow模型加载到JVM中。 我用这作为一个例子:tensorflow/examples/LabelImage.java什么是Python中的Tensorflow Java Api toGraphDef`相当于什么?

这里是我的简单的Scala代码:

import java.nio.file.{Files, Path, Paths} 
import org.tensorflow.{Graph, Session, Tensor} 

def readAllBytesOrExit(path: Path): Array[Byte] = Files.readAllBytes(path) 
val graphDef = readAllBytesOrExit(Paths.get("PATH_TO_A_SINGLE_FILE_DESCRIBING_TF_MODEL.pb")) 
val g = new Graph() 
g.importGraphDef(graphDef) 
val session = new Session(g) 
val result: Tensor = session.runner().feed("input", image).fetch("output").run().get(0)) 

我如何保存我的模型来获得Session和存储在同一个文件中的图表。如上面“PATH_TO_A_SINGLE_FILE_DESCRIBING_TF_MODEL.pb”中所述。

描述here它提到:

曲线图的序列化表示,通常被称为一个 GraphDef,可以通过toGraphDef生成()并且在其它 语言API当量。

其他语言API中的等价物是什么?我没有发现它明显

注意:我已经看过tensorflow_serving下的mnist_saved_model.py,但通过该过程保存给我一个.pb文件和一个variables文件夹。当我试图加载.pb文件时,我得到:java.lang.IllegalArgumentException: Invalid GraphDef

+0

我试图使用https://www.tensorflow.org/api_docs/python/tf/GraphDef#SerializeToString,这有意义将图加载到会话中,但运行会话时变量不在那里。 –

回答

1

当前使用tensorflow的Java API,我只找到了如何将graph保存为graphDef(即没有它的变量和元数据)。这可以通过刚写入阵列[字节]对文件进行:

Files.write(Paths.get(modelDir, modelName), myGraph.toGraphDef) 

这里myGraphGraph class的Java对象。

我建议使用此处定义的SavedModel api从Python API保存模型。它会将你的模型保存在一个文件夹中,其中序列化图形在.pb文件中,而变量在文件夹中。请注意您使用的tag_constants,因为您需要在scala/java代码中使用它来使用变量加载模型。然后,可以使用java api中的SavedModelBundle Java类轻松加载带变量的图表和会话。它会返回一个包装用图形和包含的变量值的会话都:

val model = SavedModelBundle.load(modelDir, modelTag) 

如果你已经尝试过这一点,也许你可以分享你的代码,看看它为什么返回无效GraphDef。

另一种方法是冻结图形,即将变量节点变为常量节点,以便所有内容都包含在.pb文件中。详细信息here冻结部分

+0

'SavedModelBundle'尚未发布,但我可以尝试重新编译并使用它。 [api_docs](https://www.tensorflow.org/api_docs/) 我会尝试冻结我的模型,看看是否有效。谢谢! –

+1

你是对的@DanielHasegan,你只需要从源代码构建java api依赖[here](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/java#building-from-source) – nicodri

+0

我设法服务于冻结的模型,并将我的本地库重新编译为最新的1.1版本。谢谢您的帮助 –

相关问题