2014-04-24 37 views
1

我有这个R函数来生成一个矩阵,它包含0和n之间的k个数的所有组合,其总和等于n。这是我的程序的瓶颈之一,因为它变得更小数目非常慢(因为它是计算发电机组)0和n之间的k个数的所有组合,其总和等于n,速度优化

下面是代码

sum.comb <- 
function(n,k) { 

ls1 <- list()       # generate empty list 
for(i in 1:k) {      # how could this be done with apply? 
    ls1[[i]] <- 0:n      # fill with 0:n 
} 
allc <- as.matrix(expand.grid(ls1))  # generate all combinations, already using the built in function 
colnames(allc) <- NULL 
index <- (rowSums(allc) == n)  # make index with only the ones that sum to n 
allc[index, ,drop=F]     # matrix with only the ones that sum to n 
} 
+0

您应该删除部分数据集。例如当你看nz时,你只想考虑k = 2时的数字1:z。然后使用相同的算法从第三列中删除数字(如果k = 3等)。 –

+0

@HansRoggeman,这意味着几个嵌套的循环还是有更优雅的方式? – spore234

+0

n和k的典型值是多少?不同的算法在飞机的不同部分可能表现更好。至少我们可以尝试改善你的情况。 – flodel

回答

4

这是很难说,如果这将是有益的除非你回答我关于对nk(请做的。)下面是一个使用递归,似乎用他的基准测试中比josilber的更快版本的典型值的问题:

sum.comb3 <- function(n, k) { 

    stopifnot(k > 0L) 

    REC <- function(n, k) { 
     if (k == 1L) list(n) else 
     unlist(lapply(0:n, function(i)Map(c, i, REC(n - i, k - 1L))), 
      recursive = FALSE) 
    } 

    matrix(unlist(REC(n, k)), ncol = k, byrow = TRUE) 
} 

microbenchmark(sum.comb(3, 10), sum.comb2(3, 10), sum.comb3(3, 10)) 
# Unit: milliseconds 
#    expr  min  lq median  uq  max neval 
# sum.comb2(3, 10) 39.55612 40.60798 41.91954 44.26756 70.44944 100 
# sum.comb3(3, 10) 25.86008 27.74415 28.37080 29.65567 34.18620 100 
+0

谢谢,这看起来非常好。我会试图找出k和n的值。 – spore234

1

下可以用lapply完成。

ls1 <- list() 
for(i in 1:k) { 
    ls1[[i]] <- 0:n 
} 

尝试替换这是,看看你是否得到任何加快。

ls1 = lapply(1:k,function(x) 0:n) 

我改变 'LS' 到 'LS1',因为LS()是R的功能。

+0

谢谢,我很好奇这里如何使用lapply,虽然这不是瓶颈。我也忘记了ls是内部的,我将编辑原始代码。 – spore234

+2

'rep(list(0:n),k)'也 –

3

下面是一个不同的方法,它在每次迭代时逐渐扩大集合从1到k的集合,修剪其总和超过n的组合。这应该会导致加速比你有一个很大的k相对于n,因为你不需要计算任何接近功率集的大小。

sum.comb2 <- function(n, k) { 
    combos <- 0:n 
    sums <- 0:n 
    for (width in 2:k) { 
    combos <- apply(expand.grid(combos, 0:n), 1, paste, collapse=" ") 
    sums <- apply(expand.grid(sums, 0:n), 1, sum) 
    if (width == k) { 
     return(combos[sums == n]) 
    } else { 
     combos <- combos[sums <= n] 
     sums <- sums[sums <= n] 
    } 
    } 
} 

# Simple test 
sum.comb2(3, 2) 
# [1] "3 0" "2 1" "1 2" "0 3" 

这里的小n和大k时的加速的一个例子:

library(microbenchmark) 
microbenchmark(sum.comb2(1, 100)) 
# Unit: milliseconds 
#    expr  min  lq median  uq  max neval 
# sum.comb2(1, 100) 149.0392 158.716 162.1919 174.0482 236.2095 100 

这种方法运行在一秒之内,而当然与发电机组的做法会不会把过去的通话到expand.grid,因为在结果矩阵中最终会有2^100行。

即使是在不那么极端的情况下,N = 3且k = 10,我们看到了一个20倍的加速相比,在原来的职位功能:

microbenchmark(sum.comb(3, 10), sum.comb2(3, 10)) 
# Unit: milliseconds 
#    expr  min  lq median  uq  max neval 
# sum.comb(3, 10) 404.00895 439.94472 446.67452 461.24909 574.80426 100 
# sum.comb2(3, 10) 23.27445 24.53771 25.60409 26.97439 65.59576 100 
1

什么喜欢的东西短:

comb = function(n, k) { 
    all = combn(0:n, k) 
    sums = colSums(all) 
    all[, sums == n] 
} 

然后是这样的:

comb(5, 3) 

产生如你要求的矩阵:

 [,1] [,2] 
[1,] 0 0 
[2,] 1 2 
[3,] 4 3 

由于@josilber和原来的海报指出所需的OP所有排列与重复而不是组合。对于排列类似的方法将如下所示:

perm = function(n, k) { 
    grid = matrix(rep(0:n, k), n + 1, k) 
    all = expand.grid(data.frame(grid)) 
    sums = rowSums(all) 
    all[sums == n,] 
} 

然后是这样的:

perm(5, 3) 

产生如你要求的矩阵:

X1 X2 X3 
6 5 0 0 
11 4 1 0 
16 3 2 0 
21 2 3 0 
26 1 4 0 
31 0 5 0 
... 
+0

这个解决方案的问题是,你不会得到像(5,0,0)这样的结果,它重用0和'n'之间的一个数字。此外,你没有得到所有的顺序(例如,而不是014该OP正在寻找014,041,140,​​104,401,410)。 – josliber

+0

而不是一个“问题”我的解决方案符合组合的数学定义:'集合S的k-组合是S的k **不同**元素的子集http://en.wikipedia.org/wiki /组合 – user14382

+0

好的,但是通过运行OP发布的代码,您可以看到这不是他们实际要做的。 – josliber

2

partitions包partitcularly compositions()blockparts()它们将会像整个矩阵生成器和迭代操作一样更快。那么如果这还不够快,请参阅关于组合和分区生成算法(空洞代码,灰色代码和并行)的各种出版物,如Daniel Page's research

library(partitions) 
library(microbenchmark) 

# rcpp_comps is an Rcpp implementation of compositions using loop 
# free grey code, just for illustrative purposes. 

# Just get the matrix 
microbenchmark(compositions(3,10), compositions(10,3), 
       blockparts(rep(10,3),10), blockparts(rep(3,10),3), 
       rcpp_comps(10), times=10) 
## Unit: microseconds 
##      expr min  lq median uq max neval 
##   compositions(3, 10) 1967.4 2050.9 2097.1 2173 3189.6 10 
##   compositions(10, 3) 618.2 638.5 654.6 688 700.7 10 
## blockparts(rep(10, 3), 10) 612.2 620.8 645.6 663 963.5 10 
## blockparts(rep(3, 10), 3) 2057.2 2089.2 2176.0 2242 3116.4 10 
##    rcpp_comps(10) 359.9 360.7 367.6 378 404.2 10 
相关问题