欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

4.13 ~ 4.20做题情况

程序员文章站 2022-07-15 08:46:05
...

这周主要在做莫队及其各种变形算法和FFT/NTT的一些板子基础题

小B的询问:莫队

最最基础的莫队,没啥好说的

#include <iostream>
#include <cstdio>
#include <cmath>
#include <algorithm>
using namespace std;
const int maxn = 5e4 + 50;
int n,m,k,l = 1,r,sqt,sum,a[maxn],c[maxn],pos[maxn],ans[maxn];
struct query{
    int l,r,pos;
}q[maxn];
int read(){
    int x = 0;
    char c = getchar();
    while(c < '0' || c > '9') c = getchar();
    while(c >= '0' && c <= '9') x = x * 10 + (c ^ 48),c = getchar();
    return x;
}
bool cmp(query a,query b){
    if(pos[a.l] != pos[b.l]) return pos[a.l] < pos[b.l];
    else{
        if(pos[a.l] & 1) return a.r < b.r;
        else return a.r > b.r;
    }
}
inline void add(int x){
    sum += 2 * c[x] + 1,c[x] ++;
}
inline void del(int x){
    sum -= 2 * c[x] - 1,c[x] --;
}
int main(){
    n = read(),m = read(),k = read(),sqt = sqrt(n);
    for(int i = 1; i <= n; i ++) a[i] = read(),pos[i] = (i - 1) / sqt + 1;
    for(int i = 1; i <= m; i ++) q[i].l = read(),q[i].r = read(),q[i].pos = i;
    sort(q + 1,q + m + 1,cmp);
    for(int i = 1; i <= m; i ++){
        while(l < q[i].l) del(a[l]),l ++;
        while(l > q[i].l) l --,add(a[l]);
        while(r < q[i].r) r ++,add(a[r]);
        while(r > q[i].r) del(a[r]),r --;
        ans[q[i].pos] = sum;
    }
    for(int i = 1; i <= m; i ++) printf("%d\n",ans[i]);
    return 0;
}

大爷的字符串题:莫队

本题主要考察了选手的语文能力 题目的转化十分巧妙,然后就变成了一道普通的莫队

#include <iostream>
#include <cstdio>
#include <cmath>
#include <algorithm>
using namespace std;
const int maxn = 2e5 + 50;
int n,m,l = 1,r,len,sqt,sum,a[maxn],pos[maxn],t[maxn],num[maxn],cnt[maxn],ans[maxn];
struct query{
    int l,r,pos;
}q[maxn];
int read(){
    int x = 0;
    char c = getchar();
    while(c < '0' || c > '9') c = getchar();
    while(c >= '0' && c <= '9') x = x * 10 + (c ^ 48),c = getchar();
    return x;
}
bool cmp(query a,query b){
    if(pos[a.l] != pos[b.l]) return pos[a.l] < pos[b.l];
    else return a.r < b.r;
}
void add(int x){
    if(sum == num[x]) sum ++;
    cnt[num[x]] --,num[x] ++,cnt[num[x]] ++;
}
void del(int x){
    if(sum == num[x] && cnt[num[x]] == 1) sum --;
    cnt[num[x]] --,num[x] --,cnt[num[x]] ++;
}
int main(){
    n = read(),m = read(),sqt = sqrt(n);
    for(int i = 1; i <= n; i ++) a[i] = t[i] = read(),pos[i] = (i - 1) / sqt + 1;
    sort(t + 1,t + n + 1);
    len = unique(t + 1,t + n + 1) - t - 1;
    for(int i = 1; i <= n; i ++) a[i] = lower_bound(t + 1,t + len + 1,a[i]) - t;
    for(int i = 1; i <= m; i ++) q[i].l = read(),q[i].r = read(),q[i].pos = i;
    sort(q + 1,q + m + 1,cmp);
    for(int i = 1; i <= n; i ++){
        while(r < q[i].r) r ++,add(a[r]);
        while(r > q[i].r) del(a[r]),r --;
        while(l < q[i].l) del(a[l]),l ++;
        while(l > q[i].l) l --,add(a[l]);
        ans[q[i].pos] = sum;
    }
    for(int i = 1; i <= m; i ++) printf("%d\n",-ans[i]);
    return 0;
}

小清新人渣的本愿:莫队,bitset

用到了类似状压的思想,由于长度高达10510^5所以不能直接用普通整形,可以用一个bitset维护 (表示第一次听说bitset这种东西 ) 。其中加法的询问转化比较巧妙,不过题解说都是很套路的东西?

