欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

多项式多点求值

程序员文章站 2024-03-21 20:25:52
...

给定一个多项式和m个x,求相应的y

我们把需要求值的点均分成两个集合S1,S2,构造两个多项式P1,P2,使得这两个多项式分别为这两个集合的零点。则多项式A%P1对于S1满足A%P1对S1内元素求值和A相同,A%P2对于S2内求值和A相同,而它们次数都是n/2,分治递归下去继续求值即可。

由于多项式取模的数组版还不会写,这里借用的以前的vector版,有非常大的优化空间(vector真香)

#include <bits/stdc++.h>
using namespace std;

const int p = 998244353;

int qpow(int x, int y)
{
    int res = 1;
    while (y > 0)
    {
        if (y & 1)
            res = 1LL * res * x % p;
        x = 1LL * x * x % p;
        y >>= 1;
    }
    return res;
}

void FNTT(vector<int> &A, int len, int flag)
{
    A.resize(len);
    int *r = new int[len];
    r[0] = 0;
    for (int i = 0; i < len; i++)
        r[i] = (r[i >> 1] >> 1) | ((i & 1) * (len >> 1));
    for (int i = 0; i < len; i++)
        if (i < r[i])
            swap(A[i], A[r[i]]);
    int gn, g, t, A0, A1;
    for (int i = 1; i < len; i <<= 1)
    {
        gn = qpow(3, (p - 1) / (i * 2));
        for (int j = 0; j < len; j += (i << 1))
        {
            g = 1;
            A0 = j;
            A1 = A0 + i;
            for (int k = 0; k < i; k++, A0++, A1++, g = (1LL * g * gn) % p)
            {
                t = (1LL * A[A1] * g) % p;
                A[A1] = ((A[A0] - t) % p + p) % p;
                A[A0] = (A[A0] + t) % p;
            }
        }
    }
    if (flag == -1)
    {
        reverse(A.begin() + 1, A.end());
        int inv = qpow(len, p - 2);
        for (int i = 0; i < len; i++)
            A[i] = 1LL * A[i] * inv % p;
    }
    delete []r;
}

vector<int> operator*(vector<int> a, vector<int> b)
{
    int len = 1;
    int sz = a.size() + b.size() - 1;
    while (len <= sz) len <<= 1;
    FNTT(a, len, 1);
    FNTT(b, len, 1);
    vector<int> res;
    res.resize(len);
    for (int i = 0; i < len; i++)
        res[i] = 1LL * a[i] * b[i] % p;
    FNTT(res, len, -1);
    res.resize(sz);
    return res;
}

vector<int> poly_inv(vector<int> a)
{
    if (a.size() == 1)
    {
        a[0] = qpow(a[0], p - 2);
        return a;
    }
    int n = a.size(), newsz = (n + 1) >> 1;
    vector<int> b(a);
    b.resize(newsz);
    b = poly_inv(b);
    int len = 1;
    while (len <= (n << 1)) len <<= 1;
    vector<int> c(a);
    FNTT(a, len, 1);
    FNTT(b, len, 1);
    for (int i = 0; i < len; i++)
        a[i] = ((1LL * b[i] * (2 - 1LL * a[i] * b[i] % p)) % p + p) % p;
    FNTT(a, len, -1);
    a.resize(n);
    return a;
}

vector<int> poly_r(vector<int> a)
{
    reverse(a.begin(), a.end());
    return a;
}

void div(vector<int> f, vector<int> g, vector<int> &q, vector<int> &r)
{
    int n = f.size() - 1, m = g.size() - 1;
    vector<int> gr = poly_r(g);
    gr.resize(n - m + 1);
    q = poly_r(f) * poly_inv(gr);
    q.resize(n - m + 1);
    q = poly_r(q);
    vector<int> gq = g * q;
    r.resize(m);
    gq.resize(m);
    f.resize(m);
    for (int i = 0; i < m; i++)
        r[i] = ((f[i] - gq[i]) % p + p) % p;
}

int n, m;
vector<int> f, tmp[1000010];
int a[100010], res[100010], le[1000010], re[1000010], tot;

vector<int> prework(int l, int r)
{
    int id = ++tot;
    if (l == r)
    {
        vector<int> res;
        res.push_back(p - a[l]);
        res.push_back(1);
        tmp[id] = res;
        return res;
    }
    int mid = (l + r) / 2;
    le[id] = tot + 1;
    vector<int> res = prework(l, mid);
    re[id] = tot + 1;
    res = res * prework(mid + 1, r);
    return tmp[id] = res;
}

void work(int l, int r, vector<int> sb)
{
    int id = ++tot;
    if (l == r)
    {
        int tmp = 1;
        for (int i = 0; i < (int)sb.size(); i++)
            res[l] = (res[l] + tmp * sb[i]) % p, tmp = tmp * (long long)a[l] % p;
        return;
    }
    vector<int> fl = tmp[le[id]], fr = tmp[re[id]];
    vector<int> tmp1, rel, rer;
    div(sb, fl, tmp1, rel);
    div(sb, fr, tmp1, rer);
    int mid = (l + r) / 2;
    work(l, mid, rel);
    work(mid + 1, r, rer);
}

int main()
{
    scanf("%d%d", &n, &m); f.resize(n + 1);
    for (int i = 0; i <= n; i++) scanf("%d", &f[i]);
    for (int i = 1; i <= m; i++) scanf("%d", &a[i]);
    prework(1, max(n, m));
    tot = 0;
    work(1, max(n, m), f);
    for (int i = 1; i <= m; i++) printf("%d\n", res[i]);
    return 0;
}