2011-03-29 39 views
2

我正在做一些关于Java 7中的叉/加入框架的性能研究。为了改善测试结果,我想在测试过程中使用不同的递归算法。其中之一是乘法矩阵。在Java中叉加入矩阵乘法

public class MatrixMultiply { 

    static final int DEFAULT_GRANULARITY = 16; 

    /** The quadrant size at which to stop recursing down 
    * and instead directly multiply the matrices. 
    * Must be a power of two. Minimum value is 2. 
    **/ 
    static int granularity = DEFAULT_GRANULARITY; 

    public static void main(String[] args) { 

    final String usage = "Usage: java MatrixMultiply <threads> <matrix size (must be a power of two)> [<granularity>] \n Size and granularity must be powers of two.\n For example, try java MatrixMultiply 2 512 16"; 

    try { 
     int procs; 
     int n; 
     try { 
     procs = Integer.parseInt(args[0]); 
     n = Integer.parseInt(args[1]); 
     if (args.length > 2) granularity = Integer.parseInt(args[2]); 
     } 

     catch (Exception e) { 
     System.out.println(usage); 
     return; 
     } 

     if (((n & (n - 1)) != 0) || 
      ((granularity & (granularity - 1)) != 0) || 
      granularity < 2) { 
     System.out.println(usage); 
     return; 
     } 

     float[][] a = new float[n][n]; 
     float[][] b = new float[n][n]; 
     float[][] c = new float[n][n]; 
     init(a, b, n); 

     FJTaskRunnerGroup g = new FJTaskRunnerGroup(procs); 
     g.invoke(new Multiplier(a, 0, 0, b, 0, 0, c, 0, 0, n)); 
     g.stats(); 

     // check(c, n); 
    } 
    catch (InterruptedException ex) {} 
    } 


    // To simplify checking, fill with all 1's. Answer should be all n's. 
    static void init(float[][] a, float[][] b, int n) { 
    for (int i = 0; i < n; ++i) { 
     for (int j = 0; j < n; ++j) { 
     a[i][j] = 1.0F; 
     b[i][j] = 1.0F; 
     } 
    } 
    } 

    static void check(float[][] c, int n) { 
    for (int i = 0; i < n; i++) { 
     for (int j = 0; j < n; j++) { 
     if (c[i][j] != n) { 
      throw new Error("Check Failed at [" + i +"]["+j+"]: " + c[i][j]); 
     } 
     } 
    } 
    } 

    /** 
    * Multiply matrices AxB by dividing into quadrants, using algorithm: 
    * <pre> 
    *  A  x  B        
    * 
    * A11 | A12  B11 | B12  A11*B11 | A11*B12  A12*B21 | A12*B22 
    * |----+----| x |----+----| = |--------+--------| + |---------+-------| 
    * A21 | A22  B21 | B21  A21*B11 | A21*B21  A22*B21 | A22*B22 
    * </pre> 
    */ 


    static class Multiplier extends FJTask { 
    final float[][] A; // Matrix A 
    final int aRow;  // first row of current quadrant of A 
    final int aCol;  // first column of current quadrant of A 

    final float[][] B; // Similarly for B 
    final int bRow; 
    final int bCol; 

    final float[][] C; // Similarly for result matrix C 
    final int cRow; 
    final int cCol; 

    final int size;  // number of elements in current quadrant 

    Multiplier(float[][] A, int aRow, int aCol, 
       float[][] B, int bRow, int bCol, 
       float[][] C, int cRow, int cCol, 
       int size) { 
     this.A = A; this.aRow = aRow; this.aCol = aCol; 
     this.B = B; this.bRow = bRow; this.bCol = bCol; 
     this.C = C; this.cRow = cRow; this.cCol = cCol; 
     this.size = size; 
    } 

    public void run() { 

     if (size <= granularity) { 
     multiplyStride2(); 
     } 

     else { 
     int h = size/2; 

     coInvoke(new FJTask[] { 
      seq(new Multiplier(A, aRow, aCol, // A11 
          B, bRow, bCol, // B11 
          C, cRow, cCol, // C11 
          h), 
       new Multiplier(A, aRow, aCol+h, // A12 
          B, bRow+h, bCol, // B21 
          C, cRow, cCol, // C11 
          h)), 

      seq(new Multiplier(A, aRow, aCol, // A11 
          B, bRow, bCol+h, // B12 
          C, cRow, cCol+h, // C12 
          h), 
       new Multiplier(A, aRow, aCol+h, // A12 
          B, bRow+h, bCol+h, // B22 
          C, cRow, cCol+h, // C12 
          h)), 

      seq(new Multiplier(A, aRow+h, aCol, // A21 
          B, bRow, bCol, // B11 
          C, cRow+h, cCol, // C21 
          h), 
       new Multiplier(A, aRow+h, aCol+h, // A22 
          B, bRow+h, bCol, // B21 
          C, cRow+h, cCol, // C21 
          h)), 

      seq(new Multiplier(A, aRow+h, aCol, // A21 
          B, bRow, bCol+h, // B12 
          C, cRow+h, cCol+h, // C22 
          h), 
       new Multiplier(A, aRow+h, aCol+h, // A22 
          B, bRow+h, bCol+h, // B22 
          C, cRow+h, cCol+h, // C22 
          h)) 
     }); 
     } 
    } 

