2013-06-03 77 views
2

我正在使用CUDA和Thrust。我发现输入thrust::transform [plus/minus/divide]乏味,所以我只想重载一些简单的操作符。为Thrust重载“+”运算符,有什么想法?

这将是真棒,如果我能做到:

thrust::[host/device]_vector<float> host; 
thrust::[host/device]_vector<float> otherHost; 
thrust::[host/device]_vector<float> result = host + otherHost; 

下面是+一个例子片段:

template <typename T> 
__host__ __device__ T& operator+(T &lhs, const T &rhs) { 
    thrust::transform(rhs.begin(), rhs.end(), 
         lhs.begin(), lhs.end(), thrust::plus<?>()); 
    return lhs; 
} 

然而,thrust::plus<?>没有正确过载,或者我没有做它正确地......一个或另一个。 (如果为此重载简单的操作符是个不好的主意,请解释原因)。最初,我认为我可以用typename T::iterator之类的东西超载?占位符,但那不起作用。

我不知道如何使矢量类型的矢量迭代器的类型过载+运算符。这有意义吗?

感谢您的帮助!

+0

''是什么意思? – Elazar

+0

@Elazar这意味着我不知道该放什么。也许某种类型的'T :: iterator'类型或其他类型。 –

+0

我刚编辑它 –

回答

2

这似乎是工作,其他人可能有更好的想法:

#include <ostream> 
#include <thrust/host_vector.h> 
#include <thrust/device_vector.h> 
#include <thrust/transform.h> 
#include <thrust/functional.h> 
#include <thrust/copy.h> 
#include <thrust/fill.h> 

#define DSIZE 10 


template <typename T> 
thrust::device_vector<T> operator+(thrust::device_vector<T> &lhs, const thrust::device_vector<T> &rhs) { 
    thrust::transform(rhs.begin(), rhs.end(), 
         lhs.begin(), lhs.begin(), thrust::plus<T>()); 
    return lhs; 
} 

template <typename T> 
thrust::host_vector<T> operator+(thrust::host_vector<T> &lhs, const thrust::host_vector<T> &rhs) { 
    thrust::transform(rhs.begin(), rhs.end(), 
         lhs.begin(), lhs.begin(), thrust::plus<T>()); 
    return lhs; 
} 
int main() { 


    thrust::device_vector<float> dvec(DSIZE); 
    thrust::device_vector<float> otherdvec(DSIZE); 
    thrust::fill(dvec.begin(), dvec.end(), 1.0f); 
    thrust::fill(otherdvec.begin(), otherdvec.end(), 2.0f); 
    thrust::host_vector<float> hresult1 = dvec + otherdvec; 

    std::cout << "result 1: "; 
    thrust::copy(hresult1.begin(), hresult1.end(), std::ostream_iterator<float>(std::cout, " ")); std::cout << std::endl; 

    thrust::host_vector<float> hvec(DSIZE); 
    thrust::fill(hvec.begin(), hvec.end(), 5.0f); 
    thrust::host_vector<float> hresult2 = hvec + hresult1; 


    std::cout << "result 2: "; 
    thrust::copy(hresult2.begin(), hresult2.end(), std::ostream_iterator<float>(std::cout, " ")); std::cout << std::endl; 

    // this line would produce a compile error: 
    // thrust::host_vector<float> hresult3 = dvec + hvec; 

    return 0; 
} 

注意,在任何情况下,我可以指定结果的主机或设备向量,由于推力会看到其中的差别并自动生成必要的复制操作。所以我的模板中的结果矢量类型(主机,设备)并不重要。

另请注意,您在模板定义中使用的thrust::transform函数参数并不完全正确。

+0

啊,我认为这会起作用。我会在早上回去工作时检查它。 –

+1

我想我应该指出的另外一件事是(我认为)你提出了一个原地变换(这就是为什么你建议返回'lhs',而且我遵循你的约定),但是这个效果是一个两个操作数(lhs)被覆盖。因此,这可能会产生稍微不直观的行为,即result = vec1 + vec2;将结果放在* result *和'vec1'中。 –

相关问题