2013-07-29 24 views
6

更新:使用Rcpp和openMP从截断正态分布快速采样

我试图实现Dirk的建议。注释? 我现在正忙于JSM,但是我希望在针织画廊的Rmd之前得到一些反馈。 我从Armadillo切换回普通Rcpp,因为它没有添加任何值。 R ::的标量版本相当不错。 如果将mean/sd作为标量输入,我应该在绘制次数中输入参数n,而不是作为所需输出长度的向量。


有很多MCMC应用程序需要从截断的正态分布中抽取样本。我建立在TN的现有实现上并为其添加了并行计算。

问题:

  1. 有谁看到进一步的潜在速度的改进? 在基准测试的最后一个例子中,rtruncnorm有时更快。 Rcpp的实现总是比现有的包快,但它可以进一步改进吗?
  2. 我在一个我无法分享的复杂模型中运行它,并且我的R会话崩溃了。但是,我无法系统地复制它,所以它可能是代码的另一部分。如果有人正在与TN合作,请对其进行测试并通知我。更新:我没有更新代码的问题,但让我知道。

我如何把东西放在一起: 据我所知,最快的实现是不是在CRAN,但源代码可以下载OSU statmsmtrunco​​rm中的竞争实现在我的基准测试中较慢。诀窍是有效调整投标分布,其中指数正好适用于截断法线的尾部。 所以我拿了克里斯的代码,“Rcpp'ed”它,并添加了一些openMP香料。这里的动态时间表是最优的,因为取决于边界,取样可能需要更多或更少的时间。 我发现一件令人讨厌的事情是:当我想要使用双打时,很多统计分布都基于NumericVector类型。我只是用我的方式编码。

继承人的RCPP代码:

#include <Rcpp.h> 
#include <omp.h> 


// norm_rs(a, b) 
// generates a sample from a N(0,1) RV restricted to be in the interval 
// (a,b) via rejection sampling. 
// ====================================================================== 

// [[Rcpp::export]] 

double norm_rs(double a, double b) 
{ 
    double x; 
    x = Rf_rnorm(0.0, 1.0); 
    while((x < a) || (x > b)) x = norm_rand(); 
    return x; 
} 

// half_norm_rs(a, b) 
// generates a sample from a N(0,1) RV restricted to the interval 
// (a,b) (with a > 0) using half normal rejection sampling. 
// ====================================================================== 

// [[Rcpp::export]] 

double half_norm_rs(double a, double b) 
{ 
    double x; 
    x = fabs(norm_rand()); 
    while((x<a) || (x>b)) x = fabs(norm_rand()); 
    return x; 
} 

// unif_rs(a, b) 
// generates a sample from a N(0,1) RV restricted to the interval 
// (a,b) using uniform rejection sampling. 
// ====================================================================== 

// [[Rcpp::export]] 

double unif_rs(double a, double b) 
{ 
    double xstar, logphixstar, x, logu; 

    // Find the argmax (b is always >= 0) 
    // This works because we want to sample from N(0,1) 
    if(a <= 0.0) xstar = 0.0; 
    else xstar = a; 
    logphixstar = R::dnorm(xstar, 0.0, 1.0, 1.0); 

    x = R::runif(a, b); 
    logu = log(R::runif(0.0, 1.0)); 
    while(logu > (R::dnorm(x, 0.0, 1.0,1.0) - logphixstar)) 
    { 
     x = R::runif(a, b); 
     logu = log(R::runif(0.0, 1.0)); 
    } 
    return x; 
} 

// exp_rs(a, b) 
// generates a sample from a N(0,1) RV restricted to the interval 
// (a,b) using exponential rejection sampling. 
// ====================================================================== 

// [[Rcpp::export]] 

double exp_rs(double a, double b) 
{ 
    double z, u, rate; 

// Rprintf("in exp_rs"); 
    rate = 1/a; 
//1/a 

    // Generate a proposal on (0, b-a) 
    z = R::rexp(rate); 
    while(z > (b-a)) z = R::rexp(rate); 
    u = R::runif(0.0, 1.0); 

    while(log(u) > (-0.5*z*z)) 
    { 
     z = R::rexp(rate); 
     while(z > (b-a)) z = R::rexp(rate); 
     u = R::runif(0.0,1.0); 
    } 
    return(z+a); 
} 




