2016-05-02 105 views
5

我注意到在tensorflow中已经存在批量归一化函数。但我不明白的一件事是如何改变训练和测试之间的程序?张量流中的批量归一化

批处理标准化在测试期间的行为与在培训期间不同。具体来说,在训练期间使用固定均值和方差。

是否有一些很好的示例代码?我看到了一些,但有了范围变量,它变得令人困惑

+0

如'tf.contrib考虑使用来自高层API预先定义的层.layers'。 – danijar

回答

9

你是对的,tf.nn.batch_normalization提供了实现批量标准化的基本功能。您必须添加额外的逻辑来跟踪训练期间的移动均值和差异,并在推理期间使用训练的均值和差异。你可以看一下这个example一个非常普遍实施,但不使用gamma一个快速的版本是在这里:

beta = tf.Variable(tf.zeros(shape), name='beta') 
    moving_mean = tf.Variable(tf.zeros(shape), name='moving_mean', 
           trainable=False) 
    moving_variance = tf.Variable(tf.ones(shape), 
            name='moving_variance', 
            trainable=False) 
    control_inputs = [] 
    if is_training: 
    mean, variance = tf.nn.moments(image, [0, 1, 2]) 
    update_moving_mean = moving_averages.assign_moving_average(
     moving_mean, mean, self.decay) 
    update_moving_variance = moving_averages.assign_moving_average(
     moving_variance, variance, self.decay) 
    control_inputs = [update_moving_mean, update_moving_variance] 
    else: 
    mean = moving_mean 
    variance = moving_variance 
    with tf.control_dependencies(control_inputs): 
    return tf.nn.batch_normalization(
     image, mean=mean, variance=variance, offset=beta, 
     scale=None, variance_epsilon=0.001) 
+0

非常感谢。另一个简单的问题。伽马版本真的更复杂吗?似乎你只需要为它初始化另一个tf.Variable?其余的代码应该是相同的,如果不是的话? – user3358117

+0

是的,您可以按照我提供的用于添加'gamma'的链接中的更一般的实现。 – keveman