2017-08-18 63 views
0

我试图使用两个不同的mobilenet模型。以下是我如何初始化模型的代码。在张量流中使用两个不同的模型

def initialSetup(): 
    os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' 
    start_time = timeit.default_timer() 

    # This takes 2-5 seconds to run 
    # Unpersists graph from file 
    with tf.gfile.FastGFile('age/output_graph.pb', 'rb') as f: 
     age_graph_def = tf.GraphDef() 
     age_graph_def.ParseFromString(f.read()) 
     tf.import_graph_def(age_graph_def, name='') 

    with tf.gfile.FastGFile('output_graph.pb', 'rb') as f: 
     gender_graph_def = tf.GraphDef() 
     gender_graph_def.ParseFromString(f.read()) 
     tf.import_graph_def(gender_graph_def, name='') 

    print ('Took {} seconds to unpersist the graph'.format(timeit.default_timer() - start_time)) 

因为两者都是两种不同的模型,我如何将它用于预测?

UPDATE

initialSetup() 

age_session = tf.Session(graph=age_graph_def) 
gender_session = tf.Session(graph=gender_graph_def) 

with tf.Session() as sess: 
    start_time = timeit.default_timer() 

    # Feed the image_data as input to the graph and get first prediction 
    softmax_tensor = age_session.graph.get_tensor_by_name('final_result:0') 

    print ('Took {} seconds to feed data to graph'.format(timeit.default_timer() - start_time)) 

    while True: 
     # Capture frame-by-frame 
     ret, frame = video_capture.read() 

错误

Traceback (most recent call last): File "C:/Users/Desktop/untitled/testimg/testimg/combo.py", line 48, in age_session = tf.Session(graph=age_graph_def) File "C:\Program Files\Anaconda3\lib\site-packages\tensorflow\python\client\session.py", line 1292, in init super(Session, self).init(target, graph, config=config) File "C:\Program Files\Anaconda3\lib\site-packages\tensorflow\python\client\session.py", line 529, in init raise TypeError('graph must be a tf.Graph, but got %s' % type(graph)) TypeError: graph must be a tf.Graph, but got Exception ignored in: > Traceback (most recent call last): File "C:\Program Files\Anaconda3\lib\site-packages\tensorflow\python\client\session.py", line 587, in del if self._session is not None: AttributeError: 'Session' object has no attribute '_session'

+0

你成功地使用过这种方式加载单个模型吗?通常的做法是将不同的非空'name'参数传递给每个'tf.import_graph_def()'调用,然后将这些名称用作每个模型中要提供和读取的特定张量的前缀。 – mrry

+0

是单独的它正在working.if我添加的名称,它说,没有这样的张量 –

+0

您可以添加您用来调用会话的代码,并打印完整的错误?如果给导入的图形添加一个'name',则需要在该图形中使用的任何张量名称的前缀名称为'name',后跟一个'/'。 – mrry

回答

2

当您在同一个图表中使用多个模型时,使用名称范围来给出单个张量的可预测名称。例如,你可以重写initial_setup()如下:

def initialSetup(): 
    os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' 
    start_time = timeit.default_timer() 

    # This takes 2-5 seconds to run 
    # Unpersists graph from file 
    with tf.gfile.FastGFile('age/output_graph.pb', 'rb') as f: 
     age_graph_def = tf.GraphDef() 
     age_graph_def.ParseFromString(f.read()) 
     tf.import_graph_def(age_graph_def, name='age_model') 

    with tf.gfile.FastGFile('output_graph.pb', 'rb') as f: 
     gender_graph_def = tf.GraphDef() 
     gender_graph_def.ParseFromString(f.read()) 
     tf.import_graph_def(gender_graph_def, name='gender_model') 

    print ('Took {} seconds to unpersist the graph'.format(timeit.default_timer() - start_time)) 

现在所有节点的从age_graph_def名字将与"age_model/"前缀和所有节点的从gender_graph_def的名字将与"gender_model/"前缀。它们都是同一个默认图形的一部分,因此您可以使用单个tf.Session而不使用参数来访问任一模型。

initialSetup() 

with tf.Session() as sess: 
    start_time = timeit.default_timer() 

    # Feed the image_data as input to the graph and get first prediction 
    softmax_tensor = sess.graph.get_tensor_by_name('age_model/final_result:0') 

    # Alternatively, to get a tensor from the gender model: 
    # tensor = sess.graph.get_tensor_by_name('gender_model/...') 

    print ('Took {} seconds to feed data to graph'.format(timeit.default_timer() - start_time)) 

    while True: 
     # Capture frame-by-frame 
     ret, frame = video_capture.read() 
+0

谢谢它的作品..但是框架现在有点落后...有什么方法可以提高速度吗? –

+0

如果我将两个模型的两个类和再培训结合起来,它会影响准确性吗? –

1

tf.Session需要tf.Graph实例不tf.GraphDef,接下来的步骤解决问题。

def initialSetup(): 
    with tf.gfile.FastGFile('age/output_graph.pb', 'rb') as f: 
     age_graph_def = tf.GraphDef() 
     age_graph_def.ParseFromString(f.read()) 
     with tf.Graph().as_default() as graph: 
      tf.import_graph_def(age_graph_def, name='') 
      age_graph = graph 

    ... 
    return age_graph, gender_graph 

age_graph, gender_graph = initial_setup() 
age_session = tf.Session(graph=age_graph) 
... 
# also delete the following line, as it creates another new context 
with tf.Session() as sess: 
+0

现在它说,'KeyError:'名称'final_result:0'是指一个不存在的张量,操作'final_result'在图中不存在。 ' –

+0

然而这个张量已经存在于图中 –

+0

它没有打印任何东西:( –