2016-03-01 33 views
2

我是Tensorflow的初学者。我从“入门”页面 中选择了一个适合行的示例,并且做出了我认为对其进行了几乎微不足道的修改,但完全失败。 我不明白。初学者 - 解决一个简单的凸优化

在修改后的版本中,数组b_data是两个已知高斯的和,其权重未知。 尝试解决这些重量。这是一个二次问题,可以作为线性系统来解决 。

尽管真实权重为0.4,0.2,但梯度下降给出w [0]为负数, 和w [1]为正数。

这就是问题:虽然问题是凸的(二次均匀),但tensorflow并没有找到正确的答案。

我想我一定是做了错误的损失功能? 事实上,我认为用损失

tf.reduce_sum(tf.square(b - b_data))

是我想要的(对应于平方2范数|| b - b_data ||^2),但是 尝试这种更是雪上加霜,将导致在NaNs。

import numpy as np 
import matplotlib.pyplot as pl 
import tensorflow as tf 

RES = 200 
CEN = [0.2, 0.3, 0.6] 
SD = [0.1, 0.15, 0.07] 
X = np.linspace(0., 1., RES).astype(np.float32) 
G0 = np.exp(- np.power(X - CEN[0], 2)/SD[0]) 
G1 = np.exp(- np.power(X - CEN[1], 2)/SD[1]) 
B = np.vstack([G0,G1]) 
B = B.T 

b_data = 0.4*G0 + 0.2*G1 

# check numpy answer 
w_ = np.linalg.lstsq(B,b_data) 
print('numpy answer',w_[0])  # correct: 0.4, 0.2 

w = tf.Variable(tf.random_uniform([2,1], 0., 0.5)) 
b = tf.matmul(B,w) 

loss = tf.reduce_mean(tf.square(b - b_data)) 
#loss = tf.reduce_sum(tf.square(b - b_data)) 
optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.01) 
train = optimizer.minimize(loss) 

init = tf.initialize_all_variables() 

sess = tf.Session() 
sess.run(init) 

for step in xrange(8001): 
    sess.run(train) 
    if step % 100 == 0: 
     print(step, sess.run(loss), sess.run(tf.transpose(w))) 

print('w', sess.run(w)) 
bfit = sess.run(w[0,0])*G0 + sess.run(w[1,0])*G1 

pl.clf() 
pl.plot(G0,'g-') 
pl.plot(G1,'b-') 
pl.plot(b_data,'r-') 
pl.plot(bfit,'k-') 
pl.show() 
+0

你得到的错误究竟是什么?或者它只是打印(步骤,sess.run(损失),sess.run(tf.transpose(w))) 返回NaN? –

+0

我补充说明。问题是张量流程代码找不到正确的解决方案。 – bullwinkle

+0

我尝试打印渐变,它快速到零,所以我认为这是一个局部最小值。 '[wgrad,_] = optimizer.compute_gradients(loss,[w])[0]'但这个问题似乎是二次的,所以这很奇怪 –

回答

0

你就没有优化配置适当的成本函数,但一个较大的减法播放你的阵列b_datab

> print(tf.square(b - b_data)) 
Tensor("Square:0", shape=(200, 200), dtype=float32) 

>print(tf.square(b[:, 0] - b_data)) 
Tensor("Square:0", shape=(200,), dtype=float32) 

这是通过在tensorflow实施(see for instance this issue)反直觉的广播引起的。

如果用loss = tf.reduce_mean(tf.square(b[:, 0] - b_data))代替损失,优化成功并返回结果。