#include <iostream>
#include <cstdio>
#include <cmath>
#include <algorithm>
#include <bitset>
using namespace std;
const int maxn = 1e5 + 50,N = 1e5;
int n,m,sqt,l,r,num[maxn],pos[maxn],cnt[maxn],ans[maxn];
bitset <maxn> a,b;
struct query{
    int l,r,x,opt,pos;
}q[maxn];
int read(){
    int x = 0;
    char c = getchar();
    while(c < '0' || c > '9') c = getchar();
    while(c >= '0' && c <= '9') x = x * 10 + (c ^ 48),c = getchar();
    return x;
}
bool cmp(query a,query b){
    if(pos[a.l] != pos[b.l]) return pos[a.l] < pos[b.l];
    else{
        if(pos[a.l] & 1) return a.r < b.r;
        else return a.r > b.r;
    }
}
void add(int x){
    if(!cnt[x]) a[x] = 1,b[N - x] = 1;
    cnt[x] ++;
}
void del(int x){
    cnt[x] --;
    if(!cnt[x]) a[x] = 0,b[N - x] = 0;
}
int main(){
    n = read(),m = read(),sqt = sqrt(n);
    for(int i = 1; i <= n; i ++) num[i] = read(),pos[i] = (i - 1) / sqt + 1;
    for(int i = 1; i <= m; i ++) q[i].opt = read(),q[i].l = read(),q[i].r = read(),q[i].x = read(),q[i].pos = i;
    sort(q + 1,q + m + 1,cmp);
    for(int i = 1; i <= m; i ++){
        while(l < q[i].l) del(num[l]),l ++;
        while(l > q[i].l) l --,add(num[l]);
        while(r < q[i].r) r ++,add(num[r]);
        while(r > q[i].r) del(num[r]),r --;
        if(q[i].opt == 1){
            if((a & (a << q[i].x)).count()) ans[q[i].pos] = 1;
        }else if(q[i].opt == 2){
            if((a & (b >> (N - q[i].x))).count()) ans[q[i].pos] = 1;
        }else{
            for(int j = 1; j * j <= q[i].x; j ++){
                if(q[i].x % j == 0 && a[j] && a[q[i].x / j]){
                    ans[q[i].pos] = 1;
                    break;
                }
            }
        }
    }
    for(int i = 1; i <= m; i ++){
        if(ans[i]) printf("hana\n");
        else printf("bi\n");
    }
    return 0;
}

COT2 - Count on a tree II:树上莫队模板

求出树的欧拉序即可按照某些方法转化成区间问题,套莫队就好了

我在想能不能用LCT做

#include <iostream>
#include <cstdio>
#include <cmath>
#include <algorithm>
using namespace std;
const int maxn = 4e4 + 50,maxm = 1e5 + 50;
int n,m,x,y,l = 1,r,sum,len,cnt,num[maxn],t[maxn],a[2 * maxn],pos[2 * maxn],fst[maxn],lst[maxn],last[maxn],d[maxn],vis[maxn],tot[maxn],ans[maxm],f[maxn][20];
struct edge{
    int v,nxt;
}e[2 * maxn];
struct query{
    int l,r,lca,pos;
}q[maxm];
int read(){
    int x = 0;
    char c = getchar();
    while(c < '0' || c > '9') c = getchar();
    while(c >= '0' && c <= '9') x = x * 10 + (c ^ 48),c = getchar();
    return x;
}
bool cmp(query a,query b){
    if(pos[a.l] != pos[b.l]) return pos[a.l] < pos[b.l];
    else{
        if(pos[a.l] & 1) return a.r < b.r;
        else return a.r > b.r;
    }
}
inline void insert(int x,int y){
    cnt ++,e[cnt].v = y,e[cnt].nxt = last[x],last[x] = cnt;
}
void dfs(int u,int fa){
    a[++cnt] = u,fst[u] = cnt,d[u] = d[fa] + 1,f[u][0] = fa;
    for(int i = 1; f[f[u][i - 1]][i - 1]; i ++) f[u][i] = f[f[u][i - 1]][i - 1];
    for(int i = last[u]; i; i = e[i].nxt){
        int v = e[i].v;
        if(v != fa) dfs(v,u);
    }
    a[++cnt] = u,lst[u] = cnt;
}
int Lca(int x,int y){
    if(d[x] < d[y]) swap(x,y);
    for(int i = 16; i >= 0; i --) if(d[f[x][i]] >= d[y]) x = f[x][i];
    if(x == y) return x;
    for(int i = 16; i >= 0; i --) if(f[x][i] != f[y][i]) x = f[x][i],y = f[y][i];
    return f[x][0];
}
void update(int x){
    if(!x) return;
    vis[x] ^= 1;
    if(!vis[x]){
        tot[num[x]] --;
        if(!tot[num[x]]) sum --;
    }else{
        tot[num[x]] ++;
        if(tot[num[x]] == 1) sum ++;
    }
}
int main(){
    n = read(),m = read();
    for(int i = 1; i <= n; i ++) num[i] = t[i] = read();
    sort(t + 1,t + n + 1);
    len = unique(t + 1,t + n + 1) - t - 1;
    for(int i = 1; i <= n; i ++) num[i] = lower_bound(t + 1,t + len + 1,num[i]) - t;
    for(int i = 1; i < n; i ++){
        x = read(),y = read();
        insert(x,y),insert(y,x);
    }
    cnt = 0;
    dfs(1,0);
    int sqt = sqrt(cnt),t = ceil((double) cnt / sqt);
	for(int i = 1; i <= t; ++i)
		for(int j = sqt * (i - 1) + 1; j <= i * sqt; ++j)
            pos[j] = i;
    for(int i = 1; i <= m; i ++){
        x = read(),y = read(),q[i].pos = i;
        int t = Lca(x,y);
        if(fst[x] > fst[y]) swap(x,y);
        if(t == x) q[i].l = fst[x],q[i].r = fst[y];
        else q[i].l = lst[x],q[i].r = fst[y],q[i].lca = t;
    }
    sort(q + 1,q + m + 1,cmp);
    for(int i = 1; i <= m; i ++){
        while(l < q[i].l) update(a[l]),l ++;
        while(l > q[i].l) l --,update(a[l]);
        while(r < q[i].r) r ++,update(a[r]);
        while(r > q[i].r) update(a[r]),r --;
        update(q[i].lca);
        ans[q[i].pos] = sum;
        update(q[i].lca);
    }
    for(int i = 1; i <= m; i ++) printf("%d\n",ans[i]);
    return 0;
}

