2014-02-19 66 views
3

我正在开发一个CUDA应用程序,它要求我将一些任意函数传递给CUDA内核。由于为每种可能的情况声明一个函数指针并将它们传递给内核将会太麻烦(> 50个不同的函数),并且它们都是基本函数的组合,如sin(x)/y,我想要一些最小的Lambda - CUDA内核的表达功能。由于C++ 11功能尚未被设备代码支持(据我所知),并且我没有在网上找到任何相关信息,于是我决定自学自定义表达式模板并实现一些简单的lambda表达式规则以传入内核。在Cuda中使用表达式模板构建lambda表达式

我想出了下面的代码,这是一种在NVCC上编译并运行良好的最小实现。然而,沿着这条道路,我只能用1个变量来实现函数。有什么方法可以扩展我的代码来处理函数组合,如sin(_x) + _y

在此先感谢!

#include<math.h> 

#ifdef __CUDACC__ 
#define HOST_DEVICE __host__ __device__ 
#else 
#define HOST_DEVICE 
#endif 

struct Id {}; 

template <typename Op, typename Left, typename Right> 
struct BinaryOp 
{ 
    Left left; 
    Right right; 
    HOST_DEVICE BinaryOp(Left t1, Right t2) : left(t1), right(t2) {} 

    HOST_DEVICE double operator() (double x) { 
     return Op::apply(left(x), right(x)); 
    } 
}; 

template <typename Op, typename Arg> 
struct UnaryOp 
{ 
    Arg arg; 
    HOST_DEVICE UnaryOp(Arg t1) : arg(t1) {} 

    HOST_DEVICE double operator() (double x) { 
     return Op::apply(arg(x)); 
    } 
}; 

template <> 
struct UnaryOp<Id, double> 
{ 
    HOST_DEVICE UnaryOp() {} 
    HOST_DEVICE double operator() (double x) { 
     return x; 
    } 
}; 

struct Sin 
{ 
    HOST_DEVICE static double apply(double x) { 
     return sin(x); 
    } 
}; 

struct Plus 
{ 
    HOST_DEVICE static double apply(double a, double b) { 
     return a + b; 
    } 
}; 

template <typename Left, typename Right> 
BinaryOp<Plus, Left, Right> operator+ (Left lhs, Right rhs) { 
    return BinaryOp<Plus, Left, Right>(lhs, rhs); 
} 

template <typename Arg> 
UnaryOp<Sin, Arg> _sin(Arg arg) { 
    return UnaryOp<Sin, Arg>(arg); 
} 

template <class T> 
__global__ void test(T func, double x) { 
    printf("%e\n", func(x)); 
} 

int main() 
{ 
    UnaryOp<Id, double> _x; 
    double x = 1.0; 
    test<<<1, 1>>>(_sin(_x) + _x, x); 
    cudaDeviceSynchronize(); // Needed or the host will return before kernel is finished 
    return 0; 
} 
+0

看[表达式模板](http://en.wikipedia.org/wiki/Expression_templates)。 – Constructor

+0

@Constructor谢谢,但我已详细阅读,并提出了自己的代码来实现表达式模板。但是我不认为该页面有足够的信息用于我想要做的事情:构建超过1个变量的lambda表达式。 –

+0

你能向我解释为什么你所做的比仅仅函数指针更简单吗?我真的很想知道。我一直盯着你的代码2天,我仍然没有看到优势 – portforwardpodcast

回答

1

所以我花了一些时间问这个问题后设计了一个简单的解决方案。这很丑,但它对我自己很有用。这是修改的代码,最多支持3个自由变量。更多变量可以用硬编码,但目前我没有需要我的项目。

#include<math.h> 

#ifdef __CUDACC__ 
#define HOST_DEVICE __host__ __device__ 
#else 
#define HOST_DEVICE 
#endif 

struct Id {}; 

template <typename Op, typename Left, typename Right> 
struct BinaryOp 
{ 
    Left left; 
    Right right; 
    HOST_DEVICE BinaryOp(Left t1, Right t2) : left(t1), right(t2) {} 

    HOST_DEVICE double operator() (double x1, double x2 = 0.0, double x3 = 0.0) { 
     return Op::apply(left(x1, x2, x3), right(x1, x2, x3)); 
    } 
}; 

template <typename Op, typename Arg> 
struct UnaryOp 
{ 
    Arg arg; 
    HOST_DEVICE UnaryOp(Arg t1) : arg(t1) {} 

    HOST_DEVICE double operator() (double x1, double x2 = 0.0, double x3 = 0.0) { 
     return Op::apply(arg(x1, x2, x3)); 
    } 
}; 

template <int argnum> 
struct Var 
{ 
    HOST_DEVICE Var() {} 
    HOST_DEVICE double operator() (double x1, double x2 = 0.0, double x3 = 0.0) { 
     if (1 == argnum) return x1; 
     else if (2 == argnum) return x2; 
     else return x3; 
    } 
}; 

struct Sin 
{ 
    HOST_DEVICE static double apply(double x) { 
     return sin(x); 
    } 
}; 

struct Plus 
{ 
    HOST_DEVICE static double apply(double a, double b) { 
     return a + b; 
    } 
}; 

template <typename Left, typename Right> 
BinaryOp<Plus, Left, Right> operator+ (Left lhs, Right rhs) { 
    return BinaryOp<Plus, Left, Right>(lhs, rhs); 
} 

template <typename Arg> 
UnaryOp<Sin, Arg> _sin(Arg arg) { 
    return UnaryOp<Sin, Arg>(arg); 
} 

template <class T> 
__global__ void test(T func, double x, double y, double z = 0.0) { 
    printf("%e\n", func(x, y)); 
} 

Var<1> _x; 
Var<2> _y; 

int main() 
{ 
    test<<<1, 1>>>(_sin(_x) + _y, 1.0, 2.0); 
    cudaDeviceSynchronize(); // Needed or the host will return before kernel is finished 
    return 0; 
} 

这显然是一个丑陋的黑客。 lambda表达式仅适用于double(或可转换为double的类型)。不过,我无法想象目前有什么办法可以解决这个问题。希望NVCC可以很快支持C++ 11的功能,这样我就不再需要这种破解了。

如果任何人都能向我展示一个更好的解决方案,无论是图书馆还是更好的方法一起黑客攻击,我们将不胜感激。谢谢你的帮助!