2011-08-14 104 views
6

我想在C++中实现Karatsuba乘法算法,但现在我只是想让它在python中工作。Karatsuba算法太多递归

这里是我的代码:

def mult(x, y, b, m): 
    if max(x, y) < b: 
     return x * y 

    bm = pow(b, m) 
    x0 = x/bm 
    x1 = x % bm 
    y0 = y/bm 
    y1 = y % bm 

    z2 = mult(x1, y1, b, m) 
    z0 = mult(x0, y0, b, m) 
    z1 = mult(x1 + x0, y1 + y0, b, m) - z2 - z0 

    return mult(z2, bm ** 2, b, m) + mult(z1, bm, b, m) + z0 

我不明白的是:应该如何z2z1z0产生的呢?是否使用递归正确的mult函数?如果是这样,我搞乱了某个地方,因为递归没有停止。

有人能指出错误在哪里吗?

+1

当然,递归并没有停止:哪里是使递归停止的条件? – neurino

+2

我不确定,也许使用'x0,x1 = divmod(x,bm)'会更快。 – utdemir

+0

@neurino,在第一行,如果声明有回报。 – utdemir

回答

5

注:以下地址直接在OP的问题有关 过多的递归,但它并没有试图提供正确的 karatsuba算法的响应。其他回答在这方面 更多信息。

试试这个版本:

def mult(x, y, b, m): 
    bm = pow(b, m) 

    if min(x, y) <= bm: 
     return x * y 

    # NOTE the following 4 lines 
    x0 = x % bm 
    x1 = x/bm 
    y0 = y % bm 
    y1 = y/bm 

    z0 = mult(x0, y0, b, m) 
    z2 = mult(x1, y1, b, m) 
    z1 = mult(x1 + x0, y1 + y0, b, m) - z2 - z0 

    retval = mult(mult(z2, bm, b, m) + z1, bm, b, m) + z0 
    assert retval == x * y, "%d * %d == %d != %d" % (x, y, x * y, retval) 
    return retval 

与您的版本最严重的问题是,你的X0和X1,和y 0和y的计算翻转。此外,如果x1y1为0,算法的推导不成立,因为在这种情况下,分解步骤变得无效。因此,您必须通过确保x和y都大于b ** m来避免这种可能性。

编辑:修复了代码中的错字;加入澄清

EDIT2:

更清晰,直接在原来的版本注释:

def mult(x, y, b, m): 
    # The termination condition will never be true when the recursive 
    # call is either 
    # mult(z2, bm ** 2, b, m) 
    # or mult(z1, bm, b, m) 
    # 
    # Since every recursive call leads to one of the above, you have an 
    # infinite recursion condition. 
    if max(x, y) < b: 
     return x * y 

    bm = pow(b, m) 

    # Even without the recursion problem, the next four lines are wrong 
    x0 = x/bm # RHS should be x % bm 
    x1 = x % bm # RHS should be x/bm 
    y0 = y/bm # RHS should be y % bm 
    y1 = y % bm # RHS should be y/bm 

    z2 = mult(x1, y1, b, m) 
    z0 = mult(x0, y0, b, m) 
    z1 = mult(x1 + x0, y1 + y0, b, m) - z2 - z0 

    return mult(z2, bm ** 2, b, m) + mult(z1, bm, b, m) + z0 
+0

你能解释一下什么是关于这个版本从版本上面贴有什么不同?这如何解决无限递归? – templatetypedef

+2

