2014-04-04 58 views
2

我写了一个简单的Python代码来计算一个集合的熵,我试图在Theano中写同样的东西。Theano中的熵和概率

import math 

# this computes the probabilities of each element in the set 
def prob(values): 
    return [float(values.count(v))/len(values) for v in values] 

# this computes the entropy 
def entropy(values): 
    p = prob(values) 
    return -sum([v*math.log(v) for v in p]) 

我试图写在Theno等效代码,但我不知道如何做到这一点:

import theano 
import theano.tensor as T 

v = T.vector('v') # I create a symbolic vector to represent my initial values 
p = T.vector('p') # The same for the probabilities 

# this is my attempt to compute the probabilities which would feed vector p 
theano.scan(fn=prob,outputs_info=p,non_sequences=v,n_steps=len(values)) 

# considering the previous step would work, the entropy is just 
e = -T.sum(p*T.log(p)) 
entropy = theano.function([values],e) 

然而,扫描线是不正确的,我得到吨的错误。我不确定是否有简单的方法来实现它(计算矢量的熵),还是必须在扫描功能上付出更多的努力。有任何想法吗?

+0

Theano无法在列表上进行计算。您必须更新您的代码才能使用ndarray。首先只用numpy来做到这一点。这应该已经加快你的代码。 – nouiz

回答

0

除了nouiz提出的观点之外,P不应该被声明为T.vector,因为它将是你的向量值的计算结果。另外,为了计算像熵这样的东西,你不需要使用扫描(扫描引入了一个计算开销,所以它只能被使用,因为没有其他的方式来计算你想要的或者减少内存使用)。你可以采取如下方法:

values = T.vector('values') 
nb_values = values.shape[0] 

# For every element in 'values', obtain the total number of times 
# its value occurs in 'values'. 
# NOTE : I've done the broadcasting a bit more explicitly than 
# needed, for clarity. 
freqs = T.eq(values[:,None], values[None, :]).sum(0).astype("float32") 

# Compute a vector containing, for every value in 'values', the 
# probability of that value in the vector 'values'. 
# NOTE : these probabilities do *not* sum to 1 because they do not 
# correspond to the probability of every element in the vector 'values 
# but to the probability of every value in 'values'. For instance, if 
# 'values' is [1, 1, 0] then 'probs' will be [2/3, 2/3, 1/3] because the 
# value 1 has probability 2/3 and the value 0 has probability 1/3 in 
# values'. 
probs = freqs/nb_values 

entropy = -T.sum(T.log2(probs)/nb_values) 
fct = theano.function([values], entropy) 

# Will output 0.918296... 
print fct([0, 1, 1])