    /** 
    * Version of matrix multiplication that steps 2 rows and columns 
    * at a time. Adapted from Cilk demos. 
    * Note that the results are added into C, not just set into C. 
    * This works well here because Java array elements 
    * are created with all zero values. 
    **/ 

    void multiplyStride2() { 
     for (int j = 0; j < size; j+=2) { 
     for (int i = 0; i < size; i +=2) { 

      float[] a0 = A[aRow+i]; 
      float[] a1 = A[aRow+i+1]; 

      float s00 = 0.0F; 
      float s01 = 0.0F; 
      float s10 = 0.0F; 
      float s11 = 0.0F; 

      for (int k = 0; k < size; k+=2) { 

      float[] b0 = B[bRow+k]; 

      s00 += a0[aCol+k] * b0[bCol+j]; 
      s10 += a1[aCol+k] * b0[bCol+j]; 
      s01 += a0[aCol+k] * b0[bCol+j+1]; 
      s11 += a1[aCol+k] * b0[bCol+j+1]; 

      float[] b1 = B[bRow+k+1]; 

      s00 += a0[aCol+k+1] * b1[bCol+j]; 
      s10 += a1[aCol+k+1] * b1[bCol+j]; 
      s01 += a0[aCol+k+1] * b1[bCol+j+1]; 
      s11 += a1[aCol+k+1] * b1[bCol+j+1]; 
      } 

      C[cRow+i] [cCol+j] += s00; 
      C[cRow+i] [cCol+j+1] += s01; 
      C[cRow+i+1][cCol+j] += s10; 
      C[cRow+i+1][cCol+j+1] += s11; 
     } 
     } 
    } 

    } 

} 

该代码可用于旧版本的叉的书面/ join框架:

我Doug Lea的网站()下载下面的例子。所以我必须重写它。我重写的代码实现我自己的接口,看起来像这样:

public class Java7MatrixMultiply implements Algorithm { 
    private static final int SIZE = 32; 
    private static final int THRESHOLD = 8; 

    private float[][] a = new float[SIZE][SIZE]; 
    private float[][] b = new float[SIZE][SIZE]; 
    private float[][] c = new float[SIZE][SIZE]; 

    ForkJoinPool forkJoinPool; 

    @Override 
    public void initialize() { 
     init(a, b, SIZE); 
    } 

    @Override 
    public void execute() { 
     MatrixMultiplyTask mainTask = new MatrixMultiplyTask(a, 0, 0, b, 0, 0, c, 0, 0, SIZE); 
     forkJoinPool = new ForkJoinPool(); 
     forkJoinPool.invoke(mainTask); 

     System.out.println("Terminated!"); 
    } 

    @Override 
    public void printResult() { 
     check(c, SIZE); 

     for (int i = 0; i < SIZE; i++) { 
      for (int j = 0; j < SIZE; j++) { 
       System.out.print(c[i][j] + " "); 
      } 

      System.out.println(); 
     } 
    } 

    // To simplify checking, fill with all 1's. Answer should be all n's. 
    static void init(float[][] a, float[][] b, int n) { 
     for (int i = 0; i < n; ++i) { 
      for (int j = 0; j < n; ++j) { 
       a[i][j] = 1.0F; 
       b[i][j] = 1.0F; 
      } 
     } 
    } 

    static void check(float[][] c, int n) { 
     for (int i = 0; i < n; i++) { 
      for (int j = 0; j < n; j++) { 
       if (c[i][j] != n) { 
        //throw new Error("Check Failed at [" + i + "][" + j + "]: " + c[i][j]); 
        System.out.println("Check Failed at [" + i + "][" + j + "]: " + c[i][j]); 
       } 
      } 
     } 
    } 

    private class MatrixMultiplyTask extends RecursiveAction { 
     private final float[][] A; // Matrix A 
     private final int aRow; // first row of current quadrant of A 
     private final int aCol; // first column of current quadrant of A 

     private final float[][] B; // Similarly for B 
     private final int bRow; 
     private final int bCol; 

     private final float[][] C; // Similarly for result matrix C 
     private final int cRow; 
     private final int cCol; 

     private final int size; 

     MatrixMultiplyTask(float[][] A, int aRow, int aCol, float[][] B, 
       int bRow, int bCol, float[][] C, int cRow, int cCol, int size) { 
      this.A = A; 
      this.aRow = aRow; 
      this.aCol = aCol; 
      this.B = B; 
      this.bRow = bRow; 
      this.bCol = bCol; 
      this.C = C; 
      this.cRow = cRow; 
      this.cCol = cCol; 
      this.size = size; 
     } 

