欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页  >  IT编程

BZOJ3992: [SDOI2015]序列统计(NTT 原根 生成函数)

程序员文章站 2022-05-25 19:56:46
题意 "题目链接" 给出大小为$S$的集合,从中选出$N$个数,满足他们的乘积$\% M = X$的方案数 Sol 神仙题Orz 首先不难列出最裸的dp方程,设$f[i][j]$表示选了$i$个数,他们的乘积为$j$的方案数 设$g[k] = [\exists a_i = k]$ 转移的时候 $$f ......

题意

题目链接

给出大小为\(s\)的集合,从中选出\(n\)个数,满足他们的乘积\(\% m = x\)的方案数

sol

神仙题orz

首先不难列出最裸的dp方程,设\(f[i][j]\)表示选了\(i\)个数,他们的乘积为\(j\)的方案数

\(g[k] = [\exists a_i = k]\)

转移的时候

\[f[i + 1][(j * k) \% m] += f[i][j] * g[k]\]

不难发现每次的转移都是相同的,因此可以直接矩阵快速幂,时间复杂度变为\(logn m^2\)

观察上面的式子,如果我们能把\((j * k) \% m\),变成\((j + k) \% m\)的话,就是一个循环卷积的形式了

这里可以用原根来实现,设\(g\)表示\(m\)的原根,\(mp[i] = j\)表示\(g^i = j\)

直接对每个物品构造生成函数,利用mp转移即可

因为转移是个循环卷积,所以统计答案的时候应该把第\(i\)项和第\(i+m-1\)项的系数加起来

至于为啥只统计一项。

BZOJ3992: [SDOI2015]序列统计(NTT 原根 生成函数)

#include<bits/stdc++.h>
using namespace std;
const int mod = 1004535809, g = 3, gi = 334845270, maxn = 1e5 + 10; 
inline int read() {
    char c = getchar(); int x = 0, f = 1;
    while(c < '0' || c > '9') {if(c == '-') f = -1; c = getchar();}
    while(c >= '0' && c <= '9') x = x * 10 + c - '0', c = getchar();
    return x * f;
}
int n, m, x, s;
int r[maxn], lim, l, ind[maxn], s[maxn], f[maxn], a[maxn], b[maxn];
int mul(int a, int b) {
    return 1ll * a * b % mod;
}
int add(int x, int y) {
    if(x + y < 0) return x + y + mod;
    return x + y >= mod ? x + y - mod : x + y;
}
int dec(int x, int y) {
    return x - y < 0 ? x - y + mod : x - y;
}
int fp(int a, int p, int mod) {
    int base = 1;
    while(p) {
        if(p & 1) base = 1ll * base * a % mod; 
        a = 1ll * a * a % mod; p >>= 1;
    }
    return base;
}
int getg(int x) {
    static int q[maxn]; int tot = 0, tp = x - 1;
    for(int i = 2; i * i <= tp; i++) {
        if(!(tp % i)) {
            q[++tot] = i;
            while(!(tp % i)) tp /= i;
        }
    }
    if(tp > 1) q[++tot] = tp;
    for(int i = 2, j; i <= x - 1; i++) {
        for(j = 1; j <= tot; j++) if(fp(i, (x - 1) / q[j], x) == 1) break;
        if(j == tot + 1) return i;
    }
}
void ntt(int *a, int n, int type) {
    for(int i = 1; i < n; i++) if(i < r[i]) swap(a[i], a[r[i]]);
    for(int mid = 1; mid < n; mid <<= 1) {
        int r = mid << 1, wn = fp(type == 1 ? g : gi, (mod - 1) / r, mod);
        for(int j = 0; j < lim; j += r) {
            for(int w = 1, k = 0; k < mid; k++, w = mul(w, wn)) {
                int x = a[j + k], y = mul(w, a[j + k + mid]);
                a[j + k] = add(x, y);
                a[j + k + mid] = dec(x, y);
            }
        }
    }
    if(type == -1) {
        for(int i = 0, inv = fp(lim, mod - 2, mod); i < n; i++) a[i] = mul(a[i], inv);
    }
}
void mul(int *a1, int *b1, int *c) {
    memset(a, 0, sizeof(a)); memset(b, 0, sizeof(b));//tag
    for(int i = 0; i < m - 1; i++) a[i] = a1[i], b[i] = b1[i];
    ntt(a, lim, 1); ntt(b, lim, 1);
    for(int i = 0; i < lim; i++) a[i] = mul(a[i], b[i]);
    ntt(a, lim, -1);
    for(int i = 0; i < m - 1; i++) c[i] = add(a[i], a[i + m - 1]);
}
void pre() {
    lim = 1;
    while(lim <= 2 * (m - 2)) lim <<= 1, l++;
    for(int i = 0; i < lim; i++) r[i] = (r[i >> 1] >> 1) | (i & 1) << (l - 1);
    int d = getg(m);
    for(int i = 0; i < m - 1; i++) ind[fp(d, i, m)] = i;
}
int main() {
    n = read(); m = read(); x = read(); s = read();
    pre();
    for(int i = 1; i <= s; i++) {
        int x = read();
        if(x) f[ind[x]]++;
    }
    s[ind[1]] = 1;
    while(n) {
        if(n & 1) mul(s, f, s);
        mul(f, f, f); n >>= 1;
    }
    printf("%d", s[ind[x]]);
    return 0;
}
/*
40000000 3 1 2
1 2

4 3 1 2
1 2
*/