你可以简单地创建一个自定义钩子并将它传递给MonitoredTrainingSession
。无需将您自己的tf.RunMetadata()
实例传递给运行调用。
下面是一个例子钩,其存储每N个步骤ckptdir元数据:
import tensorflow as tf
class TraceHook(tf.train.SessionRunHook):
"""Hook to perform Traces every N steps."""
def __init__(self, ckptdir, every_step=50, trace_level=tf.RunOptions.FULL_TRACE):
self._trace = every_step == 1
self.writer = tf.summary.FileWriter(ckptdir)
self.trace_level = trace_level
self.every_step = every_step
def begin(self):
self._global_step_tensor = tf.train.get_global_step()
if self._global_step_tensor is None:
raise RuntimeError("Global step should be created to use _TraceHook.")
def before_run(self, run_context):
if self._trace:
options = tf.RunOptions(trace_level=self.trace_level)
else:
options = None
return tf.train.SessionRunArgs(fetches=self._global_step_tensor,
options=options)
def after_run(self, run_context, run_values):
global_step = run_values.results - 1
if self._trace:
self._trace = False
self.writer.add_run_metadata(run_values.run_metadata,
f'{global_step}', global_step)
if not (global_step + 1) % self.every_step:
self._trace = True
它检查在before_run
它是否有跟踪与否,如果是,增加了RunOptions。在after_run
它检查是否需要跟踪下一个运行调用,如果是,它会再次将_trace
设置为True。此外,它在元数据可用时存储元数据。