这就解决了终端问题,但在'retval'的分配两个最终递归调用是错误的,所以是BM的计算:这不会给你Karatsuba的为O(n ^(log_2(3))的运行时间。 – huitseeker

+0

@huitseeker,在OP解释说,他/她只是想要得到的基本逻辑的工作,没有运行时间最佳,这是明显的,不仅从bm'的'在每一个递归运算,而且从使用'% '和'/'。至于最后两个递归调用,我不知道他们是否是原始算法的一部分(我是不是能够轻易找到它的完整版),所以在这个意义上,它们可能是“错误“,但是所显示的算法在我的所有测试运行中产生了正确的结果。如果您有反例,请将其发布。 – kjo

1

我相信这项技术背后的想法是使用递归算法计算术语,但结果并不是以这种方式统一在一起的。既然你想要的最终结果是

z0 B^2m + z1 B^m + z2 

假设你选择B的合适值(比如,2)你可以计算乙^ M没有做任何乘法。例如,当使用B = 2时,可以使用位移而不是乘法来计算B^m。这意味着最后一步可以不做任何乘法运算。

还有一件事 - 我注意到你已经为整个算法选择了一个固定的m值。通常情况下,你可以通过让m总是一个值来实现这个算法,使得B^m在写入基数B时是x和y的位数的一半。如果你使用2的幂,这将会完成通过选取m = ceil((log x)/ 2)。

希望这会有所帮助!

+0

谢谢。虐待最终会改变它,以便它将使用2的幂作为b。现在,我只是试图让这个工作 – calccrypto

4

通常大的数字存储为整数数组。每个整数代表一个数字。这种方法允许通过阵列的简单左移使基数乘以任意数量。

这是我基于列表的实现(可能包含bug):

def normalize(l,b): 
    over = 0 
    for i,x in enumerate(l): 
     over,l[i] = divmod(x+over,b) 
    if over: l.append(over) 
    return l 
def sum_lists(x,y,b): 
    l = min(len(x),len(y)) 
    res = map(operator.add,x[:l],y[:l]) 
    if len(x) > l: res.extend(x[l:]) 
    else: res.extend(y[l:]) 
    return normalize(res,b) 
def sub_lists(x,y,b): 
    res = map(operator.sub,x[:len(y)],y) 
    res.extend(x[len(y):]) 
    return normalize(res,b) 
def lshift(x,n): 
    if len(x) > 1 or len(x) == 1 and x[0] != 0: 
     return [0 for i in range(n)] + x 
    else: return x 
def mult_lists(x,y,b): 
    if min(len(x),len(y)) == 0: return [0] 
    m = max(len(x),len(y)) 
    if (m == 1): return normalize([x[0]*y[0]],b) 
    else: m >>= 1 
    x0,x1 = x[:m],x[m:] 
    y0,y1 = y[:m],y[m:] 
    z0 = mult_lists(x0,y0,b) 
    z1 = mult_lists(x1,y1,b) 
    z2 = mult_lists(sum_lists(x0,x1,b),sum_lists(y0,y1,b),b) 
    t1 = lshift(sub_lists(z2,sum_lists(z1,z0,b),b),m) 
    t2 = lshift(z1,m*2) 
    return sum_lists(sum_lists(z0,t1,b),t2,b) 

sum_listssub_lists回报未标准化结果 - 单一的数字可能比基值。 normalize函数解决了这个问题。

所有函数都希望以相反顺序获取数字列表。例如,基数10中的12应该写为[2,1]。让我们的9987654321.

» a = [1,2,3,4,5,6,7,8,9] 
» res = mult_lists(a,a,10) 
» res.reverse() 
» res 
[9, 7, 5, 4, 6, 1, 0, 5, 7, 7, 8, 9, 9, 7, 1, 0, 4, 1] 
4

方形的Karastuba倍增的目标是提高对分而通过使3个递归调用,而不是四个征服乘算法。因此,脚本中应包含递归调用乘法的唯一行是分配z0,z1z2的那些行。别的什么会给你一个更糟的复杂性。如果尚未定义乘法(并且还有一个幂级数指数),则不能使用pow来计算b^m

为此,该算法关键性地使用了它使用位置表示法系统的事实。如果你有一个基于b的数字的表示x,那么x*b^m只是通过将该表示m次的位移动到左边来获得。任何位置符号系统都可以实现这种移位操作。这也意味着如果你想实现这一点,你必须重现这种位置表示法和“自由”转变。要么你选择以b=2为基数进行计算,并且使用python的位运算符(或者如果测试平台具有给定的十进制,十六进制... base),,或者您决定实施用于教育目的的适用于一个任意的b,并且你用字符串,数组或列表来重现这个位置算术。

您已经有solution with lists了。我喜欢在python中使用字符串,因为int(s,base)会给出与字符串s相对应的整数,以base为基数表示:它使测试变得简单。 我已经发布了一个严格评论的基于字符串的实现作为要点here,其中包括字符串到数字和数字到字符串的基元以实现良好的度量。

可以通过提供填充字符串与底座和他们(等于)长作为参数传递给mult测试:

In [169]: mult("987654321","987654321",10,9) 

Out[169]: '966551847789971041' 

如果你不想弄清楚填充或计数的字符串长度,一个填充功能可以为你做到这一点:

In [170]: padding("987654321","2") 

Out[170]: ('987654321', '000000002', 9) 

当然它b>10工作:

In [171]: mult('987654321', '000000002', 16,9) 

Out[171]: '130eca8642' 

(使用wolfram alpha查询)