import tensorflow as tf
import numpy as np
isTrain = tf.placeholder(tf.bool)
user_input = tf.placeholder(tf.float32)
# ema = tf.train.ExponentialMovingAverage(decay=.5)
with tf.device('/cpu:0'):
beta = tf.Variable(tf.ones([1]))
batch_mean = beta.assign(user_input)
ema = tf.train.ExponentialMovingAverage(decay=0.5)
ema_apply_op = ema.apply([batch_mean])
ema_mean = ema.average(batch_mean)
def mean_var_with_update():
with tf.control_dependencies([ema_apply_op]):
return tf.identity(batch_mean)
mean = tf.cond(isTrain,
mean_var_with_update,
lambda: (ema_mean))
# ======= End Here ==========
saver = tf.train.Saver()
init = tf.initialize_all_variables()
sess = tf.Session()
sess.run(init)
u_input = [[2], [3], [4] ]
for u in u_input:
aa = sess.run([mean], feed_dict={user_input:u, isTrain: True })
print("Train", aa)
for u in u_input:
aa = sess.run([ema_mean], feed_dict={user_input:u, isTrain: False })
print("Test correct", aa)
for u in u_input:
aa = sess.run([mean], feed_dict={user_input:u, isTrain: False })
print("Test", aa)
该代码片段应该计算在测试阶段期间训练阶段和输出平均值的user_input平均值。Tensorflow tf.cond评估两个脚蹬
这是输出结果:
('Train', [array([ 2.], dtype=float32)])
('Train', [array([ 3.], dtype=float32)])
('Train', [array([ 4.], dtype=float32)])
('Test correct', [array([ 3.], dtype=float32)])
('Test correct', [array([ 3.], dtype=float32)])
('Test correct', [array([ 3.], dtype=float32)])
('Test', [array([ 2.5], dtype=float32)])
('Test', [array([ 2.75], dtype=float32)])
('Test', [array([ 3.375], dtype=float32)])
但是,调用sess.run([mean])
即使isTrain = False
当ema_mean
始终得到评估。
代码中是否有任何错误? tensorflow版本是0.7.1
https://gist.github.com/24hours/0545f92d5407bcdd3106如果'isTrain = False'不应评估'batch_mean',如果'isTrain = True'不应评估'ema_mean',是火车的价值。这是正确的行为? – 24hours
@ 24小时:这可能是[这个问题](https://stackoverflow.com/questions/37063952/confused-by-the-behavior-of-tf-cond)。尝试在'mean_var_with_update'函数内创建'ema_apply_op'。 – Albert