0
我想在训练期间预加载训练数据CNN
in TF
我的简单实现如下。但是,我发现一个奇怪的现象。这似乎是一个同步过程。装载一批数据的时间成本几乎相同,无论是PRE_FETCH
是True
还是False
。TF中的预加载数据
class Demo(object):
def __init__(self):
self._name = 'demo'
def load_batch(self):
...
def prefetch(self, func):
while True:
data = func()
self.queue.put(data)
def train(self):
input_data = tf.placeholder(tf.float32, shape=[B, H, W, C])
optim_op = build_model(input_data)
if PRE_FETCH:
self.queue = Queue(30)
self.process = Process(target=self.prefetch, args=(self.load_batch))
self.process.start()
def cleanup():
self.process.terminate()
self.process.join()
import atexit
atexit.register(cleanup)
sess = tf.Session()
i = 1
while i < MAX_ITER_SIZE:
if PRE_FETCH:
start = time.time()
tmp = self.queue.get()
end = time.time()
print 'load data time: ', (end - start)
else:
start = time.time()
tmp = self.load_batch()
end = time.time()
print 'load data time: ', (end - start)
sess.run(optim_op, feed_dict={input_data: tmp}