2012-04-27 18 views
2

这是我为乐趣而编写的Levenshtein距离的并行实现。我对结果感到失望。我在一个核心的i7处理器上运行它,所以我有很多可用的线程。但是,随着线程数的增加,性能显着下降。由此我的意思是它实际上运行速度较慢与更多的线程输入相同的大小。Levenshtein距离的并行实现使用更多的线程减慢

我希望有人可以看看我使用线程和java.util.concurrent包的方式,并告诉我如果我做错了什么。我真的只关心为什么并行性不能像我期望的那样工作。我不希望读者看到这里发生的复杂索引。我相信我所做的计算是正确的。但即使它们不是,我认为我应该看到接近线性加速,因为我增加了线程池中的线程数。

我已经包含了我使用的基准代码。我使用库发现here进行基准测试。第二个代码块是我用于基准测试的。

感谢您的任何帮助:)。

import java.util.ArrayList; 
import java.util.List; 
import java.util.concurrent.*; 

public class EditDistance { 
    private static final int MIN_CHUNK_SIZE = 5; 
    private final ExecutorService threadPool; 
    private final int threadCount; 
    private final String maxStr; 
    private final String minStr; 
    private final int maxLen; 
    private final int minLen; 

    public EditDistance(String s1, String s2, ExecutorService threadPool, 
      int threadCount) { 
     this.threadCount = threadCount; 
     this.threadPool = threadPool; 
     if (s1.length() < s2.length()) { 
      minStr = s1; 
      maxStr = s2; 
     } else { 
      minStr = s2; 
      maxStr = s1; 
     } 
     maxLen = maxStr.length(); 
     minLen = minStr.length(); 
    } 

    public int editDist() { 
     int iterations = maxLen + minLen - 1; 
     int[] prev = new int[0]; 
     int[] current = null; 

     for (int i = 0; i < iterations; i++) { 
      int currentLen; 
      if (i < minLen) { 
       currentLen = i + 1; 
      } else if (i < maxLen) { 
       currentLen = minLen; 
      } else { 
       currentLen = iterations - i; 
      } 

      current = new int[currentLen * 2 - 1]; 
      parallelize(prev, current, currentLen, i); 
      prev = current; 
     } 
     return current[0]; 
    } 

    private void parallelize(int[] prev, int[] current, int currentLen, 
      int iteration) { 
     int chunkSize = Math.max(current.length/threadCount, MIN_CHUNK_SIZE); 
     List<Future<?>> futures = new ArrayList<Future<?>>(currentLen); 
     for (int i = 0; i < currentLen; i += chunkSize) { 
      int stopIdx = Math.min(currentLen, i + chunkSize); 
      Runnable worker = new Worker(prev, current, currentLen, iteration, 
        i, stopIdx); 
      futures.add(threadPool.submit(worker)); 
     } 
     for (Future<?> future : futures) { 
      try { 
       Object result = future.get(); 
       if (result != null) { 
        throw new RuntimeException(result.toString()); 
       } 
      } catch (InterruptedException e) { 
       Thread.currentThread().interrupt(); 
      } catch (ExecutionException e) { 
       // We can only finish the computation if we complete 
       // all subproblems 
       throw new RuntimeException(e); 
      } 
     } 
    } 

    private void doChunk(int[] prev, int[] current, int currentLen, 
      int iteration, int startIdx, int stopIdx) { 
     int mergeStartIdx = (iteration < minLen) ? 0 : 2; 

     for (int i = startIdx; i < stopIdx; i++) { 
      // Edit distance 
      int x; 
      int y; 
      int leftIdx; 
      int downIdx; 
      int diagonalIdx; 
      if (iteration < minLen) { 
       x = i; 
       y = currentLen - i - 1; 
       leftIdx = i * 2 - 2; 
       downIdx = i * 2; 
       diagonalIdx = i * 2 - 1; 
      } else { 
       x = i + iteration - minLen + 1; 
       y = minLen - i - 1; 
       leftIdx = i * 2; 
       downIdx = i * 2 + 2; 
       diagonalIdx = i * 2 + 1; 
      } 
      int left = 1 + ((leftIdx < 0) ? iteration + 1 : prev[leftIdx]); 
      int down = 1 + ((downIdx < prev.length) ? prev[downIdx] 
        : iteration + 1); 
      int diagonal = penalty(x, y) 
        + ((diagonalIdx < 0 || diagonalIdx >= prev.length) ? iteration 
          : prev[diagonalIdx]); 
      int dist = Math.min(left, Math.min(down, diagonal)); 
      current[i * 2] = dist; 

      // Merge prev 
      int mergeIdx = i * 2 + 1; 
      if (mergeIdx < current.length) { 
       current[mergeIdx] = prev[mergeStartIdx + i * 2]; 
      } 
     } 
    } 

