欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页  >  IT编程

java实现任意矩阵Strassen算法

程序员文章站 2024-03-09 11:00:05
本例输入为两个任意尺寸的矩阵m * n, n * m,输出为两个矩阵的乘积。计算任意尺寸矩阵相乘时,使用了strassen算法。程序为自编,经过测试,请放心使用。基本算法是...

本例输入为两个任意尺寸的矩阵m * n, n * m,输出为两个矩阵的乘积。计算任意尺寸矩阵相乘时,使用了strassen算法。程序为自编,经过测试,请放心使用。基本算法是:
1.对于方阵(正方形矩阵),找到最大的l, 使得l = 2 ^ k, k为整数并且l < m。边长为l的方形矩阵则采用strassen算法,其余部分以及方形矩阵中遗漏的部分用蛮力法。
2.对于非方阵,依照行列相应添加0使其成为方阵。
strassenmethodtest.java

package matrixalgorithm;
 
import java.util.scanner;
 
public class strassenmethodtest {
 
  private strassenmethod strassenmultiply;
   
   strassenmethodtest(){
    strassenmultiply = new strassenmethod();
  }//end cons
 
   public static void main(string[] args){
    scanner input = new scanner(system.in);
    system.out.println("input row size of the first matrix: ");
    int arow = input.nextint();
    system.out.println("input column size of the first matrix: ");
    int acol = input.nextint();
    system.out.println("input row size of the second matrix: ");
    int brow = input.nextint();
    system.out.println("input column size of the second matrix: ");
    int bcol = input.nextint();
 
    double[][] a = new double[arow][acol];
    double[][] b = new double[brow][bcol];
    double[][] c = new double[arow][bcol];
    system.out.println("input data for matrix a: ");
     
    /*in all of the codes later in this project,
    r means row while c means column.
    */
    for (int r = 0; r < arow; r++) {
      for (int c = 0; c < acol; c++) {
        system.out.printf("data of a[%d][%d]: ", r, c);
        a[r][c] = input.nextdouble();
      }//end inner loop
    }//end loop
 
    system.out.println("input data for matrix b: ");
    for (int r = 0; r < brow; r++) {
      for (int c = 0; c < bcol; c++) {
        system.out.printf("data of a[%d][%d]: ", r, c);
        b[r][c] = input.nextdouble();
      }//end inner loop
    }//end loop
 
    strassenmethodtest algorithm = new strassenmethodtest();
    c = algorithm.multiplyrectmatrix(a, b, arow, acol, brow, bcol);
 
    //display the calculation result:
    system.out.println("result from matrix c: ");
    for (int r = 0; r < arow; r++) {
      for (int c = 0; c < bcol; c++) {
        system.out.printf("data of c[%d][%d]: %f\n", r, c, c[r][c]);
      }//end inner loop
    }//end outter loop
   }//end main
   
  //deal with matrices that are not square:
  public double[][] multiplyrectmatrix(double[][] a, double[][] b,
      int arow, int acol, int brow, int bcol) {
    if (arow != bcol) //invalid multiplicatio
      return new double[][]{{0}};
    
    double[][] c = new double[arow][bcol];
 
    if (arow < acol) {
       
      double[][] newa = new double[acol][acol];
      double[][] newb = new double[brow][brow];
 
      int n = acol;
       
      for (int r = 0; r < acol; r++) 
        for (int c = 0; c < acol; c++) 
          newa[r][c] = 0.0;
 
      for (int r = 0; r < brow; r++) 
        for (int c = 0; c < brow; c++) 
          newb[r][c] = 0.0;
 
      for (int r = 0; r < arow; r++) 
        for (int c = 0; c < acol; c++) 
          newa[r][c] = a[r][c];
 
      for (int r = 0; r < brow; r++) 
        for (int c = 0; c < bcol; c++) 
          newb[r][c] = b[r][c];
       
      double[][] c2 = multiplysquarematrix(newa, newb, n);
      for(int r = 0; r < arow; r++)
        for(int c = 0; c < bcol; c++)
          c[r][c] = c2[r][c];
    }//end if
     
    else if(arow == acol)
      c = multiplysquarematrix(a, b, arow);
       
    else {
      int n = arow;
      double[][] newa = new double[arow][arow];
      double[][] newb = new double[bcol][bcol];
 
      for (int r = 0; r < arow; r++) 
        for (int c = 0; c < arow; c++) 
          newa[r][c] = 0.0;
 
      for (int r = 0; r < bcol; r++) 
        for (int c = 0; c < bcol; c++) 
          newb[r][c] = 0.0;
 
 
      for (int r = 0; r < arow; r++) 
        for (int c = 0; c < acol; c++) 
          newa[r][c] = a[r][c];
 
      for (int r = 0; r < brow; r++)
        for (int c = 0; c < bcol; c++) 
          newb[r][c] = b[r][c];
 
      double[][] c2 = multiplysquarematrix(newa, newb, n);
      for(int r = 0; r < arow; r++)
        for(int c = 0; c < bcol; c++)
          c[r][c] = c2[r][c];
    }//end else
       
     return c;
   }//end method
   