// rnorm_trunc(mu, sigma, lower, upper) 
// 
// generates one random normal RVs with mean 'mu' and standard 
// deviation 'sigma', truncated to the interval (lower,upper), where 
// lower can be -Inf and upper can be Inf. 
//====================================================================== 

// [[Rcpp::export]] 
double rnorm_trunc (double mu, double sigma, double lower, double upper) 
{ 
int change; 
double a, b; 
double logt1 = log(0.150), logt2 = log(2.18), t3 = 0.725; 
double z, tmp, lograt; 

change = 0; 
a = (lower - mu)/sigma; 
b = (upper - mu)/sigma; 

// First scenario 
if((a == R_NegInf) || (b == R_PosInf)) 
    { 
    if(a == R_NegInf) 
     { 
    change = 1; 
    a = -b; 
    b = R_PosInf; 
     } 

    // The two possibilities for this scenario 
    if(a <= 0.45) z = norm_rs(a, b); 
    else z = exp_rs(a, b); 
    if(change) z = -z; 
    } 
// Second scenario 
else if((a * b) <= 0.0) 
    { 
    // The two possibilities for this scenario 
    if((R::dnorm(a, 0.0, 1.0,1.0) <= logt1) || (R::dnorm(b, 0.0, 1.0, 1.0) <= logt1)) 
     { 
    z = norm_rs(a, b); 
     } 
    else z = unif_rs(a,b); 
    } 
// Third scenario 
else 
    { 
    if(b < 0) 
     { 
    tmp = b; b = -a; a = -tmp; change = 1; 
     } 

    lograt = R::dnorm(a, 0.0, 1.0, 1.0) - R::dnorm(b, 0.0, 1.0, 1.0); 
    if(lograt <= logt2) z = unif_rs(a,b); 
    else if((lograt > logt1) && (a < t3)) z = half_norm_rs(a,b); 
    else z = exp_rs(a,b); 
    if(change) z = -z; 
    } 
    double output; 
    output = sigma*z + mu; 
return (output); 
} 


// rtnm(mu, sigma, lower, upper, cores) 
// 
// generates one random normal RVs with mean 'mu' and standard 
// deviation 'sigma', truncated to the interval (lower,upper), where 
// lower can be -Inf and upper can be Inf. 
// mu, sigma, lower, upper are vectors, and vectorized calls of this function 
// speed up computation 
// cores is an intege, representing the number of cores to be used in parallel 
//====================================================================== 


// [[Rcpp::export]] 

Rcpp::NumericVector rtnm(Rcpp::NumericVector mus, Rcpp::NumericVector sigmas, Rcpp::NumericVector lower, Rcpp::NumericVector upper, int cores){ 
    omp_set_num_threads(cores); 
    int nobs = mus.size(); 
    Rcpp::NumericVector out(nobs); 
    double logt1 = log(0.150), logt2 = log(2.18), t3 = 0.725; 
    double a,b, z, tmp, lograt; 

    int change; 

    #pragma omp parallel for schedule(dynamic) 
    for(int i=0;i<nobs;i++) { 

    a = (lower(i) - mus(i))/sigmas(i); 
    b = (upper(i) - mus(i))/sigmas(i); 
    change=0; 
    // First scenario 
    if((a == R_NegInf) || (b == R_PosInf)) 
     { 
     if(a == R_NegInf) 
      { 
       change = 1; 
       a = -b; 
       b = R_PosInf; 
      } 

     // The two possibilities for this scenario 
     if(a <= 0.45) z = norm_rs(a, b); 
     else z = exp_rs(a, b); 
     if(change) z = -z; 
     } 
    // Second scenario 
    else if((a * b) <= 0.0) 
     { 
     // The two possibilities for this scenario 
     if((R::dnorm(a, 0.0, 1.0,1.0) <= logt1) || (R::dnorm(b, 0.0, 1.0, 1.0) <= logt1)) 
      { 
       z = norm_rs(a, b); 
      } 
     else z = unif_rs(a,b); 
     } 

    // Third scenario 
    else 
     { 
     if(b < 0) 
      { 
       tmp = b; b = -a; a = -tmp; change = 1; 
      } 

     lograt = R::dnorm(a, 0.0, 1.0, 1.0) - R::dnorm(b, 0.0, 1.0, 1.0); 
     if(lograt <= logt2) z = unif_rs(a,b); 
     else if((lograt > logt1) && (a < t3)) z = half_norm_rs(a,b); 
     else z = exp_rs(a,b); 
     if(change) z = -z; 
     } 
    out(i)=sigmas(i)*z + mus(i);   
    } 

return(out); 
} 