     @Override 
     protected void compute() {  
      if (size <= THRESHOLD) { 
       multiplyStride2(); 
      } else { 

       int h = size/2;    

       invokeAll(new MatrixMultiplyTask[] { 
         new MatrixMultiplyTask(A, aRow, aCol, // A11 
           B, bRow, bCol, // B11 
           C, cRow, cCol, // C11 
           h), 

         new MatrixMultiplyTask(A, aRow, aCol + h, // A12 
           B, bRow + h, bCol, // B21 
           C, cRow, cCol, // C11 
           h), 

         new MatrixMultiplyTask(A, aRow, aCol, // A11 
           B, bRow, bCol + h, // B12 
           C, cRow, cCol + h, // C12 
           h), 

         new MatrixMultiplyTask(A, aRow, aCol + h, // A12 
           B, bRow + h, bCol + h, // B22 
           C, cRow, cCol + h, // C12 
           h), 

         new MatrixMultiplyTask(A, aRow + h, aCol, // A21 
           B, bRow, bCol, // B11 
           C, cRow + h, cCol, // C21 
           h), 

         new MatrixMultiplyTask(A, aRow + h, aCol + h, // A22 
           B, bRow + h, bCol, // B21 
           C, cRow + h, cCol, // C21 
           h), 

         new MatrixMultiplyTask(A, aRow + h, aCol, // A21 
           B, bRow, bCol + h, // B12 
           C, cRow + h, cCol + h, // C22 
           h), 

         new MatrixMultiplyTask(A, aRow + h, aCol + h, // A22 
           B, bRow + h, bCol + h, // B22 
           C, cRow + h, cCol + h, // C22 
           h) }); 

      } 
     } 

     /** 
     * Version of matrix multiplication that steps 2 rows and columns at a 
     * time. Adapted from Cilk demos. Note that the results are added into 
     * C, not just set into C. This works well here because Java array 
     * elements are created with all zero values. 
     **/ 

     void multiplyStride2() { 
      for (int j = 0; j < size; j += 2) { 
       for (int i = 0; i < size; i += 2) { 

        float[] a0 = A[aRow + i]; 
        float[] a1 = A[aRow + i + 1]; 

        float s00 = 0.0F; 
        float s01 = 0.0F; 
        float s10 = 0.0F; 
        float s11 = 0.0F; 

        for (int k = 0; k < size; k += 2) { 

         float[] b0 = B[bRow + k]; 

         s00 += a0[aCol + k] * b0[bCol + j]; 
         s10 += a1[aCol + k] * b0[bCol + j]; 
         s01 += a0[aCol + k] * b0[bCol + j + 1]; 
         s11 += a1[aCol + k] * b0[bCol + j + 1]; 

         float[] b1 = B[bRow + k + 1]; 

         s00 += a0[aCol + k + 1] * b1[bCol + j]; 
         s10 += a1[aCol + k + 1] * b1[bCol + j]; 
         s01 += a0[aCol + k + 1] * b1[bCol + j + 1]; 
         s11 += a1[aCol + k + 1] * b1[bCol + j + 1]; 
        } 

        C[cRow + i][cCol + j] += s00; 
        C[cRow + i][cCol + j + 1] += s01; 
        C[cRow + i + 1][cCol + j] += s10; 
        C[cRow + i + 1][cCol + j + 1] += s11; 
       } 
      } 
     } 
    } 
} 

有时我的计算未能通过检查。矩阵的某些区域具有与预期不同的值。这些不一致是随机的,并不总是会发生。我怀疑计算方法出了问题,因为我不得不重写使用Seq类的部分。 Seq klass按顺序执行任务,与invokeAll()方法不同。在当前版本的fork/join框架中,该类不再存在。我对矩阵乘法算法不是很熟悉,所以很难看出错误。有什么建议么?

回答

0

正如您已经注意到的那样,属于同一象限的子任务的顺序执行对于此算法很重要。因此,您需要实现您自己的seq()函数,例如,如下所示,并将其用作原始代码:

public ForkJoinTask<?> seq(final ForkJoinTask<?> a, final ForkJoinTask<?> b) { 
    return adapt(new Runnable() { 
     public void run() { 
      a.invoke(); 
      b.invoke(); 
     } 
    }); 
} 
+0

谢谢。它完美无瑕。 – TheArchitect 2011-03-29 13:56:02

1

您正在积累C[cRow + i][cCol + j] += s00;等的结果。这不是线程安全操作,因此您必须同步行或确保只有一个任务更新单元。没有这个,你会看到随机单元格设置不正确。

我会检查你找到正确答案,用1

BTW并发:float可能不是这里最好的选择。它的精度数字相当低,而且在繁重的矩阵操作中(我假设你正在做或者没有多少点使用多个线程),舍入误差可能会占用大部分或全部的精度。我建议改为考虑double

例如float大约有7位数字的精度,一个经验法则是错误与计算次数成正比。因此,对于1K x 1K矩阵,您可能会有4位数字的精度。对于10K x 10K,最多只能有三个。 double具有16位数字的精度,这意味着在10K x 10K加密后可能有12位精度。

+0

感谢您的快速回复。我测试了并发性为1的算法。没有发生错误。你的解释当然是对的,但使用同步并不是很有效。 Dough Lea在原始代码中也没有使用这种方法。是否有可能重新实现计算方法,因此不需要同步? – TheArchitect 2011-03-29 13:12:13