2014-03-13 92 views
1

我有两个类型的字典:的Python:计算两个类型的字典的余弦相似度更快

d1 = {1234: 4, 125: 7, ...} 
d2 = {1234: 8, 1288: 5, ...} 

http://stardict.sourceforge.net/Dictionaries.php下载的长度为10至40000。变化要计算我使用此功能的余弦相似性:

from scipy.linalg import norm 
def simple_cosine_sim(a, b): 
    if len(b) < len(a): 
     a, b = b, a 

    res = 0 
    for key, a_value in a.iteritems(): 
     res += a_value * b.get(key, 0) 
    if res == 0: 
     return 0 

    try: 
     res = res/norm(a.values())/norm(b.values()) 
    except ZeroDivisionError: 
     res = 0 
    return res 

可以更快地计算相似度吗?

UPD:使用Cython +重写代码+速度提高15%。感谢@Davidmh

from scipy.linalg import norm 

def fast_cosine_sim(a, b): 
    if len(b) < len(a): 
     a, b = b, a 

    cdef long up, key 
    cdef int a_value, b_value 

    up = 0 
    for key, a_value in a.iteritems(): 
     b_value = b.get(key, 0) 
     up += a_value * b_value 
    if up == 0: 
     return 0 
    return up/norm(a.values())/norm(b.values()) 
+0

我已经评论了你用Cython代码,增加了一种替代方法。我希望这有帮助。 – Davidmh

回答

1

如果索引不是太高,可以将每个字典转换为数组。如果它们非常大,则可以使用稀疏数组。那么,余弦相似性只会使它们两者相乘。如果您需要重复使用同一个字典进行多次计算,则此方法的性能最佳。

如果这不是一个选项,只要您注释a_value和b_value,Cython应该是非常快的。

编辑: 看看你的Cython重写,我看到了一些改进。第一件事是做一个cython -a来生成汇编的HTML报告,看看哪些事情已经加速,哪些没有。首先,你定义“up”为止,但是你总结了整数。另外,在你的例子中,键是整数,但是你将它们声明为double。另一个简单的事情是将输入键入为字符串。

此外,检查C代码,似乎有一些没有检查,您可以通过使用@ cython.nonechecks(False)禁用。

实际上,字典的实现是非常有效的,所以在一般情况下,你可能不会比这更好。如果您需要挤压最出你的代码,也许是值得的C API替换一些电话:http://docs.python.org/2/c-api/dict.html

cpython.PyDict_GetItem(a, key) 

但是,你将负责引用计数和的PyObject *铸造为int的一个可疑的表现收益。

任何方式,代码的开头是这样的:

cimport cython 

@cython.nonecheck(False) 
@cython.cdivision(True) 
def fast_cosine_sim(dict a, dict b): 
    if len(b) < len(a): 
     a, b = b, a 

    cdef int up, key 
    cdef int a_value, b_value 

还有另一个问题:是你dicionaries大?因为如果它们不是,规范的计算实际上可能是一个重要的开销。

编辑2: 另一种可能的方法是只查看必要的键。说:

from scipy.linalg import norm 
cimport cython 

@cython.nonecheck(False) 
@cython.cdivision(True) 
def fast_cosine_sim(dict a, dict b): 
    cdef int up, key 
    cdef int a_value, b_value 

    up = 0 
    for key in set(a.keys()).intersection(b.keys()): 
     a_value = a[key] 
     b_value = b[key] 
     up += a_value * b_value 
    if up == 0: 
     return 0 
    return up/norm(a.values())/norm(b.values()) 

这在Cython中非常高效。实际的表现可能取决于键之间有多少重叠。

+0

该词典可以包含40000多个项目。因此,将它们转换为一个集合并找到交点不会很快。钥匙是“长”型。并且'a_value * b_value'的总和可以大于int值。我认为Cython不能自动转换类型(比如Python),这就是为什么我把'up'定义为'long'的原因。 –

1

从算法的角度来看,没有。你已经处于复杂的O(N)。虽然有一些计算技巧可以使用。

您可以使用多处理模块将a_value * b.get(key, 0)乘法调度给几个工人,从而利用您拥有的所有机器核心。请注意,您将不会使用线程获得此效果,因为Python具有全局解释器锁定。

最简单的方法是使用池对象的multiproccess.Poolmap方法。

我强烈建议使用Python内置的cProfiler来检查代码中的热点。这很容易。只要运行:

python -m cProfile myscript.py

+0

问题是''simple_cosine_sim'从'multiproccess.Pool'中的map中运行的函数调用:) –

+1

去核心?剖析代码并检查乘法是否是热点。如果确实如此,你可以使用PyCUDA。 – rafgoncalves

+0

我会朝这个方向看 –