[国家集训队] 数颜色:带修莫队模板,毒瘤卡常

谁来告诉我块大小到底取n3t\sqrt{n^3t}还是n23n^{\frac23}

在原算法的基础上加一维时间,表示当前进行了几个修改操作,移动的时候再考虑上时间即可。有个小优化是修改时直接swap要修改的值与当前值,这样就不用判断当前是修改还是撤回修改,减小了代码量。

真的就是毒瘤题 ,实在卡不过去了,就从pb大佬的剪贴板里粘了一大堆 几行东西,不过就算这样最大点也只差0.02s啊!!!

人傻常数大的代码

#pragma GCC optimize(2)
#pragma GCC optimize(3)
#pragma GCC optimize("Ofast")
#pragma GCC optimize("inline")
#pragma GCC optimize("-fgcse")
#pragma GCC optimize("-fgcse-lm")
#pragma GCC optimize("-fipa-sra")
#pragma GCC optimize("-ftree-pre")
#pragma GCC optimize("-ftree-vrp")
#pragma GCC optimize("-fpeephole2")
#pragma GCC optimize("-ffast-math")
#pragma GCC optimize("-fsched-spec")
#pragma GCC optimize("unroll-loops")
#pragma GCC optimize("-falign-jumps")
#pragma GCC optimize("-falign-loops")
#pragma GCC optimize("-falign-labels")
#pragma GCC optimize("-fdevirtualize")
#pragma GCC optimize("-fcaller-saves")
#pragma GCC optimize("-fcrossjumping")
#pragma GCC optimize("-fthread-jumps")
#pragma GCC optimize("-funroll-loops")
#pragma GCC optimize("-fwhole-program")
#pragma GCC optimize("-freorder-blocks")
#pragma GCC optimize("-fschedule-insns")
#pragma GCC optimize("inline-functions")
#pragma GCC optimize("-ftree-tail-merge")
#pragma GCC optimize("-fschedule-insns2")
#pragma GCC optimize("-fstrict-aliasing")
#pragma GCC optimize("-fstrict-overflow")
#pragma GCC optimize("-falign-functions")
#pragma GCC optimize("-fcse-skip-blocks")
#pragma GCC optimize("-fcse-follow-jumps")
#pragma GCC optimize("-fsched-interblock")
#pragma GCC optimize("-fpartial-inlining")
#pragma GCC optimize("no-stack-protector")
#pragma GCC optimize("-freorder-functions")
#pragma GCC optimize("-findirect-inlining")
#pragma GCC optimize("-frerun-cse-after-loop")
#pragma GCC optimize("inline-small-functions")
#pragma GCC optimize("-finline-small-functions")
#pragma GCC optimize("-ftree-switch-conversion")
#pragma GCC optimize("-foptimize-sibling-calls")
#pragma GCC optimize("-fexpensive-optimizations")
#pragma GCC optimize("-funsafe-loop-optimizations")
#pragma GCC optimize("inline-functions-called-once")
#pragma GCC optimize("-fdelete-null-pointer-checks")

#include <iostream>
#include <cstdio>
#include <cmath>
#include <algorithm>
using namespace std;
const int maxn = 1.4e5,maxv = 1e6 + 50;
int n,m,x,y,l = 1,r,t,sum,cntq,cntr,len,num[maxn],cnt[maxv],ans[maxn];
struct node{
    int pos,val;
}a[maxn];
struct query{
    int l,r,t,pos;
}q[maxn];
int read(){
    int x = 0;
    char c = getchar();
    while(c < '0' || c > '9') c = getchar();
    while(c >= '0' && c <= '9') x = x * 10 + (c ^ 48),c = getchar();
    return x;
}
void print(int x){
    int h = 0,c[20];
    while(x) c[++h] = x % 10,x /= 10;
    for(int i = h; i >= 1; i --) putchar(c[i] + '0');
    putchar('\n');
}
bool cmp(query a,query b){
    if(a.l / len!= b.l / len) return a.l / len < b.l / len;
    else if(a.r / len != b.r / len) return a.r / len < b.r / len;
    else return a.t < b.t;
}
// void add(int x){
//     if(!cnt[x]) sum ++;
//     cnt[x] ++;
// }
// void del(int x){
//     cnt[x] --;
//     if(!cnt[x]) sum --;
// }
int main(){
    n = read(),m = read();
    for(int i = 1; i <= n; i ++) num[i] = read();
    for(int i = 1; i <= m; i ++){
        if(getchar() == 'Q') cntq ++,q[cntq].l = read(),q[cntq].r = read(),q[cntq].t = cntr,q[cntq].pos = cntq;
        else cntr ++,a[cntr].pos = read(),a[cntr].val = read();
    }
    len = ceil(exp((log(n) + log(cntr)) / 3));
    sort(q + 1,q + cntq + 1,cmp);
    for(int i = 1; i <= cntq; i ++){
        while(l < q[i].l) sum -= !--cnt[num[l++]];
        while(l > q[i].l) sum += !cnt[num[--l]] ++;
        while(r < q[i].r) sum += !cnt[num[++r]] ++;
        while(r > q[i].r) sum -= !--cnt[num[r--]];
        while(t != q[i].t){
            if(t < q[i].t) t ++;
            if(l <= a[t].pos && r >= a[t].pos){
                cnt[num[a[t].pos]] --;
                if(!cnt[num[a[t].pos]]) sum --;
                if(!cnt[a[t].val]) sum ++;
                cnt[a[t].val] ++;
            }
            swap(num[a[t].pos],a[t].val);
            if(t > q[i].t) t --;
        }
        ans[q[i].pos] = sum;
    }
    for(int i = 1; i <= cntq; i ++) print(ans[i]);
    return 0;
}

