2013-03-09 64 views
11

我编写函数int compare_16bytes(__m128i lhs, __m128i rhs),以便使用SSE指令比较两个16字节数:此函数在执行比较之后返回多少个字节相等。快速计算两个数组之间的相等字节数

现在我想使用上面的函数来比较两个任意长度的字节数组:长度可能不是16字节的倍数,所以我需要处理这个问题。我如何完成下面的函数的实现?我如何改进下面的功能?

int fast_compare(const char* s, const char* t, int length) 
{ 
    int result = 0; 

    const char* sPtr = s; 
    const char* tPtr = t; 

    while(...) 
    { 
     const __m128i* lhs = (const __m128i*)sPtr; 
     const __m128i* rhs = (const __m128i*)tPtr; 

     // compare the next 16 bytes of s and t 
     result += compare_16bytes(*lhs,*rhs); 

     sPtr += 16; 
     tPtr += 16; 
    } 

    return result; 
} 
+2

如果剩余字节数小于16,则使用for循环(长度/ 16次),并将零填充到lhs和one。填充应该是不同的,以便它不会错误地将填充计数为相等。 – 2013-03-09 17:42:10

+1

'while(length> = 16){/ *使用你的函数*/length - = 16; } if(length)/ *使用比较长度(最多15个字节)的版本* /;' – pmg 2013-03-09 17:42:55

+1