    private int penalty(int maxIdx, int minIdx) { 
     return (maxStr.charAt(maxIdx) == minStr.charAt(minIdx)) ? 0 : 1; 
    } 

    private class Worker implements Runnable { 
     private final int[] prev; 
     private final int[] current; 
     private final int currentLen; 
     private final int iteration; 
     private final int startIdx; 
     private final int stopIdx; 

     Worker(int[] prev, int[] current, int currentLen, int iteration, 
       int startIdx, int stopIdx) { 
      this.prev = prev; 
      this.current = current; 
      this.currentLen = currentLen; 
      this.iteration = iteration; 
      this.startIdx = startIdx; 
      this.stopIdx = stopIdx; 
     } 

     @Override 
     public void run() { 
      doChunk(prev, current, currentLen, iteration, startIdx, stopIdx); 
     } 
    } 

    public static void main(String args[]) { 
     int threadCount = 4; 
     ExecutorService threadPool = Executors.newFixedThreadPool(threadCount); 
     EditDistance ed = new EditDistance("Saturday", "Sunday", threadPool, 
       threadCount); 
     System.out.println(ed.editDist()); 
     threadPool.shutdown(); 
    } 
} 

EditDistance中有一个私有的内部类Worker。每个工作人员负责使用EditDistance.doChunk填写当前数组的范围。 EditDistance.parallelize负责创建这些工作人员,并等待他们完成任务。

,我使用的基准代码:

import java.io.PrintStream; 
import java.util.concurrent.*; 
import org.apache.commons.lang3.RandomStringUtils; 
import bb.util.Benchmark; 

public class EditDistanceBenchmark { 

    public static void main(String[] args) { 
     if (args.length != 2) { 
      System.out.println("Usage: <string length> <thread count>"); 
      System.exit(1); 
     } 
     PrintStream oldOut = System.out; 
     System.setOut(System.err); 

     int strLen = Integer.parseInt(args[0]); 
     int threadCount = Integer.parseInt(args[1]); 
     String s1 = RandomStringUtils.randomAlphabetic(strLen); 
     String s2 = RandomStringUtils.randomAlphabetic(strLen); 
     ExecutorService threadPool = Executors.newFixedThreadPool(threadCount); 

     Benchmark b = new Benchmark(new Benchmarker(s1, s2, threadPool,threadCount)); 
     System.setOut(oldOut); 

     System.out.println("threadCount: " + threadCount + 
       " string length: "+ strLen + "\n\n" + b); 
     System.out.println("s1: " + s1 + "\ns2: " + s2); 

     threadPool.shutdown(); 
    } 

    private static class Benchmarker implements Runnable { 
     private final String s1, s2; 
     private final int threadCount; 
     private final ExecutorService threadPool; 

     private Benchmarker(String s1, String s2, ExecutorService threadPool, int threadCount) { 
      this.s1 = s1; 
      this.s2 = s2; 
      this.threadPool = threadPool; 
      this.threadCount = threadCount; 
     } 

     @Override 
     public void run() { 
      EditDistance d = new EditDistance(s1, s2, threadPool, threadCount); 
      d.editDist(); 
     } 

    } 
} 
+0

你有没有尝试用非常长的字符串对代码进行基准测试?对于这样短的字符串,问题可能是不断增加的线程管理开销超过了使用多线程的好处。 – flacs 2012-04-27 22:04:18

+0

@Tilka你看过我提供的基准代码吗?我正在运行那个长度为10000-50000个字符的字符串。如果将字符串长度设置为EditDistanceBenchmark,则会为您生成这些长字符串。 – Kevin 2012-04-28 03:31:21

回答

2

这很容易不小心编写代码,不并行很好。主要的罪魁祸首是你的线程竞争底层系统资源(例如缓存线)。由于这种算法固有地作用于物理记忆中彼此接近的事物,我怀疑这可能是罪魁祸首。

我建议你检讨错误共享

http://www.drdobbs.com/go-parallel/article/217500206?pgno=3

这个优秀的文章,然后仔细检查你的代码,其中线程将阻止另一个案例。

此外,如果您的线程受到CPU限制(如果您已经使用所有内核接近100%,添加更多线程将只会增加上下文切换的开销),那么运行更多的线程会比CPU内核更慢, 。

+0

看起来这可能确实是罪魁祸首。本文仅讨论C(++)中的虚假共享。如果你知道任何有用的文章解决Java中的相同问题,那将是有帮助的。 – Kevin 2012-04-28 15:57:57