  //deal with matrices that are square matrices. 
   public double[][] multiplysquarematrix(double[][] a2, double[][] b2, int n){
      
     double[][] c2 = new double[n][n];
    
     for(int r = 0; r < n; r++)
       for(int c = 0; c < n; c++)
         c2[r][c] = 0;
     
     if(n == 1){
      c2[0][0] = a2[0][0] * b2[0][0];
      return c2;
     }//end if
          
     int exp2k = 2;
     
     while(exp2k <= (n / 2) ){
       exp2k *= 2;
     }//end loop
     
     if(exp2k == n){
       c2 = strassenmultiply.strassenmultiplymatrix(a2, b2, n);
       return c2;
     }//end else
     
     //the "biggest" strassen matrix:
     double[][][] a = new double[6][exp2k][exp2k];
     double[][][] b = new double[6][exp2k][exp2k];
     double[][][] c = new double[6][exp2k][exp2k];
     
     for(int r = 0; r < exp2k; r++){
       for(int c = 0; c < exp2k; c++){
         a[0][r][c] = a2[r][c];
         b[0][r][c] = b2[r][c];
       }//end inner loop
     }//end outter loop
     
    c[0] = strassenmultiply.strassenmultiplymatrix(a[0], b[0], exp2k);
     
    for(int r = 0; r < exp2k; r++)
      for(int c = 0; c < exp2k; c++)
        c2[r][c] = c[0][r][c];
     
    int middle = exp2k / 2;
     
    for(int r = 0; r < middle; r++){
      for(int c = exp2k; c < n; c++){
        a[1][r][c - exp2k] = a2[r][c];
        b[3][r][c - exp2k] = b2[r][c];
      }//end inner loop     
    }//end outter loop
     
    for(int r = exp2k; r < n; r++){
      for(int c = 0; c < middle; c++){
        a[3][r - exp2k][c] = a2[r][c];
        b[1][r - exp2k][c] = b2[r][c];
      }//end inner loop     
    }//end outter loop
     
    for(int r = middle; r < exp2k; r++){
      for(int c = exp2k; c < n; c++){
        a[2][r - middle][c - exp2k] = a2[r][c];
        b[4][r - middle][c - exp2k] = b2[r][c];
      }//end inner loop     
    }//end outter loop
     
    for(int r = exp2k; r < n; r++){
      for(int c = middle; c < n - exp2k + 1; c++){
        a[4][r - exp2k][c - middle] = a2[r][c];
        b[2][r - exp2k][c - middle] = b2[r][c];     
      }//end inner loop     
    }//end outter loop
    
    for(int i = 1; i <= 4; i++)
      c[i] = multiplyrectmatrix(a[i], b[i], middle, a[i].length, a[i].length, middle);
     
    /*
    calculate the final results of grids in the "biggest 2^k square,
    according to the rules of matrice multiplication.
    */
    for (int row = 0; row < exp2k; row++) {
       for (int col = 0; col < exp2k; col++) {
         for (int k = exp2k; k < n; k++) {
           c2[row][col] += a2[row][k] * b2[k][col];
         }//end loop
       }//end inner loop
     }//end outter loop
     
    //use brute force to solve the rest, will be improved later:
    for(int col = exp2k; col < n; col++){
      for(int row = 0; row < n; row++){
        for(int k = 0; k < n; k++)
          c2[row][col] += a2[row][k] * b2[k][row];
      }//end inner loop
    }//end outter loop
     
    for(int row = exp2k; row < n; row++){
      for(int col = 0; col < exp2k; col++){
        for(int k = 0; k < n; k++)
          c2[row][col] += a2[row][k] * b2[k][row];
      }//end inner loop
    }//end outter loop   
     
    return c2;
   }//end method
   
}//end class

strassenmethod.java

package matrixalgorithm;
 
import java.util.scanner;
 
public class strassenmethod {
 
  private double[][][][] a = new double[2][2][][];
  private double[][][][] b = new double[2][2][][];
  private double[][][][] c = new double[2][2][][];
 
  /*//codes for testing this class:
    public static void main(string[] args) {
    scanner input = new scanner(system.in);
    system.out.println("input size of the matrix: ");
    int n = input.nextint();
 
    double[][] a = new double[n][n];
    double[][] b = new double[n][n];
    double[][] c = new double[n][n];
    system.out.println("input data for matrix a: ");
    for (int r = 0; r < n; r++) {
      for (int c = 0; c < n; c++) {
        system.out.printf("data of a[%d][%d]: ", r, c);
        a[r][c] = input.nextdouble();
      }//end inner loop
    }//end loop
 
    system.out.println("input data for matrix b: ");
    for (int r = 0; r < n; r++) {
      for (int c = 0; c < n; c++) {
        system.out.printf("data of a[%d][%d]: ", r, c);
        b[r][c] = input.nextdouble();
      }//end inner loop
    }//end loop
 
    strassenmethod algorithm = new strassenmethod();
    c = algorithm.strassenmultiplymatrix(a, b, n);
 
    system.out.println("result from matrix c: ");
    for (int r = 0; r < n; r++) {
      for (int c = 0; c < n; c++) {
        system.out.printf("data of c[%d][%d]: %f\n", r, c, c[r][c]);
      }//end inner loop
    }//end outter loop
 
  }//end main*/
   