FYI这通常称为[*汉明距离*](http://en.wikipedia.org/wiki/Hamming_distance ) - 这可能是一个有用的搜索词。 – 2013-03-09 18:02:50

回答

6

正如@Mysticial在上面的评论说,做比较和纵向总结,然后只在主循环的末尾水平总结:

#include <stdio.h> 
#include <stdlib.h> 
#include <time.h> 
#include <emmintrin.h> 

// reference implementation 
int fast_compare_ref(const char *s, const char *t, int length) 
{ 
    int result = 0; 
    int i; 

    for (i = 0; i < length; ++i) 
    { 
     if (s[i] == t[i]) 
      result++; 
    } 
    return result; 
} 

// optimised implementation 
int fast_compare(const char *s, const char *t, int length) 
{ 
    int result = 0; 
    int i; 

    __m128i vsum = _mm_set1_epi32(0); 
    for (i = 0; i < length - 15; i += 16) 
    { 
     __m128i vs, vt, v, vh, vl, vtemp; 

     vs = _mm_loadu_si128((__m128i *)&s[i]); // load 16 chars from input 
     vt = _mm_loadu_si128((__m128i *)&t[i]); 
     v = _mm_cmpeq_epi8(vs, vt);    // compare 
     vh = _mm_unpackhi_epi8(v, v);   // unpack compare result into 2 x 8 x 16 bit vectors 
     vl = _mm_unpacklo_epi8(v, v); 
     vtemp = _mm_madd_epi16(vh, vh);   // accumulate 16 bit vectors into 4 x 32 bit partial sums 
     vsum = _mm_add_epi32(vsum, vtemp); 
     vtemp = _mm_madd_epi16(vl, vl); 
     vsum = _mm_add_epi32(vsum, vtemp); 
    } 

    // get sum of 4 x 32 bit partial sums 
    vsum = _mm_add_epi32(vsum, _mm_srli_si128(vsum, 8)); 
    vsum = _mm_add_epi32(vsum, _mm_srli_si128(vsum, 4)); 
    result = _mm_cvtsi128_si32(vsum); 

    // handle any residual bytes (< 16) 
    if (i < length) 
    { 
     result += fast_compare_ref(&s[i], &t[i], length - i); 
    } 

    return result; 
} 

// test harness 
int main(void) 
{ 
    const int n = 1000000; 
    char *s = malloc(n); 
    char *t = malloc(n); 
    int i, result_ref, result; 

    srand(time(NULL)); 

    for (i = 0; i < n; ++i) 
    { 
     s[i] = rand(); 
     t[i] = rand(); 
    } 

    result_ref = fast_compare_ref(s, t, n); 
    result = fast_compare(s, t, n); 

    printf("result_ref = %d, result = %d\n", result_ref, result);; 

    return 0; 
} 

编译并运行上述测试工具:

$ gcc -Wall -O3 -msse3 fast_compare.c -o fast_compare 
$ ./fast_compare 
result_ref = 3955, result = 3955 
$ ./fast_compare 
result_ref = 3947, result = 3947 
$ ./fast_compare 
result_ref = 3945, result = 3945 

请注意,有在我们使用_mm_madd_epi16解包和积累16位以上SSE代码一个可能的非显而易见的伎俩0/-1值到32位部分和。我们利用了这个事实-1*-1 = 1(当然还有0*0 = 0) - 我们在这里并没有真正进行乘法运算,只是在一条指令中进行解包和求和。


UPDATE:在下面的评论中指出,该方案不是最优的 - 我只花了相当最佳的16位解决方案,并加入8位到16位的拆包,使之成为8位的数据。然而,对于8位数据,存在更有效的方法,例如,使用psadbw/_mm_sad_epu8。我将这里的答案留给子孙后代,对于任何想要用16位数据做这种事情的人来说,但其他答案中的其中一个不需要拆开输入数据的答案应该是公认的答案。

+0

好极了!它正常工作!而且,这两个向量's'和't'是否对齐是重要的?什么是对齐? – enzom83 2013-03-11 14:08:12

+1

我已经在上面的例子中使用了'_mm_loadu_si128',所以它对齐并不重要。如果你可以保证's'和't'是16字节对齐的,那么使用'_mm_load_si128'而不是'_mm_loadu_si128'来获得更好的性能,特别是在较老的CPU上。 – 2013-03-11 21:38:41

+0

_mm_setzero_si128()可能比_mm_set1_epi32(0)更快以便将vsum置零。 – leecbaker 2015-02-03 18:26:26

1

SSE中的整数比较产生全部为零或全为1的字节。如果要计数,首先需要将比较结果右移7位(不算算),然后添加到结果向量中。最后,您仍然需要通过对其元素进行求和来减少结果向量。这种减少必须在标量代码中完成,或者使用一系列添加/移位来完成。通常这部分是不值得困扰的。

3

使用16 x uint8元素中的部分和可能会获得更好的性能。
我已经将循环分为内循环和外循环。
内部循环和uint8元素(每个uint8元素可以总计为255“1”)。小技巧:_mm_cmpeq_epi8将相等元素设置为0xFF,并且(char)0xFF = -1,因此您可以从总和中减去结果(减1以加1)。

这里是我的fast_compare优化版本:

int fast_compare2(const char *s, const char *t, int length) 
{ 
    int result = 0; 
    int inner_length = length; 
    int i; 
    int j = 0; 

    //Points beginning of 4080 elements block. 
    const char *s0 = s; 
    const char *t0 = t; 


    __m128i vsum = _mm_setzero_si128(); 

    //Outer loop sum result of 4080 sums. 
    for (i = 0; i < length; i += 4080) 
    { 
     __m128i vsum_uint8 = _mm_setzero_si128(); //16 uint8 sum elements (each uint8 element can sum up to 255). 
     __m128i vh, vl, vhl, vhl_lo, vhl_hi; 

     //Points beginning of 4080 elements block. 
     s0 = s + i; 
     t0 = t + i; 

     if (i + 4080 <= length) 
     { 
      inner_length = 4080; 
     } 
     else 
     { 
      inner_length = length - i; 
     } 

     //Inner loop - sum up to 4080 (compared) results. 
     //Each uint8 element can sum up to 255. 16 uint8 elements can sum up to 255*16 = 4080 (compared) results. 
     ////////////////////////////////////////////////////////////////////////// 
     for (j = 0; j < inner_length-15; j += 16) 
     { 
       __m128i vs, vt, v; 

       vs = _mm_loadu_si128((__m128i *)&s0[j]); // load 16 chars from input 
       vt = _mm_loadu_si128((__m128i *)&t0[j]); 
       v = _mm_cmpeq_epi8(vs, vt);    // compare - set to 0xFF where equal, and 0 otherwise. 

       //Consider this: (char)0xFF = (-1) 
       vsum_uint8 = _mm_sub_epi8(vsum_uint8, v); //Subtract the comparison result - subtract (-1) where equal. 
     } 
     ////////////////////////////////////////////////////////////////////////// 

     vh = _mm_unpackhi_epi8(vsum_uint8, _mm_setzero_si128());  // unpack result into 2 x 8 x 16 bit vectors 
     vl = _mm_unpacklo_epi8(vsum_uint8, _mm_setzero_si128()); 
     vhl = _mm_add_epi16(vh, vl); //Sum high and low as uint16 elements. 

     vhl_hi = _mm_unpackhi_epi16(vhl, _mm_setzero_si128()); //unpack sum of vh an vl into 2 x 4 x 32 bit vectors 
     vhl_lo = _mm_unpacklo_epi16(vhl, _mm_setzero_si128()); //unpack sum of vh an vl into 2 x 4 x 32 bit vectors 

     vsum = _mm_add_epi32(vsum, vhl_hi); 
     vsum = _mm_add_epi32(vsum, vhl_lo); 
    } 

    // get sum of 4 x 32 bit partial sums 
    vsum = _mm_add_epi32(vsum, _mm_srli_si128(vsum, 8)); 
    vsum = _mm_add_epi32(vsum, _mm_srli_si128(vsum, 4)); 
    result = _mm_cvtsi128_si32(vsum); 

    // handle any residual bytes (< 16) 
    if (j < inner_length) 
    { 
     result += fast_compare_ref(&s0[j], &t0[j], inner_length - j); 
    } 

    return result; 
} 
+0

嘿,我应该在评论保罗的之前看过新的答案;我建议同样的事情(内部循环内的'psubb')。这就是我的意思,除了你应该使用'psadbw'来执行'vsum_uint8'的水平总和(参见我对Paul的回答的评论)。 – 2016-06-21 05:32:17

+0

我想过使用水平和,但决定保持SSE2的兼容性。 – Rotem 2016-06-21 17:38:09

+0

你在说“phaddd”吗?这不是我说的。 'phaddd' [唯一的优点是代码大小](http:// stackoverflow。com/questions/6996764 /最快的方式做水平浮点矢量求和在x86/35270026#35270026)在当前的CPU。另请参阅我对此问题的回答,该问题仅使用SSE2指令。 – 2016-06-21 17:40:53

2

最快的方式为大投入是Rotem公司的回答,其中内环为pcmpeqb/psubb,载体的任何字节元素之前打破了在水平方向和累加器溢出。针对全零矢量,使用psadbw对无符号字节进行处理。

没有展开/嵌套循环,最好的选择可能是

pcmpeqb -> vector of 0 or 0xFF elements 
psadbw -> two 64bit sums of (0*no_matches + 0xFF*matches) 
paddq  -> accumulate the psadbw result in a vector accumulator 

#outside the loop: 
horizontal sum 
divide the result by 255 

如果你没有大量的寄存器压力在循环,psadbw反对0x7f而不是全零向量。

  • psadbw(0x00, set1(0x7f)) =>sum += 0x7f
  • psadbw(0xff, set1(0x7f)) =>sum += 0x80

因此,而不是由255分(编译器应该有效地做没有实际div),你只需要减去n * 0x7f,其中n是元素的数量。

另外请注意,paddq在Nehalem和Atom上运行缓慢,因此如果您不希望128 *计数溢出32位整数,您可以使用paddd_mm_add_epi32)。

这与Paul R的pcmpeqb/2x punpck/2x pmaddwd/2x paddw非常相似。