2017-07-24 69 views
1

我试图通过遵循运行时统计指令here来获得我的张量流代码配置文件(网络中每层的运行和内存消耗)。据我了解,我需要创建运行选项像这样通过tensorflow监测训练获得运行时统计信息

run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE) 
run_metadata = tf.RunMetadata() 

运行的元数据,并将其传递给sess.run

然而,正如我也尝试使用tf.train.MonitoredTrainingSession我不知道如果我可以将同样的事情传递给这个班级。一个合理的方法可以使用钩子,但我不知道如何去做。我对他们还很陌生

回答

2

你可以简单地创建一个自定义钩子并将它传递给MonitoredTrainingSession。无需将您自己的tf.RunMetadata()实例传递给运行调用。

下面是一个例子钩,其存储每N个步骤ckptdir元数据:

import tensorflow as tf 

class TraceHook(tf.train.SessionRunHook): 
    """Hook to perform Traces every N steps.""" 

    def __init__(self, ckptdir, every_step=50, trace_level=tf.RunOptions.FULL_TRACE): 
     self._trace = every_step == 1 
     self.writer = tf.summary.FileWriter(ckptdir) 
     self.trace_level = trace_level 
     self.every_step = every_step 

    def begin(self): 
     self._global_step_tensor = tf.train.get_global_step() 
     if self._global_step_tensor is None: 
      raise RuntimeError("Global step should be created to use _TraceHook.") 

    def before_run(self, run_context): 
     if self._trace: 
      options = tf.RunOptions(trace_level=self.trace_level) 
     else: 
      options = None 
     return tf.train.SessionRunArgs(fetches=self._global_step_tensor, 
             options=options) 

    def after_run(self, run_context, run_values): 
     global_step = run_values.results - 1 
     if self._trace: 
      self._trace = False 
      self.writer.add_run_metadata(run_values.run_metadata, 
             f'{global_step}', global_step) 
     if not (global_step + 1) % self.every_step: 
      self._trace = True 

它检查在before_run它是否有跟踪与否,如果是,增加了RunOptions。在after_run它检查是否需要跟踪下一个运行调用,如果是,它会再次将_trace设置为True。此外,它在元数据可用时存储元数据。