回滚莫队模板

用于某些不好删除但可以插入的题,对每一块计算左端点在当前块的区间,因为左端点在同一块的区间右端点单调递增,每次右端点可以从上一个右端点继续扩展,左端点只需要从min{}min\{当前区间右端点,块右端\}暴力向左扩展,这样就避免了莫队的删除操作。
考虑时间复杂度,凭感觉 取块的大小为n\sqrt n
对于左端点,每个区间至多从块右端扩展到块左端,时间为O(n)O(\sqrt n),有mm个询问,则总复杂度为O(mn)O(m\sqrt n)
对于右端点,每块至多从块左端扩展到nn号点,单块时间O(n)O(n),共nn=n\frac{n}{\sqrt n} = \sqrt n块,总复杂度O(nn)O(n\sqrt n)
一般n,mn,m同阶,故总复杂度O(nn)O(n\sqrt n)

#include <iostream>
#include <cstdio>
#include <cmath>
#include <algorithm>
using namespace std;
const int maxn = 2e5 + 50,inf = 1e9;
int n,m,l,r,len,tot,mx,tmx,a[maxn],t[maxn],pos[maxn],tl[maxn],tr[maxn],lb[maxn],rb[maxn],ans[maxn];
struct query{
    int l,r,pos;
}q[maxn];
int read(){
    int x = 0;
    char c = getchar();
    while(c < '0' || c > '9') c = getchar();
    while(c >= '0' && c <= '9') x = x * 10 + (c ^ 48),c = getchar();
    return x;
}
bool cmp(query a,query b){
    if(pos[a.l] != pos[b.l]) return pos[a.l] < pos[b.l];
    else return a.r < b.r;
}
inline void update(int x){
    lb[a[x]] = min(lb[a[x]],x),rb[a[x]] = max(rb[a[x]],x);
}
int main(){
    n = read(),len = sqrt(n);
    for(int i = 1; i <= n; i ++) pos[i] = (i - 1) / len + 1;
    for(int i = 1; i <= n; i ++) a[i] = t[i] = read();
    sort(t + 1,t + n + 1);
    tot = unique(t + 1,t + n + 1) - t - 1;
    for(int i = 1; i <= n; i ++) a[i] = lower_bound(t + 1,t + tot + 1,a[i]) - t;
    m = read();
    for(int i = 1; i <= m; i ++) q[i].l = read(),q[i].r = read(),q[i].pos = i;
    sort(q + 1,q + m + 1,cmp);
    for(int i = 1; i <= n; i ++) lb[i] = inf,rb[i] = 0;
    for(int k = 1,i = 1; k <= pos[n]; k ++){
        while(i <= m && q[i].r <= k * len){
            for(int j = q[i].l; j <= q[i].r; j ++) update(j),mx = max(mx,rb[a[j]] - lb[a[j]]);
            ans[q[i].pos] = mx;
            for(int j = q[i].l; j <= q[i].r; j ++) lb[a[j]] = inf,rb[a[j]] = 0;
            mx = 0,i ++;
        }
        r = k * len;
        if(r >= m) continue;
        while(i <= m && pos[q[i].l] == k){
            l = k * len + 1;
            while(r < q[i].r) r ++,update(r),mx = max(mx,rb[a[r]] - lb[a[r]]);
            tmx = mx;
            for(int j = q[i].l; j <= k * len; j ++) tl[a[j]] = lb[a[j]],tr[a[j]] = rb[a[j]];
            while(l > q[i].l) l --,update(l),mx = max(mx,rb[a[l]] - lb[a[l]]);
            for(int j = q[i].l; j <= k * len; j ++) lb[a[j]] = tl[a[j]],rb[a[j]] = tr[a[j]];
            ans[q[i].pos] = mx,mx = tmx,i ++;
        }
        mx = 0;
        for(int j = k * len + 1; j <= r; j ++) lb[a[j]] = inf,rb[a[j]] = 0;
    }
    for(int i = 1; i <= m; i ++) printf("%d\n",ans[i]);
    return 0;
}

