纯c语言优雅地实现矩阵运算库的方法
编程既是技术输出也是艺术创作。鉴赏高手写的程序,往往让人眼前一亮,他们思路、逻辑清晰,所呈现的代码简洁、优雅、高效,令人为之叹服。烂代码则像“屎山"一般让人作呕,软件难以维护最大的原因除了需求模糊的客观因素,最重要的主观因素还是代码写得烂。有生之年,愿能保持对编程的热情,不断提升编程能力,真正体会其乐趣,共勉!
1.一个优雅好用的c语言库必须满足哪些条件
这里给出的软件开发应遵循的一般原则,摘自les piegl和wayne tiller所著的一本非常经典的书the nurbs book(second edition)。
(1)工具性(toolability):应该使用可用的工具来建立新的应用程序;
(2)可移植性(portability):应用程序应该容易被移植到不同的软件和硬件平台;
(3)可重用性(reusability):程序的编写应该便于重复使用代码段;
(4)可检验性(testability):程序应该简单一致,以便于测试和调试;
(5)可靠性(reliability):对程序运行过程中可能出现的各种错误应该进行合理、一致的处理,以使系统稳定、可靠;
(6)可扩展性(enhanceability):代码必须易于理解,以便可以容易地增加新的功能;
(7)可维护性(fixability):易于找出程序错误的位置;
(8)一致性(consistency):在整个库中,编程习惯应保持一致;
(9)可读性(communicability):程序应该易于阅读和理解;
(10)编程风格(style of programming):代码看起来像书上的数学公式那样以便于读者理解,同时遵循用户友好的编程风格;
(11)易用性(usability):应该使一些非专业的用户也能够方便地使用所开发的库来开发各种更高层次的应用程序;
(12)数值高效性(numerical efficiency):所有函数必须仔细推敲,保证其数值高效性;
(13)基于对象编程(object based programming):避免在函数间传递大量数据,并且使代码易于理解。
2.实现一个矩阵运算库的几点思考
(1)采用预定义的数据类型,避免直接使用编译器定义的数据类型
typedef unsigned int error_id; typedef int index; typedef short flag; typedef int integer; typedef double real; typedef char* string; typedef void void;
使用预定义的数据类型,有利于程序移植,且可以提高可读性。例如,如果一个系统只支持单精度浮点数,那么只需修改数据类型real为float,达到一劳永逸的效果。定义index与integer数据类型是为了在编程时区分索引变量与普通整形变量,同样提高可读性。
(2)基于对象编程,定义矩阵对象
typedef struct matrix { integer rows; integer columns; real* p; }matrix;
这里,用一级指针而非二级指针指向矩阵的数据内存地址,有诸多原因,详见博文:为什么我推荐使用一级指针创建二维数组?。
(3)除了特别编写的内存处理函数(使用栈链表保存、释放动态分配的内存地址),不允许任何函数直接分配和释放内存
不恰当的分配、使用与释放内存很可能导致内存泄漏、系统崩溃等致命的错误。如果一个函数需动态申请多个内存,那么可能会写出这样啰嗦的程序:
double* x = null, * y = null, * z = null; x = (double*)malloc(n1 * sizeof(double)); if (x == null) return -1; y = (double*)malloc(n2 * sizeof(double)); if (y == null) { free(x); x = null; return -1; } z = (double*)malloc(n3 * sizeof(double)); if (z == null) { free(x); x = null; free(y); y = null; return -1; }
为了优雅地实现动态内存分配与释放,les piegl大神分3步来处理内存申请与释放:
a)在进入一个新的程序时,一个内存堆栈被初始化为空;
b)当需要申请内存时,调用特定的函数来分配所需的内存,并将指向内存的指针存入堆栈中的正确位置;
c)在离开程序时,遍历内存堆栈,释放其中的指针所指向的内存。
程序结构大致如下:
stacks s; matrix* m = null; integer rows = 3, columns = 4; error_id errorid = _error_no_error; init_stack(&s); m = creat_matrix(rows, columns, &errorid, &s); if (m == null) goto exit; //do something // ... exit: free_stack(&s); return errorid;
(4)防御性编程,对输入参数做有效性检查,并返回错误号
例如输入的矩阵行数、列数应该是正整数,指针参数必须非空等等。
(5)注意编程细节的打磨
a)操作符(逗号,等号等)两边必须空一格;
b)逻辑功能相同的程序间不加空行,逻辑功能独立的程序间加空行;
c)条件判断关键字(for if while等)后必须加一空格,起到强调作用,也更清晰;
d)函数内部定义局部变量后,必须空一行后再编写函数主体。
3.完整c程序
本矩阵运算库只包含了矩阵的基本运算,包括创建任意二维/三维矩阵、创建零矩阵及单位矩阵、矩阵加法、矩阵减法、矩阵乘法、矩阵求逆、矩阵转置、矩阵的迹、矩阵lup分解、解矩阵方程ax=b。
common.h
/******************************************************************************* * file name : common.h * library/module name : matrixcomputation * author : marc pony(marc_pony@163.com) * create date : 2021/6/28 * abstract description : 矩阵运算库公用头文件 *******************************************************************************/ #ifndef __common_h__ #define __common_h__ /******************************************************************************* * (1)debug switch section *******************************************************************************/ /******************************************************************************* * (2)include file section *******************************************************************************/ #include <math.h> #include <stdio.h> #include <malloc.h> #include <stdlib.h> #include <time.h> #include <memory.h> /******************************************************************************* * (3)macro define section *******************************************************************************/ #define _in #define _out #define _in_out #define max(x,y) (x) > (y) ? (x) : (y) #define min(x,y) (x) < (y) ? (x) : (y) #define _crt_secure_no_warnings #define pi 3.14159265358979323846 #define positive_infinity 999999999 #define negative_infinity -999999999 #define _error_no_error 0x00000000 //无错误 #define _error_failed_to_allocate_heap_memory 0x00000001 //分配堆内存失败 #define _error_svd_exceed_max_iterations 0x00000002 //svd超过最大迭代次数 #define _error_matrix_rows_or_columns_not_equal 0x00000003 //矩阵行数或列数不相等 #define _error_matrix_multiplication 0x00000004 //矩阵乘法错误(第一个矩阵的列数不等于第二个矩阵行数) #define _error_matrix_must_be_square 0x00000005 //矩阵必须为方阵 #define _error_matrix_norm_type_invalid 0x00000006 //矩阵模类型无效 #define _error_matrix_equation_has_no_solutions 0x00000007 //矩阵方程无解 #define _error_matrix_equation_has_infinity_manny_solutions 0x00000008 //矩阵方程有无穷多解 #define _error_qr_decomposition_failed 0x00000009 //qr分解失败 #define _error_cholesky_decomposition_failed 0x0000000a //cholesky分解失败 #define _error_improved_cholesky_decomposition_failed 0x0000000b //improved cholesky分解失败 #define _error_lu_decomposition_failed 0x0000000c //lu分解失败 #define _error_create_matrix_failed 0x0000000d //创建矩阵失败 #define _error_matrix_transpose_failed 0x0000000e //矩阵转置失败 #define _error_create_vector_failed 0x0000000f //创建向量失败 #define _error_vector_dimension_not_equal 0x00000010 //向量维数不相同 #define _error_vector_norm_type_invalid 0x00000011 //向量模类型无效 #define _error_vector_cross_failed 0x00000012 //向量叉乘失败 #define _error_input_parameters_error 0x00010000 //输入参数错误 /******************************************************************************* * (4)struct(data types) define section *******************************************************************************/ typedef unsigned int error_id; typedef int index; typedef short flag; typedef int integer; typedef double real; typedef char* string; typedef void void; typedef struct matrix { integer rows; integer columns; real* p; }matrix; typedef struct matrix_node { matrix* ptr; struct matrix_node* next; } matrix_node; typedef struct matrix_element_node { real* ptr; struct matrix_element_node* next; } matrix_element_node; typedef struct stacks { matrix_node* matrixnode; matrix_element_node* matrixelementnode; // ... // 添加其他对象的指针 } stacks; /******************************************************************************* * (5)prototype declare section *******************************************************************************/ /********************************************************************************************** function: init_stack description: 初始化栈 input: 无 output: 无 input_output: 栈指针 return: 无 author: marc pony(marc_pony@163.com) ***********************************************************************************************/ void init_stack(_in_out stacks* s); /********************************************************************************************** function: free_stack description: 释放栈 input: 栈指针 output: 无 input_output: 无 return: 无 author: marc pony(marc_pony@163.com) ***********************************************************************************************/ void free_stack(_in stacks* s); #endif
matrix.h
/******************************************************************************* * file name : matrix.h * library/module name : matrixcomputation * author : marc pony(marc_pony@163.com) * create date : 2021/6/28 * abstract description : 矩阵运算库头文件 *******************************************************************************/ #ifndef __matrix_h__ #define __matrix_h__ /******************************************************************************* * (1)debug switch section *******************************************************************************/ /******************************************************************************* * (2)include file section *******************************************************************************/ #include "common.h" /******************************************************************************* * (3)macro define section *******************************************************************************/ /******************************************************************************* * (4)struct(data types) define section *******************************************************************************/ /******************************************************************************* * (5)prototype declare section *******************************************************************************/ void print_matrix(matrix* a, string string); /********************************************************************************************** function: creat_matrix description: 创建矩阵 input: 矩阵行数rows,列数columns output: 错误号指针errorid,栈指针s input_output: 无 return: 矩阵指针 author: marc pony(marc_pony@163.com) ***********************************************************************************************/ matrix* creat_matrix(_in integer rows, _in integer columns, _out error_id* errorid, _out stacks* s); /********************************************************************************************** function: creat_multiple_matrices description: 创建多个矩阵 input: 矩阵行数rows,列数columns,个数count output: 错误号指针errorid,栈指针s input_output: 无 return: 矩阵指针 author: marc pony(marc_pony@163.com) ***********************************************************************************************/ matrix* creat_multiple_matrices(_in integer rows, _in integer columns, _in integer count, _out error_id* errorid, _out stacks* s); /********************************************************************************************** function: creat_zero_matrix description: 创建零矩阵 input: 矩阵行数rows,列数columns output: 错误号指针errorid,栈指针s input_output: 无 return: 矩阵指针 author: marc pony(marc_pony@163.com) ***********************************************************************************************/ matrix* creat_zero_matrix(_in integer rows, _in integer columns, _out error_id* errorid, _out stacks* s); /********************************************************************************************** function: creat_eye_matrix description: 创建单位矩阵 input: 矩阵行数rows,列数columns output: 错误号指针errorid,栈指针s input_output: 无 return: 矩阵指针 author: marc pony(marc_pony@163.com) ***********************************************************************************************/ matrix* creat_eye_matrix(_in integer n, _out error_id* errorid, _out stacks* s); /********************************************************************************************** function: matrix_add description: 矩阵a + 矩阵b = 矩阵c input: 矩阵a,矩阵b output: 矩阵c input_output: 无 return: 错误号 author: marc pony(marc_pony@163.com) ***********************************************************************************************/ error_id matrix_add(_in matrix* a, _in matrix* b, _out matrix *c); /********************************************************************************************** function: matrix_subtraction description: 矩阵a - 矩阵b = 矩阵c input: 矩阵a,矩阵b output: 矩阵c input_output: 无 return: 错误号 author: marc pony(marc_pony@163.com) ***********************************************************************************************/ error_id matrix_subtraction(_in matrix* a, _in matrix* b, _out matrix* c); /********************************************************************************************** function: matrix_multiplication description: 矩阵乘法c = a * b input: 矩阵a,矩阵b output: 矩阵c input_output: 无 return: 错误号 author: marc pony(marc_pony@163.com) ***********************************************************************************************/ error_id matrix_multiplication(_in matrix* a, _in matrix* b, _out matrix* c); /********************************************************************************************** function: matrix_inverse description: 矩阵求逆 input: 矩阵a output: 矩阵a的逆矩阵 input_output: 无 return: 错误号 author: marc pony(marc_pony@163.com) ***********************************************************************************************/ error_id matrix_inverse(_in matrix* a, _out matrix* inva); /********************************************************************************************** function: matrix_transpose description: 矩阵转置 input: 矩阵a output: 矩阵a的转置 input_output: 无 return: 错误号 author: marc pony(marc_pony@163.com) ***********************************************************************************************/ error_id matrix_transpose(_in matrix* a, _out matrix* transposea); /********************************************************************************************** function: matrix_trace description: 矩阵的迹 input: 矩阵a output: 矩阵a的迹 input_output: 无 return: 错误号 author: marc pony(marc_pony@163.com) ***********************************************************************************************/ error_id matrix_trace(_in matrix* a, _out real* trace); /********************************************************************************************** function: lup_decomposition description: n行n列矩阵lup分解pa = l * u input: n行n列矩阵a output: n行n列下三角矩阵l,n行n列上三角矩阵u,n行n列置换矩阵p input_output: 无 return: 错误号 author: marc pony(marc_pony@163.com) 参考:https://zhuanlan.zhihu.com/p/84210687 ***********************************************************************************************/ error_id lup_decomposition(_in matrix* a, _out matrix* l, _out matrix* u, _out matrix* p); /********************************************************************************************** function: solve_matrix_equation_by_lup_decomposition description: lup分解解矩阵方程ax=b,其中a为n行n列矩阵,b为n行m列矩阵,x为n行m列待求矩阵(写到矩阵b) input: n行n列矩阵a output: 无 input_output: n行m列矩阵b(即n行m列待求矩阵x) return: 错误号 author: marc pony(marc_pony@163.com) ***********************************************************************************************/ error_id solve_matrix_equation_by_lup_decomposition(_in matrix* a, _in_out matrix* b); #endif
common.c
/******************************************************************************* * file name : common.c * library/module name : matrixcomputation * author : marc pony(marc_pony@163.com) * create date : 2021/7/16 * abstract description : 矩阵运算库公用源文件 *******************************************************************************/ /******************************************************************************* * (1)debug switch section *******************************************************************************/ /******************************************************************************* * (2)include file section *******************************************************************************/ #include "common.h" /******************************************************************************* * (3)macro define section *******************************************************************************/ /******************************************************************************* * (4)struct(data types) define section *******************************************************************************/ /******************************************************************************* * (5)prototype declare section *******************************************************************************/ /******************************************************************************* * (6)global variable declare section *******************************************************************************/ /******************************************************************************* * (7)file static variable define section *******************************************************************************/ /******************************************************************************* * (8)function define section *******************************************************************************/ /********************************************************************************************** function: init_stack description: 初始化栈 input: 无 output: 无 input_output: 栈指针 return: 无 author: marc pony(marc_pony@163.com) ***********************************************************************************************/ void init_stack(_in_out stacks* s) { if (s == null) { return; } memset(s, 0, sizeof(stacks)); } /********************************************************************************************** function: free_stack description: 释放栈 input: 栈指针 output: 无 input_output: 无 return: 无 author: marc pony(marc_pony@163.com) ***********************************************************************************************/ void free_stack(_in stacks* s) { matrix_node* matrixnode = null; matrix_element_node* matrixelementnode = null; if (s == null) { return; } while (s->matrixnode != null) { matrixnode = s->matrixnode; s->matrixnode = matrixnode->next; free(matrixnode->ptr); matrixnode->ptr = null; free(matrixnode); matrixnode = null; } while (s->matrixelementnode != null) { matrixelementnode = s->matrixelementnode; s->matrixelementnode = matrixelementnode->next; free(matrixelementnode->ptr); matrixelementnode->ptr = null; free(matrixelementnode); matrixelementnode = null; } // ... // 释放其他指针 } matrix.c /******************************************************************************* * file name : matrix.c * library/module name : matrixcomputation * author : marc pony(marc_pony@163.com) * create date : 2021/2/24 * abstract description : 矩阵运算库源文件 *******************************************************************************/ /******************************************************************************* * (1)debug switch section *******************************************************************************/ /******************************************************************************* * (2)include file section *******************************************************************************/ #include "matrix.h" /******************************************************************************* * (3)macro define section *******************************************************************************/ /******************************************************************************* * (4)struct(data types) define section *******************************************************************************/ /******************************************************************************* * (5)prototype declare section *******************************************************************************/ /******************************************************************************* * (6)global variable declare section *******************************************************************************/ /******************************************************************************* * (7)file static variable define section *******************************************************************************/ /******************************************************************************* * (8)function define section *******************************************************************************/ void print_matrix(matrix* a, string string) { index i, j; printf("matrix %s:", string); printf("\n"); for (i = 0; i < a->rows; i++) { for (j = 0; j < a->columns; j++) { printf("%f ", a->p[i * a->columns + j]); } printf("\n"); } printf("\n"); } /********************************************************************************************** function: creat_matrix description: 创建矩阵 input: 矩阵行数rows,列数columns output: 错误号指针errorid,栈指针s input_output: 无 return: 矩阵指针 author: marc pony(marc_pony@163.com) ***********************************************************************************************/ matrix* creat_matrix(_in integer rows, _in integer columns, _out error_id* errorid, _out stacks* s) { matrix* matrix = null; matrix_node* matrixnode = null; matrix_element_node* matrixelementnode = null; *errorid = _error_no_error; if (rows <= 0 || columns <= 0 || errorid == null || s == null) { *errorid = _error_input_parameters_error; return null; } matrix = (matrix*)malloc(sizeof(matrix)); matrixnode = (matrix_node*)malloc(sizeof(matrix_node)); matrixelementnode = (matrix_element_node*)malloc(sizeof(matrix_element_node)); if (matrix == null || matrixnode == null || matrixelementnode == null) { free(matrix); matrix = null; free(matrixnode); matrixnode = null; free(matrixelementnode); matrixelementnode = null; *errorid = _error_failed_to_allocate_heap_memory; return null; } matrix->rows = rows; matrix->columns = columns; matrix->p = (real*)malloc(rows * columns * sizeof(real)); //确保matrix非空才能执行指针操作 if (matrix->p == null) { free(matrix->p); matrix->p = null; free(matrix); matrix = null; free(matrixnode); matrixnode = null; free(matrixelementnode); matrixelementnode = null; *errorid = _error_failed_to_allocate_heap_memory; return null; } matrixnode->ptr = matrix; matrixnode->next = s->matrixnode; s->matrixnode = matrixnode; matrixelementnode->ptr = matrix->p; matrixelementnode->next = s->matrixelementnode; s->matrixelementnode = matrixelementnode; return matrix; } /********************************************************************************************** function: creat_multiple_matrices description: 创建多个矩阵 input: 矩阵行数rows,列数columns,个数count output: 错误号指针errorid,栈指针s input_output: 无 return: 矩阵指针 author: marc pony(marc_pony@163.com) ***********************************************************************************************/ matrix* creat_multiple_matrices(_in integer rows, _in integer columns, _in integer count, _out error_id* errorid, _out stacks* s) { index i; matrix* matrix = null, *p = null; matrix_node* matrixnode = null; *errorid = _error_no_error; if (rows <= 0 || columns <= 0 || count <= 0 || errorid == null || s == null) { *errorid = _error_input_parameters_error; return null; } matrix = (matrix*)malloc(count * sizeof(matrix)); matrixnode = (matrix_node*)malloc(sizeof(matrix_node)); if (matrix == null || matrixnode == null) { free(matrix); matrix = null; free(matrixnode); matrixnode = null; *errorid = _error_failed_to_allocate_heap_memory; return null; } for (i = 0; i < count; i++) { p = creat_matrix(rows, columns, errorid, s); if (p == null) { free(matrix); matrix = null; free(matrixnode); matrixnode = null; *errorid = _error_failed_to_allocate_heap_memory; return null; } matrix[i] = *p; } matrixnode->ptr = matrix; matrixnode->next = s->matrixnode; s->matrixnode = matrixnode; return matrix; } /********************************************************************************************** function: creat_zero_matrix description: 创建零矩阵 input: 矩阵行数rows,列数columns output: 错误号指针errorid,栈指针s input_output: 无 return: 矩阵指针 author: marc pony(marc_pony@163.com) ***********************************************************************************************/ matrix* creat_zero_matrix(_in integer rows, _in integer columns, _out error_id* errorid, _out stacks* s) { matrix* matrix = null; *errorid = _error_no_error; if (rows <= 0 || columns <= 0 || errorid == null || s == null) { *errorid = _error_input_parameters_error; return null; } matrix = creat_matrix(rows, columns, errorid, s); if (*errorid == _error_no_error) { memset(matrix->p, 0, rows * columns * sizeof(real)); } return matrix; } /********************************************************************************************** function: creat_eye_matrix description: 创建单位矩阵 input: 矩阵行数rows,列数columns output: 错误号指针errorid,栈指针s input_output: 无 return: 矩阵指针 author: marc pony(marc_pony@163.com) ***********************************************************************************************/ matrix* creat_eye_matrix(_in integer n, _out error_id* errorid, _out stacks* s) { index i; matrix* matrix = null; *errorid = _error_no_error; if (n <= 0 || errorid == null || s == null) { *errorid = _error_input_parameters_error; return null; } matrix = creat_matrix(n, n, errorid, s); if (*errorid == _error_no_error) { memset(matrix->p, 0, n * n * sizeof(real)); for (i = 0; i < n; i++) { matrix->p[i * n + i] = 1.0; } } return matrix; } /********************************************************************************************** function: matrix_add description: 矩阵a + 矩阵b = 矩阵c input: 矩阵a,矩阵b output: 矩阵c input_output: 无 return: 错误号 author: marc pony(marc_pony@163.com) ***********************************************************************************************/ error_id matrix_add(_in matrix* a, _in matrix* b, _out matrix* c) { index i, j; integer rows, columns; error_id errorid = _error_no_error; if (a == null || b == null || c == null) { errorid = _error_input_parameters_error; return errorid; } if (a->rows != b->rows || a->rows != c->rows || b->rows != c->rows || a->columns != b->columns || a->columns != c->columns || b->columns != c->columns) { errorid = _error_matrix_rows_or_columns_not_equal; return errorid; } rows = a->rows; columns = a->columns; for (i = 0; i < rows; i++) { for (j = 0; j < columns; j++) { c->p[i * columns + j] = a->p[i * columns + j] + b->p[i * columns + j]; } } return errorid; } /********************************************************************************************** function: matrix_subtraction description: 矩阵a - 矩阵b = 矩阵c input: 矩阵a,矩阵b output: 矩阵c input_output: 无 return: 错误号 author: marc pony(marc_pony@163.com) ***********************************************************************************************/ error_id matrix_subtraction(_in matrix* a, _in matrix* b, _out matrix* c) { index i, j; integer rows, columns; error_id errorid = _error_no_error; if (a == null || b == null || c == null) { errorid = _error_input_parameters_error; return errorid; } if (a->rows != b->rows || a->rows != c->rows || b->rows != c->rows || a->columns != b->columns || a->columns != c->columns || b->columns != c->columns) { errorid = _error_matrix_rows_or_columns_not_equal; return errorid; } rows = a->rows; columns = a->columns; for (i = 0; i < rows; i++) { for (j = 0; j < columns; j++) { c->p[i * columns + j] = a->p[i * columns + j] - b->p[i * columns + j]; } } return errorid; } /********************************************************************************************** function: matrix_multiplication description: 矩阵乘法c = a * b input: 矩阵a,矩阵b output: 矩阵c input_output: 无 return: 错误号 author: marc pony(marc_pony@163.com) ***********************************************************************************************/ error_id matrix_multiplication(_in matrix* a, _in matrix* b, _out matrix* c) { index i, j, k; real sum; error_id errorid = _error_no_error; if (a == null || b == null || c == null) { errorid = _error_input_parameters_error; return errorid; } if (a->columns != b->rows || a->rows != c->rows || b->columns != c->columns) { errorid = _error_matrix_multiplication; return errorid; } for (i = 0; i < a->rows; i++) { for (j = 0; j < b->columns; j++) { sum = 0.0; for (k = 0; k < a->columns; k++) { sum += a->p[i * a->columns + k] * b->p[k * b->columns + j]; } c->p[i * b->columns + j] = sum; } } return errorid; } /********************************************************************************************** function: matrix_inverse description: 矩阵求逆 input: 矩阵a output: 矩阵a的逆矩阵 input_output: 无 return: 错误号 author: marc pony(marc_pony@163.com) ***********************************************************************************************/ error_id matrix_inverse(_in matrix* a, _out matrix* inva) { index i; integer n; matrix* atemp = null; error_id errorid = _error_no_error; stacks s; if (a == null || inva == null) { errorid = _error_input_parameters_error; return errorid; } if (a->rows != a->columns) { errorid = _error_matrix_must_be_square; return errorid; } init_stack(&s); n = a->rows; atemp = creat_matrix(n, n, &errorid, &s); if (errorid != _error_no_error) goto exit; memcpy(atemp->p, a->p, n * n * sizeof(real)); memset(inva->p, 0, n * n * sizeof(real)); for (i = 0; i < n; i++) { inva->p[i * n + i] = 1.0; } errorid = solve_matrix_equation_by_lup_decomposition(atemp, inva); exit: free_stack(&s); return errorid; } /********************************************************************************************** function: matrix_transpose description: 矩阵转置 input: 矩阵a output: 矩阵a的转置 input_output: 无 return: 错误号 author: marc pony(marc_pony@163.com) ***********************************************************************************************/ error_id matrix_transpose(_in matrix* a, _out matrix* transposea) { index i, j; error_id errorid = _error_no_error; if (a == null || transposea == null) { errorid = _error_input_parameters_error; return errorid; } if (a->rows != transposea->columns || a->columns != transposea->rows) { errorid = _error_matrix_transpose_failed; return errorid; } for (i = 0; i < a->rows; i++) { for (j = 0; j < a->columns; j++) { transposea->p[j * a->rows + i] = a->p[i * a->columns + j]; } } return errorid; } /********************************************************************************************** function: matrix_trace description: 矩阵的迹 input: 矩阵a output: 矩阵a的迹 input_output: 无 return: 错误号 author: marc pony(marc_pony@163.com) ***********************************************************************************************/ error_id matrix_trace(_in matrix* a, _out real *trace) { index i; error_id errorid = _error_no_error; if (a == null || trace == null) { errorid = _error_input_parameters_error; return errorid; } if (a->rows != a->columns) { errorid = _error_matrix_must_be_square; return errorid; } *trace = 0.0; for (i = 0; i < a->rows; i++) { *trace += a->p[i * a->columns + i]; } return errorid; } /********************************************************************************************** function: lup_decomposition description: n行n列矩阵lup分解pa = l * u input: n行n列矩阵a output: n行n列下三角矩阵l,n行n列上三角矩阵u,n行n列置换矩阵p input_output: 无 return: 错误号 author: marc pony(marc_pony@163.com) 参考:https://zhuanlan.zhihu.com/p/84210687 ***********************************************************************************************/ error_id lup_decomposition(_in matrix* a, _out matrix* l, _out matrix* u, _out matrix* p) { index i, j, k, index, s, t; integer n; real maxvalue, temp; error_id errorid = _error_no_error; if (a == null || l == null || u == null || p == null) { errorid = _error_input_parameters_error; return errorid; } if (a->rows != a->columns) { errorid = _error_matrix_must_be_square; return errorid; } n = a->rows; memcpy(u->p, a->p, n * n * sizeof(real)); memset(l->p, 0, n * n * sizeof(real)); memset(p->p, 0, n * n * sizeof(real)); for (i = 0; i < n; i++) { l->p[i * n + i] = 1.0; p->p[i * n + i] = 1.0; } for (j = 0; j < n - 1; j++) { //select i(>= j) that maximizes | u(i, j) | index = -1; maxvalue = 0.0; for (i = j; i < n; i++) { temp = fabs(u->p[i * n + j]); if (temp > maxvalue) { maxvalue = temp; index = i; } } if (index == -1) { continue; } //interchange rows of u : u(j, j : n) < ->u(i, j : n) for (k = j; k < n; k++) { s = j * n + k; t = index * n + k; temp = u->p[s]; u->p[s] = u->p[t]; u->p[t] = temp; } //interchange rows of l : l(j, 1 : j - 1) < ->l(i, 1 : j - 1) for (k = 0; k < j; k++) { s = j * n + k; t = index * n + k; temp = l->p[s]; l->p[s] = l->p[t]; l->p[t] = temp; } //interchange rows of p : p(j, 1 : n) < ->p(i, 1 : n) for (k = 0; k < n; k++) { s = j * n + k; t = index * n + k; temp = p->p[s]; p->p[s] = p->p[t]; p->p[t] = temp; } for (i = j + 1; i < n; i++) { s = i * n + j; l->p[s] = u->p[s] / u->p[j * n + j]; for (k = j; k < n; k++) { u->p[i * n + k] -= l->p[s] * u->p[j * n + k]; } } } return errorid; } /********************************************************************************************** function: solve_matrix_equation_by_lup_decomposition description: lup分解解矩阵方程ax=b,其中a为n行n列矩阵,b为n行m列矩阵,x为n行m列待求矩阵(写到矩阵b) input: n行n列矩阵a output: 无 input_output: n行m列矩阵b(即n行m列待求矩阵x) return: 错误号 author: marc pony(marc_pony@163.com) ***********************************************************************************************/ error_id solve_matrix_equation_by_lup_decomposition(_in matrix* a, _in_out matrix* b) { index i, j, k, index, s, t; integer n, m; real sum, maxvalue, temp; matrix* l = null, * u = null, * y = null; error_id errorid = _error_no_error; stacks s; if (a == null || b == null) { errorid = _error_input_parameters_error; return errorid; } if (a->rows != a->columns) { errorid = _error_matrix_must_be_square; return errorid; } init_stack(&s); n = a->rows; m = b->columns; l = creat_matrix(n, n, &errorid, &s); if (errorid != _error_no_error) goto exit; u = creat_matrix(n, n, &errorid, &s); if (errorid != _error_no_error) goto exit; y = creat_matrix(n, m, &errorid, &s); if (errorid != _error_no_error) goto exit; memcpy(u->p, a->p, n * n * sizeof(real)); memset(l->p, 0, n * n * sizeof(real)); for (i = 0; i < n; i++) { l->p[i * n + i] = 1.0; } for (j = 0; j < n - 1; j++) { //select i(>= j) that maximizes | u(i, j) | index = -1; maxvalue = 0.0; for (i = j; i < n; i++) { temp = fabs(u->p[i * n + j]); if (temp > maxvalue) { maxvalue = temp; index = i; } } if (index == -1) { continue; } //interchange rows of u : u(j, j : n) < ->u(i, j : n) for (k = j; k < n; k++) { s = j * n + k; t = index * n + k; temp = u->p[s]; u->p[s] = u->p[t]; u->p[t] = temp; } //interchange rows of l : l(j, 1 : j - 1) < ->l(i, 1 : j - 1) for (k = 0; k < j; k++) { s = j * n + k; t = index * n + k; temp = l->p[s]; l->p[s] = l->p[t]; l->p[t] = temp; } //interchange rows of p : p(j, 1 : n) < ->p(i, 1 : n), c = p * b,等价于对b交换行 for (k = 0; k < m; k++) { s = j * m + k; t = index * m + k; temp = b->p[s]; b->p[s] = b->p[t]; b->p[t] = temp; } for (i = j + 1; i < n; i++) { s = i * n + j; l->p[s] = u->p[s] / u->p[j * n + j]; for (k = j; k < n; k++) { u->p[i * n + k] -= l->p[s] * u->p[j * n + k]; } } } for (i = 0; i < n; i++) { if (fabs(u->p[i * n + i]) < 1.0e-20) { errorid = _error_matrix_equation_has_no_solutions; goto exit; } } //l * y = c for (j = 0; j < m; j++) { for (i = 0; i < n; i++) { sum = 0.0; for (k = 0; k < i; k++) { sum = sum + l->p[i * n + k] * y->p[k * m + j]; } y->p[i * m + j] = b->p[i * m + j] - sum; } } //u * x = y for (j = 0; j < m; j++) { for (i = n - 1; i >= 0; i--) { sum = 0.0; for (k = i + 1; k < n; k++) { sum += u->p[i * n + k] * b->p[k * m + j]; } b->p[i * m + j] = (y->p[i * m + j] - sum) / u->p[i * n + i]; } } exit: free_stack(&s); return errorid; }
test_matrix.c
#include "matrix.h" void main() { real a[3 * 3] = { 1,2,3,6,5,5,8,7,2 }; real b[3 * 3] = {1,2,3,6,5,4,3,2,1}; matrix *a = null, * b = null, * c = null, * d = null, * e = null, * z = null, * inva = null, *m = null; error_id errorid = _error_no_error; real trace; stacks s; init_stack(&s); z = creat_zero_matrix(3, 3, &errorid, &s); print_matrix(z, "z"); e = creat_eye_matrix(3, &errorid, &s); print_matrix(e, "e"); a = creat_matrix(3, 3, &errorid, &s); a->p = a; print_matrix(a, "a"); b = creat_matrix(3, 3, &errorid, &s); b->p = b; print_matrix(b, "b"); c = creat_matrix(3, 3, &errorid, &s); d = creat_matrix(3, 3, &errorid, &s); inva = creat_matrix(3, 3, &errorid, &s); errorid = matrix_add(a, b, c); errorid = matrix_subtraction(a, b, c); errorid = matrix_multiplication(a, b, c); errorid = matrix_transpose(a, d); print_matrix(d, "d"); errorid = matrix_trace(a, &trace); errorid = matrix_inverse(a, inva); print_matrix(inva, "inva"); m = creat_multiple_matrices(3, 3, 2, &errorid, &s); m[0].p = a; m[1].p = b; free_stack(&s); }
参考资料
the nurbs book(second edition). les piegl,wayne tiller
到此这篇关于纯c语言优雅地实现矩阵运算库的方法的文章就介绍到这了,更多相关c语言 矩阵运算内容请搜索以前的文章或继续浏览下面的相关文章希望大家以后多多支持!
上一篇: PyTorch一小时掌握之神经网络分类篇