2016-08-26 55 views
5

我想要做的是同时运行多个预先训练的Tensorflow网络。由于每个网络中的一些变量的名称可以相同,因此常见的解决方案是在创建网络时使用名称范围。但问题是我已经训练了这些模型并将训练好的变量保存在多个检查点文件中。在我创建网络时使用名称范围后,我无法从检查点文件加载变量。同时运行多个预先训练的Tensorflow网络

例如,我已经培训了一个AlexNet,我想比较两组变量,一组来自历元10(保存在文件epoch_10.ckpt中),另一组来自历元50(保存在文件epoch_50.ckpt)。因为这两个网络完全相同,所以里面的变量名称是相同的。我可以用

with tf.name_scope("net1"): 
    net1 = CreateAlexNet() 
with tf.name_scope("net2"): 
    net2 = CreateAlexNet() 

然而,因为当我训练的这个网,我没有用一个名字的范围,我不能加载从.ckpt文件训练有素的变量创建两个网。尽管我可以在训练网络时将名称范围设置为“net1”,但这会阻止我加载net2的变量。

我曾尝试:

with tf.name_scope("net1"): 
    mySaver.restore(sess, 'epoch_10.ckpt') 
with tf.name_scope("net2"): 
    mySaver.restore(sess, 'epoch_50.ckpt') 

这是行不通的。

解决此问题的最佳方法是什么?

回答

10

最简单的解决方案是创建一个使用单独的图形各型号不同的会话:

# Build a graph containing `net1`. 
with tf.Graph().as_default() as net1_graph: 
    net1 = CreateAlexNet() 
    saver1 = tf.train.Saver(...) 
sess1 = tf.Session(graph=net1_graph) 
saver1.restore(sess1, 'epoch_10.ckpt') 

# Build a separate graph containing `net2`. 
with tf.Graph().as_default() as net2_graph: 
    net2 = CreateAlexNet() 
    saver2 = tf.train.Saver(...) 
sess2 = tf.Session(graph=net1_graph) 
saver2.restore(sess2, 'epoch_50.ckpt') 

如果这不出于某种原因,你必须使用一个tf.Session(如因为你要的结果从两个网络中的另一TensorFlow计算相结合),最好的解决办法是:

  1. 创建名称范围的不同的网络,你已经这样做,和
  2. 为两个网络创建单独的tf.train.Saver实例,并使用附加参数重新映射变量名称。

constructing的储户,就可以通过一本字典为var_list说法,在检查点(即没有名称范围前缀)以您在每个模型创建的tf.Variable对象映射变量的名称。

你可以建设var_list编程,你应该能够做到像下面这样:

with tf.name_scope("net1"): 
    net1 = CreateAlexNet() 
with tf.name_scope("net2"): 
    net2 = CreateAlexNet() 

# Strip off the "net1/" prefix to get the names of the variables in the checkpoint. 
net1_varlist = {v.name.lstrip("net1/"): v 
       for v in tf.get_collection(tf.GraphKeys.VARIABLES, scope="net1/")} 
net1_saver = tf.train.Saver(var_list=net1_varlist) 

# Strip off the "net2/" prefix to get the names of the variables in the checkpoint. 
net2_varlist = {v.name.lstrip("net2/"): v 
       for v in tf.get_collection(tf.GraphKeys.VARIABLES, scope="net2/")} 
net2_saver = tf.train.Saver(var_list=net2_varlist) 

# ... 
net1_saver.restore(sess, "epoch_10.ckpt") 
net2_saver.restore(sess, "epoch_50.ckpt") 
+0

太棒了! – denru

+0

使用lstrip剥离前缀可能会导致错误的结果。请使用切片代替。代码的其他部分完美地工作。另一个问题是,我发现一个变量的名称有一个像“:0”,“:1”的后缀。在将变量存储到检查点文件之前,我需要摆脱这个后缀吗? – denru

+0

任何人都试过这个答案?我遇到的问题与'恢复'功能没有做任何事情:http://stackoverflow.com/questions/41607144/loading-two-models-from-saver-in-the-same-tensorflow-session – TheCriticalImperitive