   public double[][] strassenmultiplymatrix(double[][] a2, double b2[][], int n){
    double[][] c2 = new double[n][n];
    //initialize the matrix:
    for(int rowindex = 0; rowindex < n; rowindex++)
      for(int colindex = 0; colindex < n; colindex++)
        c2[rowindex][colindex] = 0.0;
 
    if(n == 1)
      c2[0][0] = a2[0][0] * b2[0][0];
    //"slice matrices into 2 * 2 parts: 
    else{
      double[][][][] a = new double[2][2][n / 2][n / 2];
      double[][][][] b = new double[2][2][n / 2][n / 2];
      double[][][][] c = new double[2][2][n / 2][n / 2];
       
      for(int r = 0; r < n / 2; r++){
        for(int c = 0; c < n / 2; c++){          
          a[0][0][r][c] = a2[r][c];
          a[0][1][r][c] = a2[r][n / 2 + c];
          a[1][0][r][c] = a2[n / 2 + r][c];
          a[1][1][r][c] = a2[n / 2 + r][n / 2 + c];
           
          b[0][0][r][c] = b2[r][c];
          b[0][1][r][c] = b2[r][n / 2 + c];
          b[1][0][r][c] = b2[n / 2 + r][c];
          b[1][1][r][c] = b2[n / 2 + r][n / 2 + c];
        }//end loop
      }//end loop
       
      n = n / 2;
       
      double[][][] s = new double[10][n][n];
      s[0] = minusmatrix(b[0][1], b[1][1], n);
      s[1] = addmatrix(a[0][0], a[0][1], n);
      s[2] = addmatrix(a[1][0], a[1][1], n);
      s[3] = minusmatrix(b[1][0], b[0][0], n);
      s[4] = addmatrix(a[0][0], a[1][1], n);
      s[5] = addmatrix(b[0][0], b[1][1], n);
      s[6] = minusmatrix(a[0][1], a[1][1], n);
      s[7] = addmatrix(b[1][0], b[1][1], n);
      s[8] = minusmatrix(a[0][0], a[1][0], n);
      s[9] = addmatrix(b[0][0], b[0][1], n);
       
      double[][][] p = new double[7][n][n];
      p[0] = strassenmultiplymatrix(a[0][0], s[0], n);
      p[1] = strassenmultiplymatrix(s[1], b[1][1], n);
      p[2] = strassenmultiplymatrix(s[2], b[0][0], n);
      p[3] = strassenmultiplymatrix(a[1][1], s[3], n);
      p[4] = strassenmultiplymatrix(s[4], s[5], n);
      p[5] = strassenmultiplymatrix(s[6], s[7], n);
      p[6] = strassenmultiplymatrix(s[8], s[9], n);
       
      c[0][0] = addmatrix(minusmatrix(addmatrix(p[4], p[3], n), p[1], n), p[5], n);
      c[0][1] = addmatrix(p[0], p[1], n);
      c[1][0] = addmatrix(p[2], p[3], n);
      c[1][1] = minusmatrix(minusmatrix(addmatrix(p[4], p[0], n), p[2], n), p[6], n);
       
      n *= 2;
       
       for(int r = 0; r < n / 2; r++){
        for(int c = 0; c < n / 2; c++){
          c2[r][c] = c[0][0][r][c];
          c2[r][n / 2 + c] = c[0][1][r][c];
          c2[n / 2 + r][c] = c[1][0][r][c];
          c2[n / 2 + r][n / 2 + c] = c[1][1][r][c];
        }//end inner loop
      }//end outter loop
    }//end else     
 
    return c2;
  }//end method
   
   //add two matrices according to matrix addition.
   private double[][] addmatrix(double[][] a, double[][] b, int n){
    double c[][] = new double[n][n];
     
    for(int r = 0; r < n; r++)
      for(int c = 0; c < n; c++)
        c[r][c] = a[r][c] + b[r][c];
     
    return c;
  }//end method 
   
   //substract two matrices according to matrix addition.
   private double[][] minusmatrix(double[][] a, double[][] b, int n){
    double c[][] = new double[n][n];
     
    for(int r = 0; r < n; r++)
      for(int c = 0; c < n; c++)
        c[r][c] = a[r][c] - b[r][c];
     
    return c;
  }//end method
   
}//end class

希望本文所述对大家学习java程序设计有所帮助。