2017-05-23 68 views
1

假设我有以下功能:重写功能,以避免在numpy.exp溢出

import numpy, itertools  

def my_func(x): 
     Z = 0 
     for y in itertools.product([-1, 1], repeat=10): 
      Z += numpy.exp(numpy.dot(x,y)) 
     return numpy.log(Z) 

功能工作正常,不包含“太极端”值输入向量。但是,如果输入值太过于极端,那么在numpy.exp函数中会有一些溢出。这里有一个例子:

x = numpy.random.normal(5, 10, 10) 
my_func(x) 

到目前为止,该功能工作正常。但是,如果我用一个极端值代入x的一个元素,我得到的溢出错误:

x[3] = 6000 
my_func(x) 

有没有办法改写功能使得溢出避免?我知道溢出出现的位置和原因。尽管如此,我仍然无法找到重写函数的方法来避免它。

回答

1

NumPy有np.logaddexp(x1, x2),计算log(exp(x1) + exp(x2))。所以,你可以重写你的循环为:

z = np.logaddexp(z, np.dot(x, y)) 

,并跳过最后调用np.log

+0

谢谢,这就是我一直在寻找! – nhoeft