2016-08-22 101 views
1

T.meanthis example中有什么意义?如果实现是矢量化的,我认为T.mean是有意义的。这里的输入xytrain(x, y)是标量,而cost只发现单个输入的平方误差,并迭代数据。theano中的线性回归

cost = T.mean(T.sqr(y - Y)) 
gradient = T.grad(cost=cost, wrt=w) 
updates = [[w, w - gradient * 0.01]] 

train = theano.function(inputs=[X, Y], outputs=cost, updates=updates, allow_input_downcast=True) 

for i in range(100): 
    for x, y in zip(trX, trY): 
     train(x, y) 

print w.get_value() 

删除T.mean对输出模式没有影响。

回答

1

你是对的,T.mean在这里没有意义。成本函数一次对单个训练样本进行操作,所以“均方误差”实际上只是样本的平方误差。

本示例通过stochastic gradient descent(一种用于在线优化的算法)实现线性回归。 SGD像样本一样逐个迭代样本。但是,在更复杂的情况下,数据集通常是processed in mini-batches,这会提供更好的性能和收敛性能。

我认为T.mean留在这个例子中作为小批量梯度下降的人工产物,或者更明确地说成本函数是MSE。