2017-07-17 129 views
0

我想创建一个循环计算图。这个想法很简单,详细如下:循环计算图

  • 初始化一个网络网络的权重。
  • 从初始化权重是高斯均值的多变量高斯样本中抽样N个权重。
  • 评估每组权重的损失函数。
  • 适当地更新权重。

的基本方法的一个图像可以被看作是如下:

enter image description here

我目前的做法是for循环在培训期间进行采样和更新的权重。但是,这很慢,我想知道是否可以将此功能构建到计算图表中并加快我的培训速度。

回答

0

您应该可以在计算图表中完成所有操作。例如,与权重变量W

NUM_SAMPLES = 10 
STDDEV = 1 

# Assuming W statically shaped, otherwise you'd use tf.shape and tf.concat 
samples_shape = [0] + W.shape.as_list() 
# Generate random numbers with W as mean 
samples = tf.random_normal(samples_shape, 
          stddev=tf.constant(STDDEV, dtype=W.dtype), 
          dtype=W.dtype) 
samples += W[tf.newaxis, :] 
# The loss function should return a vector the size of 
# the first dimension of samples 
samples_loss = loss(samples) 
idx = tf.argmin(samples_loss, axis=0) 
# Update W 
update_op = tf.assign(W, samples[idx]) 

然后,你跑update_op执行一个更新步骤,或者用它作为控制依赖与其他运营下去:​​

with tf.control_dependencies([update_op]): 
    # More ops... 
+0

优秀。非常感谢。 – Garland