欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页  >  IT编程

POJ3233Matrix Power Series(矩阵快速幂)

程序员文章站 2022-05-30 09:28:05
题意 题目链接 给出$n \times n$的矩阵$A$,求$\sum_{i = 1}^k A^i $,每个元素对$m$取模 Sol 考虑直接分治 当$k$为奇数时 $\sum_{i = 1}^k A^i = \sum_{i = 1}^{k / 2 + 1} A^i + A^{k / 2 + 1}( ......

题意

给出$n \times n$的矩阵$a$,求$\sum_{i = 1}^k a^i $,每个元素对$m$取模

sol

考虑直接分治

当$k$为奇数时

$\sum_{i = 1}^k a^i = \sum_{i = 1}^{k / 2 + 1} a^i + a^{k / 2 + 1}(\sum_{i = 1}^{k / 2} a^i)$

当$k$为偶数时

$sum_{i = 1}^k = \sum_{i = 1}^{k / 2} a^i + a^{k / 2}(\sum_{i = 1}^{k / 2}a^i)$

 

当然还可以按套路对前缀和构造矩阵也是可以做的。

#include<cstdio>
#include<cstring>
#include<iostream>
#include<map>
#define ll long long 
using namespace std;
int n, k, mod;
int mul(int x, int y) {
    if(1ll * x * y > mod) return 1ll * x * y % mod;
    else return 1ll * x * y;
}
int add(int x, int y) {
    if(x + y > mod) return x + y - mod;
    else return x + y;
}
struct matrix {
    int m[31][31];
    matrix() {
        memset(m, 0, sizeof(m));
    }
    bool operator < (const matrix &rhs) const {
        for(int i = 1; i <= n; i++)
            for(int j = 1; j <= n; j++)
                if(m[i][j] != rhs.m[i][j])
                    return m[i][j] < rhs.m[i][j];
        return 1;
    }
    matrix operator * (const matrix &rhs) const {
        matrix ans;
        for(int k = 1; k <= n; k++)
            for(int i = 1; i <= n; i++)
                for(int j = 1; j <= n; j++)
                    ans.m[i][j] = add(ans.m[i][j], mul(m[i][k], rhs.m[k][j]));
        return ans;
    }
    matrix operator + (const matrix &rhs) const {
        matrix ans;
        for(int i = 1; i <= n; i++)
            for(int j = 1; j <= n; j++)
                ans.m[i][j] = add(m[i][j], rhs.m[i][j]);
        return ans;
    }
}a;
matrix getbase() {
    matrix base;
    for(int i = 1; i <= n; i++) base.m[i][i] = 1;
    return base;
}
matrix fp(matrix a, int p) {
    matrix base = getbase();
    while(p) {
        if(p & 1) base = base * a;
        a = a * a; p >>= 1;
    }
    return base;
}
matrix solve(int k) {
    if(k == 1) return a;
    matrix res = solve(k / 2);
    if(k & 1) {
        matrix po = fp(a, k / 2 + 1);
        return res + po + po * res;
    }
    else return res + fp(a, k / 2) * res;

}
main() {
//    freopen("a.in", "r", stdin);
    cin >> n >> k >> mod;
    for(int i = 1; i <= n; i++)
        for(int j = 1; j <= n; j++)
            cin >> a.m[i][j];
    matrix ans = solve(k);
    for(int i = 1; i <= n; i++, puts(""))
        for(int j = 1; j <= n; j++)
            printf("%d ", ans.m[i][j] % mod);
}