歴史の研究:回滚莫队

国外OJ真烦,不给错误信息,因为long long调了三个小时

真就不开long long见祖宗

#include <iostream>
#include <cstdio>
#include <cmath>
#include <algorithm>
using namespace std;
const int maxn = 1e5 + 50;
int n,m,l,r,len,tot,a[maxn],num[maxn],t[maxn],pos[maxn],tmp[maxn],cnt[maxn];
long long mx,tmx,ans[maxn];
struct query{
    int l,r,pos;
}q[maxn];
int read(){
    int x = 0;
    char c = getchar();
    while(c < '0' || c > '9') c = getchar();
    while(c >= '0' && c <= '9') x = x * 10 + (c ^ 48),c = getchar();
    return x;
}
bool cmp(query a,query b){
    if(pos[a.l] != pos[b.l]) return pos[a.l] < pos[b.l];
    else return a.r < b.r;
}
int main(){
    n = read(),m = read(),len = sqrt(n);
    for(int i = 1; i <= n; i ++) pos[i] = (i - 1) / len + 1;
    for(int i = 1; i <= n; i ++) num[i] = t[i] = read();
    sort(t + 1,t + n + 1);
    tot = unique(t + 1,t + n + 1) - t - 1;
    for(int i = 1; i <= n; i ++) a[i] = lower_bound(t + 1,t + tot + 1,num[i]) - t;
    for(int i = 1; i <= m; i ++) q[i].l = read(),q[i].r = read(),q[i].pos = i;
    sort(q + 1,q + m + 1,cmp);
    // for(int i = 1; i <= m; i ++) cout << q[i].pos <<endl;
    // cout << endl;
    for(int k = 1,i = 1; k <= pos[n]; k ++){
        while(i <= m && q[i].r <= k * len){
            for(int j = q[i].l; j <= q[i].r; j ++) cnt[a[j]] ++,mx = max(mx,1ll * num[j] * cnt[a[j]]);
            ans[q[i].pos] = mx;
            for(int j = q[i].l; j <= q[i].r; j ++) cnt[a[j]] = 0;
            mx = 0,i ++;
        }
        r = k * len;
        if(r >= m) continue;
        while(i <= m && pos[q[i].l] == k){
            l = k * len + 1;
            while(r < q[i].r) r ++,cnt[a[r]] ++,mx = max(mx,1ll * num[r] * cnt[a[r]]);
            tmx = mx;
            for(int j = q[i].l; j <= k * len; j ++) tmp[a[j]] = cnt[a[j]];
            while(l > q[i].l) l --,cnt[a[l]] ++,mx = max(mx,1ll * num[l] * cnt[a[l]]);
            for(int j = q[i].l; j <= k * len; j ++) cnt[a[j]] = tmp[a[j]];
            // for(int j = q[i].l; j <= k * len; j ++) cout << j << ' ' << num[j] << ' ' << cnt[a[j]] << endl;
            // cout << i << ' ' << mx << endl;
            ans[q[i].pos] = mx,mx = tmx,i ++;
        }
        mx = 0;
        for(int j = k * len + 1; j <= r; j ++) cnt[a[j]] = 0;
        // cout << k << ' ' << i << endl;
    }
    // cout << endl;
    for(int i = 1; i <= m; i ++) printf("%lld\n",ans[i]);
    return 0;
}

[WC2013]糖果公园:树上莫队+带修莫队

大\qquad\qquad\qquad\qquad\qquad杂\qquad\qquad\qquad\qquad\qquad烩
莫队毕业题,融合了几个算法,要注意移动时间的时候也要判欧拉序里的vis标记
回滚莫队:我不配拥有姓名