这里是标杆:作为速度取决于上/下限

libs=c("truncnorm","msm","inline","Rcpp","RcppArmadillo","rbenchmark") 
if(sum(!(libs %in% .packages(all.available = TRUE)))>0){ install.packages(libs[!(libs %in% .packages(all.available = TRUE))])} 
for(i in 1:length(libs)) {library(libs[i],character.only = TRUE,quietly=TRUE)} 


#needed for openMP parallel 
Sys.setenv("PKG_CXXFLAGS"="-fopenmp") 
Sys.setenv("PKG_LIBS"="-fopenmp") 

#no of cores for openMP version 
cores = 4 

#surce code from same dir 
Rcpp::sourceCpp('truncnorm.cpp') 


#sample size 
nn=1000000 


bb= 100 
aa=-100 
benchmark(rtnm(rep(0,nn),rep(1,nn),rep(aa,nn),rep(100,nn),cores), rtnm(rep(0,nn),rep(1,nn),rep(aa,nn),rep(100,nn),1),rtnorm(nn,rep(0,nn),rep(1,nn),rep(aa,nn),rep(100,nn)),rtruncnorm(nn, a=aa, b=100, mean = 0, sd = 1) , order="relative", replications=3 )[,1:4] 

aa=0 
benchmark(rtnm(rep(0,nn),rep(1,nn),rep(aa,nn),rep(100,nn),cores), rtnm(rep(0,nn),rep(1,nn),rep(aa,nn),rep(100,nn),1),rtnorm(nn,rep(0,nn),rep(1,nn),rep(aa,nn),rep(100,nn)),rtruncnorm(nn, a=aa, b=100, mean = 0, sd = 1) , order="relative", replications=3 )[,1:4] 

aa=2 
benchmark(rtnm(rep(0,nn),rep(1,nn),rep(aa,nn),rep(100,nn),cores), rtnm(rep(0,nn),rep(1,nn),rep(aa,nn),rep(100,nn),1),rtnorm(nn,rep(0,nn),rep(1,nn),rep(aa,nn),rep(100,nn)),rtruncnorm(nn, a=aa, b=100, mean = 0, sd = 1) , order="relative", replications=3 )[,1:4] 

aa=50 
benchmark(rtnm(rep(0,nn),rep(1,nn),rep(aa,nn),rep(100,nn),cores), rtnm(rep(0,nn),rep(1,nn),rep(aa,nn),rep(100,nn),1),rtnorm(nn,rep(0,nn),rep(1,nn),rep(aa,nn),rep(100,nn)),rtruncnorm(nn, a=aa, b=100, mean = 0, sd = 1) , order="relative", replications=3 )[,1:4] 

几个基准运行是必要的。对于不同的情况,在算法踢的不同部分

+0

为什么不试图实施N. Chopine快速截断的正态分布http://link.springer.com/article/10.1007%2Fs11222-009-9168-1? – dickoa

+0

inv_sqrt_2pi的所有这些额外数字都是无用的。你无法获得r中那么高精度的浮点数。 'print(inv_sqrt_2pi,digits = 18)'[1] 0.398942280401432703 –

+0

@dickoa有趣的是,那篇论文没有提到我。 – Inferrator

回答

3

真快评论:

  1. 如果包括RcppArmadillo.h你并不需要包括Rcpp.h - 事实上,你不应该,我们甚至测试

  2. rep(oneDraw, n)发出n个呼叫。我会写一个函数一旦返回您将n行将被称为 - 这将是更快,你救自己n-1个函数调用的开销

  3. 您对大量的统计分布的评论都是基于NumericVector类型,当我想与双打工作可能会泄露一些误解:NumericVector是我们的内部R类型的便利代理类:无副本。您可以自由使用std::vector<double>或您喜欢的任何形式。

  4. 我对截断法线知道不多,所以我不能评论你算法的细节。

  5. 一旦你有了它,考虑一个Rcpp Gallery的职位。