2016-11-28 29 views
1

我已经意识到Tensorflow似乎在管理图形的方式上有一些时髦的东西。由于构建(和重建)模型非常繁琐,我决定将自定义模型包装在一个类中,以便我可以在其他地方轻松地重新实例化它。Tensorflow如何管理图形?

当我在训练和测试代码时(在原来的地方),它会工作的很好,但是在我加载图形变量的代码中,我会得到各种奇怪的错误 - 变量重定义和其他一切。这个(从我最后一个类似的问题)提示,一切都被称为两次。

做了跟踪TON后,它回到了我使用加载的代码的方式。它正在从一个类,有一个结构,像这样

class MyModelUser(object): 
    def forecast(self): 
     # .. build the model in the same way as in the training code 
     # load the model checkpoint 
     # call the "predict" function on the model 
     # manipulate the prediction and return it 

然后在一些代码,使用MyModelUser我有

def test_the_model(self): 
    model_user = MyModelUser() 
    print(model_user.forecast()) # 1 
    print(model_user.forecast()) # 2 

和我(显然)有望看到两个预测这时候内使用被称为。相反,第一个预测被称为按预期工作,但第二个电话扔变量重用的TON ValueError异常的这些中的一个例子是:

ValueError: Variable weight_def/weights already exists, disallowed. Did you mean to set reuse=True in VarScope? 

我设法通过增加一系列平息错误试图/使用get_variable创建变量的块除外,然后在例外情况下,在范围上调用reuse_variables,然后在名称上调用get_variable。这带来了一套新的严重的错误,其中之一就是:

tensorflow.python.framework.errors.NotFoundError: Tensor name "weight_def/weights/Adam_1" not found in checkpoint files 

一时心血来潮我说:“如果我的造型建筑物代码移到__init__所以其只内置了一次?”

我的新机型的用户:

class MyModelUser(object): 
    def __init__(self): 
     # ... build the model in the same way as in the training code 
     # load the model checkpoint 


    def forecast(self): 
     # call the "predict" function on the model 
     # manipulate the prediction and return it 

现在:

def test_the_model(self): 
    model_user = MyModelUser() 
    print(model_user.forecast()) # 1 
    print(model_user.forecast()) # 2 

按预期工作,印花两大预测没有错误。这使我相信我也可以摆脱可变重用的东西。

我的问题是这样的:

这是为什么解决它?从理论上讲,应该在原始预测方法中每次都重新调整图形,因此它不应该创建多个图形。即使函数完成后,Tensorflow是否仍然保持图形?这就是为什么将创建代码移动到__init__工作?这让我无望地感到困惑。

回答

2

默认情况下,TensorFlow使用首次调用TensorFlow API时创建的单个全局tf.Graph实例。如果您不明确创建tf.Graph,则将在该默认实例中创建所有操作,张量和变量。这意味着您在model_user.forecast()的代码中的每个调用都会将操作添加到同一个全局图中,这有点浪费。

有(至少)动作的两种可能的课程在这里:

  • 理想的行动是调整你的代码,以便MyModelUser.__init__()构建整个tf.Graph所有进行预测所需要的操作,而MyModelUser.forecast()只需在现有图上执行sess.run()调用。理想情况下,您也只能创建一个tf.Session,因为TensorFlow会在会话中缓存关于图形的信息,并且执行效率会更高。

  • 的创伤更小—但可能不太有效—变化将是创建一个新的tf.Graph每次调用MyModelUser.forecast()。这是由很多国家是如何在MyModelUser.__init__()方法创建的问题尚不清楚,但你可以不喜欢下面把两个调用不同的图表:

    def test_the_model(self): 
        with tf.Graph(): # Create a local graph 
        model_user_1 = MyModelUser() 
        print(model_user_1.forecast()) 
        with tf.Graph(): # Create another local graph 
        model_user_2 = MyModelUser() 
        print(model_user_2.forecast()) 
    
0

TF有一个默认图表,新的操作等被添加到。当你调用你的函数两次时,你会将同样的东西两次添加到同一个图中。因此,无论是构建一次图并多次评估(就像你已经完成的那样,这也是“正常”方法),或者,如果你想改变一些东西,你可以使用reset_default_graph https://www.tensorflow.org/versions/r0.11/api_docs/python/framework.html#reset_default_graph来重置图,以便拥有一个新鲜的状态。