java实现任意矩阵Strassen算法
程序员文章站
2024-03-08 16:24:34
本例输入为两个任意尺寸的矩阵m * n, n * m,输出为两个矩阵的乘积。计算任意尺寸矩阵相乘时,使用了strassen算法。程序为自编,经过测试,请放心使用。基本算法是...
本例输入为两个任意尺寸的矩阵m * n, n * m,输出为两个矩阵的乘积。计算任意尺寸矩阵相乘时,使用了strassen算法。程序为自编,经过测试,请放心使用。基本算法是:
1.对于方阵(正方形矩阵),找到最大的l, 使得l = 2 ^ k, k为整数并且l < m。边长为l的方形矩阵则采用strassen算法,其余部分以及方形矩阵中遗漏的部分用蛮力法。
2.对于非方阵,依照行列相应添加0使其成为方阵。
strassenmethodtest.java
package matrixalgorithm; import java.util.scanner; public class strassenmethodtest { private strassenmethod strassenmultiply; strassenmethodtest(){ strassenmultiply = new strassenmethod(); }//end cons public static void main(string[] args){ scanner input = new scanner(system.in); system.out.println("input row size of the first matrix: "); int arow = input.nextint(); system.out.println("input column size of the first matrix: "); int acol = input.nextint(); system.out.println("input row size of the second matrix: "); int brow = input.nextint(); system.out.println("input column size of the second matrix: "); int bcol = input.nextint(); double[][] a = new double[arow][acol]; double[][] b = new double[brow][bcol]; double[][] c = new double[arow][bcol]; system.out.println("input data for matrix a: "); /*in all of the codes later in this project, r means row while c means column. */ for (int r = 0; r < arow; r++) { for (int c = 0; c < acol; c++) { system.out.printf("data of a[%d][%d]: ", r, c); a[r][c] = input.nextdouble(); }//end inner loop }//end loop system.out.println("input data for matrix b: "); for (int r = 0; r < brow; r++) { for (int c = 0; c < bcol; c++) { system.out.printf("data of a[%d][%d]: ", r, c); b[r][c] = input.nextdouble(); }//end inner loop }//end loop strassenmethodtest algorithm = new strassenmethodtest(); c = algorithm.multiplyrectmatrix(a, b, arow, acol, brow, bcol); //display the calculation result: system.out.println("result from matrix c: "); for (int r = 0; r < arow; r++) { for (int c = 0; c < bcol; c++) { system.out.printf("data of c[%d][%d]: %f\n", r, c, c[r][c]); }//end inner loop }//end outter loop }//end main //deal with matrices that are not square: public double[][] multiplyrectmatrix(double[][] a, double[][] b, int arow, int acol, int brow, int bcol) { if (arow != bcol) //invalid multiplicatio return new double[][]{{0}}; double[][] c = new double[arow][bcol]; if (arow < acol) { double[][] newa = new double[acol][acol]; double[][] newb = new double[brow][brow]; int n = acol; for (int r = 0; r < acol; r++) for (int c = 0; c < acol; c++) newa[r][c] = 0.0; for (int r = 0; r < brow; r++) for (int c = 0; c < brow; c++) newb[r][c] = 0.0; for (int r = 0; r < arow; r++) for (int c = 0; c < acol; c++) newa[r][c] = a[r][c]; for (int r = 0; r < brow; r++) for (int c = 0; c < bcol; c++) newb[r][c] = b[r][c]; double[][] c2 = multiplysquarematrix(newa, newb, n); for(int r = 0; r < arow; r++) for(int c = 0; c < bcol; c++) c[r][c] = c2[r][c]; }//end if else if(arow == acol) c = multiplysquarematrix(a, b, arow); else { int n = arow; double[][] newa = new double[arow][arow]; double[][] newb = new double[bcol][bcol]; for (int r = 0; r < arow; r++) for (int c = 0; c < arow; c++) newa[r][c] = 0.0; for (int r = 0; r < bcol; r++) for (int c = 0; c < bcol; c++) newb[r][c] = 0.0; for (int r = 0; r < arow; r++) for (int c = 0; c < acol; c++) newa[r][c] = a[r][c]; for (int r = 0; r < brow; r++) for (int c = 0; c < bcol; c++) newb[r][c] = b[r][c]; double[][] c2 = multiplysquarematrix(newa, newb, n); for(int r = 0; r < arow; r++) for(int c = 0; c < bcol; c++) c[r][c] = c2[r][c]; }//end else return c; }//end method //deal with matrices that are square matrices. public double[][] multiplysquarematrix(double[][] a2, double[][] b2, int n){ double[][] c2 = new double[n][n]; for(int r = 0; r < n; r++) for(int c = 0; c < n; c++) c2[r][c] = 0; if(n == 1){ c2[0][0] = a2[0][0] * b2[0][0]; return c2; }//end if int exp2k = 2; while(exp2k <= (n / 2) ){ exp2k *= 2; }//end loop if(exp2k == n){ c2 = strassenmultiply.strassenmultiplymatrix(a2, b2, n); return c2; }//end else //the "biggest" strassen matrix: double[][][] a = new double[6][exp2k][exp2k]; double[][][] b = new double[6][exp2k][exp2k]; double[][][] c = new double[6][exp2k][exp2k]; for(int r = 0; r < exp2k; r++){ for(int c = 0; c < exp2k; c++){ a[0][r][c] = a2[r][c]; b[0][r][c] = b2[r][c]; }//end inner loop }//end outter loop c[0] = strassenmultiply.strassenmultiplymatrix(a[0], b[0], exp2k); for(int r = 0; r < exp2k; r++) for(int c = 0; c < exp2k; c++) c2[r][c] = c[0][r][c]; int middle = exp2k / 2; for(int r = 0; r < middle; r++){ for(int c = exp2k; c < n; c++){ a[1][r][c - exp2k] = a2[r][c]; b[3][r][c - exp2k] = b2[r][c]; }//end inner loop }//end outter loop for(int r = exp2k; r < n; r++){ for(int c = 0; c < middle; c++){ a[3][r - exp2k][c] = a2[r][c]; b[1][r - exp2k][c] = b2[r][c]; }//end inner loop }//end outter loop for(int r = middle; r < exp2k; r++){ for(int c = exp2k; c < n; c++){ a[2][r - middle][c - exp2k] = a2[r][c]; b[4][r - middle][c - exp2k] = b2[r][c]; }//end inner loop }//end outter loop for(int r = exp2k; r < n; r++){ for(int c = middle; c < n - exp2k + 1; c++){ a[4][r - exp2k][c - middle] = a2[r][c]; b[2][r - exp2k][c - middle] = b2[r][c]; }//end inner loop }//end outter loop for(int i = 1; i <= 4; i++) c[i] = multiplyrectmatrix(a[i], b[i], middle, a[i].length, a[i].length, middle); /* calculate the final results of grids in the "biggest 2^k square, according to the rules of matrice multiplication. */ for (int row = 0; row < exp2k; row++) { for (int col = 0; col < exp2k; col++) { for (int k = exp2k; k < n; k++) { c2[row][col] += a2[row][k] * b2[k][col]; }//end loop }//end inner loop }//end outter loop //use brute force to solve the rest, will be improved later: for(int col = exp2k; col < n; col++){ for(int row = 0; row < n; row++){ for(int k = 0; k < n; k++) c2[row][col] += a2[row][k] * b2[k][row]; }//end inner loop }//end outter loop for(int row = exp2k; row < n; row++){ for(int col = 0; col < exp2k; col++){ for(int k = 0; k < n; k++) c2[row][col] += a2[row][k] * b2[k][row]; }//end inner loop }//end outter loop return c2; }//end method }//end class
strassenmethod.java
package matrixalgorithm; import java.util.scanner; public class strassenmethod { private double[][][][] a = new double[2][2][][]; private double[][][][] b = new double[2][2][][]; private double[][][][] c = new double[2][2][][]; /*//codes for testing this class: public static void main(string[] args) { scanner input = new scanner(system.in); system.out.println("input size of the matrix: "); int n = input.nextint(); double[][] a = new double[n][n]; double[][] b = new double[n][n]; double[][] c = new double[n][n]; system.out.println("input data for matrix a: "); for (int r = 0; r < n; r++) { for (int c = 0; c < n; c++) { system.out.printf("data of a[%d][%d]: ", r, c); a[r][c] = input.nextdouble(); }//end inner loop }//end loop system.out.println("input data for matrix b: "); for (int r = 0; r < n; r++) { for (int c = 0; c < n; c++) { system.out.printf("data of a[%d][%d]: ", r, c); b[r][c] = input.nextdouble(); }//end inner loop }//end loop strassenmethod algorithm = new strassenmethod(); c = algorithm.strassenmultiplymatrix(a, b, n); system.out.println("result from matrix c: "); for (int r = 0; r < n; r++) { for (int c = 0; c < n; c++) { system.out.printf("data of c[%d][%d]: %f\n", r, c, c[r][c]); }//end inner loop }//end outter loop }//end main*/ public double[][] strassenmultiplymatrix(double[][] a2, double b2[][], int n){ double[][] c2 = new double[n][n]; //initialize the matrix: for(int rowindex = 0; rowindex < n; rowindex++) for(int colindex = 0; colindex < n; colindex++) c2[rowindex][colindex] = 0.0; if(n == 1) c2[0][0] = a2[0][0] * b2[0][0]; //"slice matrices into 2 * 2 parts: else{ double[][][][] a = new double[2][2][n / 2][n / 2]; double[][][][] b = new double[2][2][n / 2][n / 2]; double[][][][] c = new double[2][2][n / 2][n / 2]; for(int r = 0; r < n / 2; r++){ for(int c = 0; c < n / 2; c++){ a[0][0][r][c] = a2[r][c]; a[0][1][r][c] = a2[r][n / 2 + c]; a[1][0][r][c] = a2[n / 2 + r][c]; a[1][1][r][c] = a2[n / 2 + r][n / 2 + c]; b[0][0][r][c] = b2[r][c]; b[0][1][r][c] = b2[r][n / 2 + c]; b[1][0][r][c] = b2[n / 2 + r][c]; b[1][1][r][c] = b2[n / 2 + r][n / 2 + c]; }//end loop }//end loop n = n / 2; double[][][] s = new double[10][n][n]; s[0] = minusmatrix(b[0][1], b[1][1], n); s[1] = addmatrix(a[0][0], a[0][1], n); s[2] = addmatrix(a[1][0], a[1][1], n); s[3] = minusmatrix(b[1][0], b[0][0], n); s[4] = addmatrix(a[0][0], a[1][1], n); s[5] = addmatrix(b[0][0], b[1][1], n); s[6] = minusmatrix(a[0][1], a[1][1], n); s[7] = addmatrix(b[1][0], b[1][1], n); s[8] = minusmatrix(a[0][0], a[1][0], n); s[9] = addmatrix(b[0][0], b[0][1], n); double[][][] p = new double[7][n][n]; p[0] = strassenmultiplymatrix(a[0][0], s[0], n); p[1] = strassenmultiplymatrix(s[1], b[1][1], n); p[2] = strassenmultiplymatrix(s[2], b[0][0], n); p[3] = strassenmultiplymatrix(a[1][1], s[3], n); p[4] = strassenmultiplymatrix(s[4], s[5], n); p[5] = strassenmultiplymatrix(s[6], s[7], n); p[6] = strassenmultiplymatrix(s[8], s[9], n); c[0][0] = addmatrix(minusmatrix(addmatrix(p[4], p[3], n), p[1], n), p[5], n); c[0][1] = addmatrix(p[0], p[1], n); c[1][0] = addmatrix(p[2], p[3], n); c[1][1] = minusmatrix(minusmatrix(addmatrix(p[4], p[0], n), p[2], n), p[6], n); n *= 2; for(int r = 0; r < n / 2; r++){ for(int c = 0; c < n / 2; c++){ c2[r][c] = c[0][0][r][c]; c2[r][n / 2 + c] = c[0][1][r][c]; c2[n / 2 + r][c] = c[1][0][r][c]; c2[n / 2 + r][n / 2 + c] = c[1][1][r][c]; }//end inner loop }//end outter loop }//end else return c2; }//end method //add two matrices according to matrix addition. private double[][] addmatrix(double[][] a, double[][] b, int n){ double c[][] = new double[n][n]; for(int r = 0; r < n; r++) for(int c = 0; c < n; c++) c[r][c] = a[r][c] + b[r][c]; return c; }//end method //substract two matrices according to matrix addition. private double[][] minusmatrix(double[][] a, double[][] b, int n){ double c[][] = new double[n][n]; for(int r = 0; r < n; r++) for(int c = 0; c < n; c++) c[r][c] = a[r][c] - b[r][c]; return c; }//end method }//end class
希望本文所述对大家学习java程序设计有所帮助。
上一篇: PHP对象相关知识总结
下一篇: php 静态属性和静态方法区别详解