#include <iostream>
#include <cstdio>
#include <cmath>
#include <algorithm>
using namespace std;
const int maxn = 1e5 + 50,maxp = 2 * maxn,maxv = 1e6;
int n,m,Q,x,y,l = 1,r,t,len,type,ecnt,pcnt,cntr,cntq,v[maxn],w[maxn],c[maxn],p[maxp],pos[maxp],fst[maxn],lst[maxn],last[maxn],d[maxn],f[maxn][20],cnt[maxv],vis[maxn];
long long sum,ans[maxn];
struct edge{
    int v,nxt;
}e[2 * maxn];
struct revise{
    int pos,val;
}a[maxn];
struct query{
    int l,r,t,lca,pos;
}q[maxn];
int read(){
    int x = 0;
    char c = getchar();
    while(c < '0' || c > '9') c = getchar();
    while(c >= '0' && c <= '9') x = x * 10 + (c ^ 48),c = getchar();
    return x;
}
inline void insert(int x,int y){
    ecnt ++,e[ecnt].v = y,e[ecnt].nxt = last[x],last[x] = ecnt;
}
bool cmp(query a,query b){
    if(pos[a.l] != pos[b.l]) return pos[a.l] < pos[b.l];
    else if(pos[a.r] != pos[b.r]) return pos[a.r] < pos[b.r];
    else return a.t < b.t;
}
void dfs(int u,int fa){
    pcnt ++,p[pcnt] = u,fst[u] = pcnt,d[u] = d[fa] + 1,f[u][0] = fa;
    for(int i = 1; i <= 17; i ++) f[u][i] = f[f[u][i - 1]][i - 1];
    for(int i = last[u]; i; i = e[i].nxt){
        int v = e[i].v;
        if(v != fa) dfs(v,u);
    }
    pcnt ++,p[pcnt] = u,lst[u] = pcnt;
}
int Lca(int x,int y){
    if(d[x] < d[y]) swap(x,y);
    for(int i = 17; i >= 0; i --) if(d[f[x][i]] >= d[y]) x = f[x][i];
    if(x == y) return x;
    for(int i = 17; i >= 0; i --) if(f[x][i] != f[y][i]) x = f[x][i],y = f[y][i];
    return f[x][0];
}
inline void update(int x){
    if(!x) return;
    vis[x] ^= 1;
    if(vis[x]) cnt[c[x]] ++,sum += 1ll * w[cnt[c[x]]] * v[c[x]];
    else sum -= 1ll * w[cnt[c[x]]] * v[c[x]],cnt[c[x]] --;
}
int main(){
    n = read(),m = read(),Q = read();
    for(int i = 1; i <= m; i ++) v[i] = read();
    for(int i = 1; i <= n; i ++) w[i] = read();
    for(int i = 1; i < n; i ++){
        x = read(),y = read();
        insert(x,y),insert(y,x);
    }
    dfs(1,0);
    len = pow(pcnt,2.0 / 3);
    for(int i = 1; i <= pcnt; i ++) pos[i] = (i - 1) / len + 1;
    for(int i = 1; i <= n; i ++) c[i] = read();
    for(int i = 1; i <= Q; i ++){
        type = read(),x = read(),y = read();
        if(type == 0) cntr ++,a[cntr].pos = x,a[cntr].val = y;
        else{
            cntq ++,q[cntq].t = cntr,q[cntq].pos = cntq;
            if(fst[x] > fst[y]) swap(x,y);
            int t = Lca(x,y);
            if(t == x) q[cntq].l = fst[x],q[cntq].r = fst[y];
            else q[cntq].l = lst[x],q[cntq].r = fst[y],q[cntq].lca = t;
        }
    }
    sort(q + 1,q + cntq + 1,cmp);
    for(int i = 1; i <= cntq; i ++){
        while(l < q[i].l) update(p[l]),l ++;
        while(l > q[i].l) l --,update(p[l]);
        while(r < q[i].r) r ++,update(p[r]);
        while(r > q[i].r) update(p[r]),r --;
        while(t != q[i].t){
            if(t < q[i].t) t ++;
            if(vis[a[t].pos]){
                sum -= 1ll * w[cnt[c[a[t].pos]]] * v[c[a[t].pos]],cnt[c[a[t].pos]] --;
                cnt[a[t].val] ++,sum += 1ll * w[cnt[a[t].val]] * v[a[t].val];
            }
            swap(c[a[t].pos],a[t].val);
            if(t > q[i].t) t --;
        }
        update(q[i].lca);
        ans[q[i].pos] = sum;
        update(q[i].lca);
    }
    for(int i = 1; i <= cntq; i ++) printf("%lld\n",ans[i]);
    return 0;
}

FFT模板

看了半天证明,打代码时才发现并没什么用

#include <iostream>
#include <cstdio>
#include <cstring>
#include <cmath>
using namespace std;
const int maxn = 2.1e6 + 50;
const double Pi = acos(-1);
int n,m,N,l,rev[maxn],ans[maxn];
char s[maxn];
struct cp{
    double x,y;
}a[maxn],b[maxn];
cp operator + (cp a,cp b){
    return (cp){a.x + b.x,a.y + b.y};
}
cp operator - (cp a,cp b){
    return (cp){a.x - b.x,a.y - b.y};
}
cp operator * (cp a,cp b){
    return (cp){a.x * b.x - a.y * b.y,a.x * b.y + a.y * b.x};
}
int read(){
    int x = 0;
    char c = getchar();
    while(c < '0' || c > '9') c = getchar();
    while(c >= '0' && c <= '9') x = x * 10 + (c ^ 48),c = getchar();
    return x;
}
void FFT(cp *a,int p){
    for(int i = 0; i < N; i ++) if(i < rev[i]) swap(a[i],a[rev[i]]);
    for(int i = 1; i < N; i <<= 1){
        cp t = (cp){cos(Pi / i),p * sin(Pi / i)};
        for(int j = 0; j < N; j += i * 2){
            cp w = (cp){1,0};
            for(int k = j; k < j + i; k ++){
                cp t1 = a[k],t2 = w * a[k + i];
                a[k] = t1 + t2,a[k + i] = t1 - t2;
                w = w * t;
            }
        }
    }
    if(p == -1) for(int i = 0; i < N; i ++) a[i].x /= N;
}
int main(){
    n = read(),m = read();
    for(int i = 0; i <= n; i ++) a[i].x = read();
    for(int i = 0; i <= m; i ++) b[i].x = read();
    N = 1;
    while(N <= n + m) N <<= 1,l ++;
    for(int i = 1; i <= N; i ++) rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (l - 1));
    FFT(a,1),FFT(b,1);
    for(int i = 0; i < N; i ++) a[i] = a[i] * b[i];
    FFT(a,-1);
    for(int i = 0; i <= n + m; i ++) printf("%d ",int(a[i].x + 0.5));
    return 0;
}

