2016-03-02 47 views
1
import tensorflow as tf 
import numpy as np 

isTrain = tf.placeholder(tf.bool) 
user_input = tf.placeholder(tf.float32) 

# ema = tf.train.ExponentialMovingAverage(decay=.5) 

with tf.device('/cpu:0'): 
    beta = tf.Variable(tf.ones([1])) 

    batch_mean = beta.assign(user_input) 
    ema = tf.train.ExponentialMovingAverage(decay=0.5) 
    ema_apply_op = ema.apply([batch_mean]) 
    ema_mean = ema.average(batch_mean) 

    def mean_var_with_update(): 
     with tf.control_dependencies([ema_apply_op]): 
      return tf.identity(batch_mean) 

    mean = tf.cond(isTrain, 
     mean_var_with_update, 
     lambda: (ema_mean)) 

# ======= End Here ========== 
saver = tf.train.Saver() 
init = tf.initialize_all_variables() 
sess = tf.Session() 
sess.run(init) 

u_input = [[2], [3], [4] ] 
for u in u_input: 
    aa = sess.run([mean], feed_dict={user_input:u, isTrain: True }) 
    print("Train", aa) 

for u in u_input: 
    aa = sess.run([ema_mean], feed_dict={user_input:u, isTrain: False }) 
    print("Test correct", aa) 

for u in u_input: 
    aa = sess.run([mean], feed_dict={user_input:u, isTrain: False }) 
    print("Test", aa) 

该代码片段应该计算在测试阶段期间训练阶段和输出平均值的user_input平均值。Tensorflow tf.cond评估两个脚蹬

这是输出结果:

('Train', [array([ 2.], dtype=float32)]) 
('Train', [array([ 3.], dtype=float32)]) 
('Train', [array([ 4.], dtype=float32)]) 
('Test correct', [array([ 3.], dtype=float32)]) 
('Test correct', [array([ 3.], dtype=float32)]) 
('Test correct', [array([ 3.], dtype=float32)]) 
('Test', [array([ 2.5], dtype=float32)]) 
('Test', [array([ 2.75], dtype=float32)]) 
('Test', [array([ 3.375], dtype=float32)]) 

但是,调用sess.run([mean])即使isTrain = Falseema_mean始终得到评估。

代码中是否有任何错误? tensorflow版本是0.7.1

回答

0

我已经添加了一些日志报表和ema_mean似乎只能当isTrain是假

tf.reset_default_graph() 

isTrain = tf.placeholder(tf.bool) 
user_input = tf.placeholder(tf.float32) 

# ema = tf.train.ExponentialMovingAverage(decay=.5) 

with tf.device('/cpu:0'): 
    beta = tf.Variable(tf.ones([1])) 

    batch_mean = beta.assign(user_input) 
    ema = tf.train.ExponentialMovingAverage(decay=0.5) 
    ema_apply_op = ema.apply([batch_mean]) 
    ema_mean = ema.average(batch_mean) 

    def mean_var_with_update(): 
     with tf.control_dependencies([ema_apply_op]): 
      return tf.Print(tf.identity(batch_mean), ["mean_var_with_update"]) 
      #return tf.identity(batch_mean) 

    mean = tf.Print(tf.cond(isTrain, 
     mean_var_with_update, 
     lambda: (tf.Print(ema_mean, ["ema_mean"]))), 
        ["evaluating mean", isTrain]) 

# ======= End Here ========== 
saver = tf.train.Saver() 
init = tf.initialize_all_variables() 
sess = tf.Session() 
sess.run(init) 

u_input = [[2], [3], [4] ] 
for u in u_input: 
    aa = sess.run([mean], feed_dict={user_input:u, isTrain: True }) 
    print("Train", aa) 

for u in u_input: 
    aa = sess.run([ema_mean], feed_dict={user_input:u, isTrain: False }) 
    print("Test correct", aa) 

for u in u_input: 
    aa = sess.run([mean], feed_dict={user_input:u, isTrain: False }) 
    print("Test", aa) 

你看

[mean_var_with_update] 
[evaluating mean][True] 
[mean_var_with_update] 
[evaluating mean][True] 
[mean_var_with_update] 
[evaluating mean][True] 
[ema_mean] 
[evaluating mean][False] 
[ema_mean] 
[evaluating mean][False] 
[ema_mean] 
[evaluating mean][False] 

注意,Print说法就是要评价在评估完所有输入后进行评估,以便最后打印外部打印语句

+0

https://gist.github.com/24hours/0545f92d5407bcdd3106如果'isTrain = False'不应评估'batch_mean',如果'isTrain = True'不应评估'ema_mean',是火车的价值。这是正确的行为? – 24hours

+0

@ 24小时:这可能是[这个问题](https://stackoverflow.com/questions/37063952/confused-by-the-behavior-of-tf-cond)。尝试在'mean_var_with_update'函数内创建'ema_apply_op'。 – Albert

1

我认为这与answered here相同。条件内的tf.control_dependencies会将相关性添加到tf.cond本身。

因此,请尝试在mean_var_with_update函数内创建ema_apply_op