2016-03-24 96 views
3

只有一个共享变量阵列的一部分欲执行以下操作:计算梯度为在Theano

import theano, numpy, theano.tensor as T 

a = T.fvector('a') 

w = theano.shared(numpy.array([1, 2, 3, 4], dtype=theano.config.floatX)) 
w_sub = w[1] 

b = T.sum(a * w) 

grad = T.grad(b, w_sub) 

这里,w_sub是例如W [1],但我不想要显式写出来b在函数w_sub。尽管经历了http://deeplearning.net/software/theano/tutorial/faq_tutorial.html和其他相关问题,我无法解决它。

这只是为了向你展示我的问题。其实,我真正想做的是与千层面稀疏卷积。权重矩阵中的零条目不需要更新,因此不需要计算w的这些条目的梯度。

亲切问候,并提前谢谢!

的Jeroen

PS:现在这是完整的错误消息:

Traceback (most recent call last): 
    File "D:/Jeroen/Project_Lasagne_General/test_script.py", line 9, in <module> 
    grad = T.grad(b, w_sub) 
    File "C:\Anaconda2\lib\site-packages\theano\gradient.py", line 545, in grad 
    handle_disconnected(elem) 
    File "C:\Anaconda2\lib\site-packages\theano\gradient.py", line 532, in handle_disconnected 
    raise DisconnectedInputError(message) 
theano.gradient.DisconnectedInputError: grad method was asked to compute the gradient with respect to a variable that is not part of the computational graph of the cost, or is used only by a non-differentiable operator: Subtensor{int64}.0 
Backtrace when the node is created: 
    File "D:/Jeroen/Project_Lasagne_General/test_script.py", line 6, in <module> 
    w_sub = w[1] 

回答

2

当theano编译图表,只看到变量在图形如所明确定义。在您的示例中,w_sub未明确用于计算b,因此不是计算图的一部分。

使用带以下代码的theano打印库,您可以在此 graph vizualization上看到确实w_sub不是b图的一部分。

import theano 
import theano.tensor as T 
import numpy 
import theano.d3viz as d3v 

a = T.fvector('a') 
w = theano.shared(numpy.array([1, 2, 3, 4], dtype=theano.config.floatX)) 
w_sub = w[1] 
b = T.sum(a * w) 

o = b, w_sub 

d3v.d3viz(o, 'b.html') 

为了解决这个问题,就需要在b计算明确使用w_sub

然后你就能够计算b WRT w_sub的梯度和更新共享变量的值,如下面的例子:

import theano 
import theano.tensor as T 
import numpy 


a = T.fvector('a') 
w = theano.shared(numpy.array([1, 2, 3, 4], dtype=theano.config.floatX)) 
w_sub = w[1] 
b = T.sum(a * w_sub) 
grad = T.grad(b, w_sub) 
updates = [(w, T.inc_subtensor(w_sub, -0.1*grad))] 

f = theano.function([a], b, updates=updates, allow_input_downcast=True) 

f(numpy.arange(10)) 
+0

上正在发生的事情一个很好的解释。谢谢。 –