A*B Problem:FFT

听说,恶臭的题号与恶臭的样例更陪哦(雾

用FFT加速高精乘,把每个数看做一个多项式,系数为数的每一位(将10带入即数的值),用FFT求两多项式之积即可

#include <iostream>
#include <cstdio>
#include <cstring>
#include <cmath>
using namespace std;
const int maxn = 2.1e6 + 50;
const double Pi = acos(-1);
int n,m,N,l,rev[maxn],ans[maxn];
char s[maxn];
struct cp{
    double x,y;
}a[maxn],b[maxn];
cp operator + (cp a,cp b){
    return (cp){a.x + b.x,a.y + b.y};
}
cp operator - (cp a,cp b){
    return (cp){a.x - b.x,a.y - b.y};
}
cp operator * (cp a,cp b){
    return (cp){a.x * b.x - a.y * b.y,a.x * b.y + a.y * b.x};
}
void FFT(cp *a,int p){
    for(int i = 0; i < N; i ++) if(i < rev[i]) swap(a[i],a[rev[i]]);
    for(int i = 1; i < N; i <<= 1){
        cp t = (cp){cos(Pi / i),p * sin(Pi / i)};
        for(int j = 0; j < N; j += i * 2){
            cp w = (cp){1,0};
            for(int k = j; k < j + i; k ++){
                cp t1 = a[k],t2 = w * a[k + i];
                a[k] = t1 + t2,a[k + i] = t1 - t2;
                w = w * t;
            }
        }
    }
    if(p == -1) for(int i = 0; i < N; i ++) a[i].x /= N;
}
int main(){
    scanf("%s",s);
    n = strlen(s);
    for(int i = 0; i < n; i ++) a[i].x = s[n - i - 1] - '0';
    scanf("%s",s);
    m = strlen(s);
    for(int i = 0; i < m; i ++) b[i].x = s[m - i - 1] - '0';
    N = 1;
    while(N < n + m) N <<= 1,l ++;
    for(int i = 1; i <= N; i ++) rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (l - 1));
    FFT(a,1),FFT(b,1);
    for(int i = 0; i < N; i ++) a[i] = a[i] * b[i];
    FFT(a,-1);
    for(int i = 0; i < N; i ++){
        ans[i] += int(a[i].x + 0.5);
        ans[i + 1] += ans[i] / 10;
        ans[i] %= 10;
    }
    while(!ans[N] && N) N --;
    for(int i = N; i >= 0; i --) printf("%d",ans[i]);
    return 0;
}

[ZJOI2014]力:FFT

#include <iostream>
#include <cstdio>
#include <cmath>
using namespace std;
const int maxn = 3e5 + 50;
const double Pi = acos(-1);
int n,m = 1,l,rev[maxn];
struct cp{
    double x,y;
    cp operator + (cp a){
        return {a.x + x,a.y + y};
    }
    cp operator - (cp a){
        return {x - a.x,y - a.y};
    }
    cp operator * (cp a){
        return {x * a.x - y * a.y,x * a.y + y * a.x};
    }
}a[maxn],b[maxn],t[maxn];
void FFT(cp *a,int p){
    for(int i = 0; i < m; i ++) if(i < rev[i]) swap(a[i],a[rev[i]]);
    for(int i = 1; i < m; i <<= 1){
        cp w1 = {cos(Pi / i),p * sin(Pi / i)};
        for(int j = 0; j < m; j += 2 * i){
            cp w = {1,0};
            for(int k = j; k < j + i; k ++){
                cp t1 = a[k],t2 = w * a[k + i];
                a[k] = t1 + t2,a[k + i] = t1 - t2;
                w = w * w1;
            }
        }
    }
    if(p == -1) for(int i = 0; i < m; i ++) a[i].x /= m;
}
int main(){
    scanf("%d",&n);
    while(m < 2 * n - 1) m <<= 1,l ++;
    for(int i = 1; i < m; i ++) rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (l - 1));
    for(int i = 1; i <= n; i ++) scanf("%lf",&a[i].x),b[n - i] = a[i],t[i].x = 1.0 / i / i;
    FFT(a,1),FFT(b,1),FFT(t,1);
    for(int i = 0; i < m; i ++) a[i] = a[i] * t[i],b[i] = b[i] * t[i];
    FFT(a,-1),FFT(b,-1);
    for(int i = 1; i <= n; i ++) printf("%.3f\n",a[i].x - b[n - i].x);
    return 0;
}

[AH/HNOI2017]礼物:FFT

