2016-11-30 33 views
0

我试图在TensorFlow中使用javacpp-presets在java中运行TensorFlow培训。我已经使用tf.train.write_graph(sess.graph_def, '.', 'example.pb', as_text=False)生成了一个.pb文件,如下所示。运行导入张量流图对未初始化的变量失败

import tensorflow as tf 
import numpy as np 

x_data = np.random.rand(100).astype(np.float32) 
y_data = x_data * 0.1 + 0.3 
Weights = tf.Variable(tf.random_uniform([1], -1.0, 1.0), name='Weights') 
biases = tf.Variable(tf.zeros([1]), name='biases') 
y = Weights * x_data + biases 
loss = tf.reduce_mean(tf.square(y - y_data)) #compute the loss 
optimizer = tf.train.GradientDescentOptimizer(0.5) 
train = optimizer.minimize(loss, name='train') 
init = tf.global_variables_initializer() 

with tf.Session() as sess: 
    print(sess.run(Weights), sess.run(biases)) 
    tf.train.write_graph(sess.graph_def, '.', 'example.pb', as_text=False) 

我:

异常线程 “main” java.lang.Exception的:试图使用未初始化值权重”

当我运行:

tensorflow.Status s = session.Run(new StringTensorPairVector(new String[] {}, new Tensor[] {}), new tensorflow.StringVector(), new tensorflow.StringVector("train"), outputs); 

加载图后,tensorflow.ReadBinaryProto(Env.Default(), "./example.pb", def);

有没有javacpp-presets API做init = tf.global_variables_initializer()的工作?
或者我可以用来初始化所有变量的任何C++ TensorFlow api?

回答

1

在你的Python程序中,inittf.global_variables_initializer()的结果)是tf.Operation,当时传递给sess.run()。如果在构建Python图时捕获值init.name,那么在运行训练步骤之前,可以在Java程序中将该名称传递给session.Run()

我不是100%肯定什么javacpp-presets API的样子,但我认为你将能够做到这一点是:

tensorflow.Status s = session.Run(
    new StringTensorPairVector(new String[] {}, new Tensor[] {}), 
    new tensorflow.StringVector(), 
    new tensorflow.StringVector(value_of_init_dot_name), 
    outputs); 

...其中value_of_init_dot_name是你所得到的的init.name值来自Python程序。

+0

非常感谢!我把它解释为你的答案。 – linfeng