2017-04-19 36 views
0

我想在训练期间预加载训练数据CNN in TF我的简单实现如下。但是,我发现一个奇怪的现象。这似乎是一个同步过程。装载一批数据的时间成本几乎相同,无论是PRE_FETCHTrue还是FalseTF中的预加载数据

class Demo(object): 
    def __init__(self): 
     self._name = 'demo' 

    def load_batch(self): 
     ... 

    def prefetch(self, func): 
     while True: 
      data = func() 
      self.queue.put(data) 

    def train(self): 
     input_data = tf.placeholder(tf.float32, shape=[B, H, W, C]) 
     optim_op = build_model(input_data) 

     if PRE_FETCH: 
      self.queue = Queue(30) 
      self.process = Process(target=self.prefetch, args=(self.load_batch)) 
      self.process.start() 
      def cleanup(): 
       self.process.terminate() 
       self.process.join() 
      import atexit 
      atexit.register(cleanup) 
     sess = tf.Session() 
     i = 1 
     while i < MAX_ITER_SIZE: 
      if PRE_FETCH: 
       start = time.time() 
       tmp = self.queue.get() 
       end = time.time() 
       print 'load data time: ', (end - start) 
      else: 
       start = time.time() 
       tmp = self.load_batch() 
       end = time.time() 
       print 'load data time: ', (end - start) 
      sess.run(optim_op, feed_dict={input_data: tmp} 

回答

0

需要花费时间的是通过占位符将数据加载到图中。如果你希望你的预加载有效,你应该调查替换你的python队列并用tensorflow图中的操作线程mecanisme。在tensorflow网站上有一个很好的教程:https://www.tensorflow.org/programmers_guide/reading_data