#include <iostream>
#include <cstdio>
#include <cmath>
using namespace std;
const int maxn = 3e5 + 50;
const long long inf = 1e18;
const double Pi = acos(-1);
int n,m,N = 1,l,t,rev[maxn];
long long s1,s2,ans = inf;
int read(){
    int x = 0;
    char c = getchar();
    while(c < '0' || c > '9') c = getchar();
    while(c >= '0' && c <= '9') x = x * 10 + (c ^ 48),c = getchar();
    return x;
}
struct cp{
    double x,y;
    cp operator + (cp a){
        return {a.x + x,a.y + y};
    }
    cp operator - (cp a){
        return {x - a.x,y - a.y};
    }
    cp operator * (cp a){
        return {x * a.x - y * a.y,x * a.y + y * a.x};
    }
}a[maxn],b[maxn];
void FFT(cp *a,int p){
    for(int i = 0; i < N; i ++) if(i < rev[i]) swap(a[i],a[rev[i]]);
    for(int i = 1; i < N; i <<= 1){
        cp w1 = {cos(Pi / i),p * sin(Pi / i)};
        for(int j = 0; j < N; j += 2 * i){
            cp w = {1,0};
            for(int k = j; k < j + i; k ++){
                cp t1 = a[k],t2 = w * a[k + i];
                a[k] = t1 + t2,a[k + i] = t1 - t2;
                w = w * w1;
            }
        }
    }
    if(p == -1) for(int i = 0; i < N; i ++) a[i].x = (long long)(a[i].x / N + 0.5);
}
int main(){
    n = read(),m = read();
    while(N <= 3 * n) N <<= 1,l ++;
    for(int i = 1; i < N; i ++) rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (l - 1));
    for(int i = 1; i <= n; i ++) t = read(),a[i].x = a[i + n].x = t,s1 += 1ll * t * t,s2 += t;
    for(int i = 1; i <= n; i ++) t = read(),b[n - i + 1].x = t,s1 += 1ll * t * t,s2 -= t;
    FFT(a,1),FFT(b,1);
    for(int i = 0; i < N; i ++) a[i] = a[i] * b[i];
    FFT(a,-1);
    for(int i = 1; i <= n; i ++)
        for(int j = -m; j <= m; j ++)
            ans = min(ans,1ll * n * j * j + s1 + 2ll * s2 * j - 2 * (long long)a[n + i].x);
    printf("%lld",ans);
    return 0;
}

任意模数NTT模板

谁来教我NTT模数和原根怎么背

#include <iostream>
#include <cstdio>
using namespace std;
const int maxn = 3e5 + 50;
const int g = 3,p1 = 469762049,p2 = 998244353,p3 = 1004535809;
int n,m,p,N = 1,l,rev[maxn],a[maxn],b[maxn],t1[maxn],t2[maxn],d[3][maxn];
int read(){
    int x = 0;
    char c = getchar();
    while(c < '0' || c > '9') c = getchar();
    while(c >= '0' && c <= '9') x = x * 10 + (c ^ 48),c = getchar();
    return x;
}
int qpow(int x,int k,int mod){
    long long d = 1,t = x;
    while(k){
        if(k & 1) d = d * t % mod;
        t = t * t % mod,k >>= 1;
    }
    return d;
}
void DNT(int *a,int op,int mod){
    for(int i = 0; i < N; i ++) if(i < rev[i]) swap(a[i],a[rev[i]]);
    for(int i = 1; i < N; i <<= 1){
        int w1 = qpow(g,(mod - 1) / (2 * i),mod);
        for(int j = 0; j < N; j += 2 * i){
            int w = 1;
            for(int k = j; k < j + i; k ++){
                int t1 = a[k],t2 = 1ll * w * a[k + i] % mod;
                a[k] = (t1 + t2) % mod,a[k + i] = (t1 - t2 + mod) % mod;
                w = 1ll * w * w1 % mod;
            }
        }
    }
    if(!op){
        int inv = qpow(N,mod - 2,mod);
        for(int i = 0; i < N; i ++) a[i] = 1ll * a[i] * inv % mod;
        for(int i = 1; i <= N / 2; i ++) swap(a[i],a[N - i]);
    }
}
void NTT(int *a,int *b,int mod){
    DNT(a,1,mod),DNT(b,1,mod);
    for(int i = 0; i < N; i ++) a[i] = 1ll * a[i] * b[i] % mod;
    DNT(a,0,mod);
}
void solve(int *a,int * b,int *s,int mod){
    for(int i = 0; i < N; i ++) t1[i] = a[i];
    for(int i = 0; i < N; i ++) t2[i] = b[i];
    NTT(t1,t2,mod);
    for(int i = 0; i < N; i ++) s[i] = t1[i];
}
int CRT(long long x,long long y,long long z){
    long long t = (y - x + p2) % p2 * qpow(p1,p2 - 2,p2) % p2 * p1 + x;
    return ((z - t % p3 + p3) % p3 * qpow(1ll * p1 * p2 % p3,p3 - 2,p3) % p3 * (1ll * p1 * p2 % p) % p + t) % p;
}
int main(){
    n = read(),m = read(),p = read();
    while(N <= n + m) N <<= 1,l ++;
    for(int i = 0; i < N; i ++) rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (l - 1));
    for(int i = 0; i <= n; i ++) a[i] = read();
    for(int i = 0; i <= m; i ++) b[i] = read();
    solve(a,b,d[0],p1);
    solve(a,b,d[1],p2);
    solve(a,b,d[2],p3);
    // for(int i = 0; i < N; i ++) cout << d[0][i] << ' ' << d[1][i] << ' ' << d[2][i] << endl;
    for(int i = 0; i <= n + m; i ++) printf("%d ",CRT(d[0][i],d[1][i],d[2][i]));
    return 0;
}