4.13 ~ 4.20做题情况
这周主要在做莫队及其各种变形算法和FFT/NTT的一些板子基础题
小B的询问:莫队
最最基础的莫队,没啥好说的
#include <iostream>
#include <cstdio>
#include <cmath>
#include <algorithm>
using namespace std;
const int maxn = 5e4 + 50;
int n,m,k,l = 1,r,sqt,sum,a[maxn],c[maxn],pos[maxn],ans[maxn];
struct query{
int l,r,pos;
}q[maxn];
int read(){
int x = 0;
char c = getchar();
while(c < '0' || c > '9') c = getchar();
while(c >= '0' && c <= '9') x = x * 10 + (c ^ 48),c = getchar();
return x;
}
bool cmp(query a,query b){
if(pos[a.l] != pos[b.l]) return pos[a.l] < pos[b.l];
else{
if(pos[a.l] & 1) return a.r < b.r;
else return a.r > b.r;
}
}
inline void add(int x){
sum += 2 * c[x] + 1,c[x] ++;
}
inline void del(int x){
sum -= 2 * c[x] - 1,c[x] --;
}
int main(){
n = read(),m = read(),k = read(),sqt = sqrt(n);
for(int i = 1; i <= n; i ++) a[i] = read(),pos[i] = (i - 1) / sqt + 1;
for(int i = 1; i <= m; i ++) q[i].l = read(),q[i].r = read(),q[i].pos = i;
sort(q + 1,q + m + 1,cmp);
for(int i = 1; i <= m; i ++){
while(l < q[i].l) del(a[l]),l ++;
while(l > q[i].l) l --,add(a[l]);
while(r < q[i].r) r ++,add(a[r]);
while(r > q[i].r) del(a[r]),r --;
ans[q[i].pos] = sum;
}
for(int i = 1; i <= m; i ++) printf("%d\n",ans[i]);
return 0;
}
大爷的字符串题:莫队
本题主要考察了选手的语文能力 题目的转化十分巧妙,然后就变成了一道普通的莫队
#include <iostream>
#include <cstdio>
#include <cmath>
#include <algorithm>
using namespace std;
const int maxn = 2e5 + 50;
int n,m,l = 1,r,len,sqt,sum,a[maxn],pos[maxn],t[maxn],num[maxn],cnt[maxn],ans[maxn];
struct query{
int l,r,pos;
}q[maxn];
int read(){
int x = 0;
char c = getchar();
while(c < '0' || c > '9') c = getchar();
while(c >= '0' && c <= '9') x = x * 10 + (c ^ 48),c = getchar();
return x;
}
bool cmp(query a,query b){
if(pos[a.l] != pos[b.l]) return pos[a.l] < pos[b.l];
else return a.r < b.r;
}
void add(int x){
if(sum == num[x]) sum ++;
cnt[num[x]] --,num[x] ++,cnt[num[x]] ++;
}
void del(int x){
if(sum == num[x] && cnt[num[x]] == 1) sum --;
cnt[num[x]] --,num[x] --,cnt[num[x]] ++;
}
int main(){
n = read(),m = read(),sqt = sqrt(n);
for(int i = 1; i <= n; i ++) a[i] = t[i] = read(),pos[i] = (i - 1) / sqt + 1;
sort(t + 1,t + n + 1);
len = unique(t + 1,t + n + 1) - t - 1;
for(int i = 1; i <= n; i ++) a[i] = lower_bound(t + 1,t + len + 1,a[i]) - t;
for(int i = 1; i <= m; i ++) q[i].l = read(),q[i].r = read(),q[i].pos = i;
sort(q + 1,q + m + 1,cmp);
for(int i = 1; i <= n; i ++){
while(r < q[i].r) r ++,add(a[r]);
while(r > q[i].r) del(a[r]),r --;
while(l < q[i].l) del(a[l]),l ++;
while(l > q[i].l) l --,add(a[l]);
ans[q[i].pos] = sum;
}
for(int i = 1; i <= m; i ++) printf("%d\n",-ans[i]);
return 0;
}
小清新人渣的本愿:莫队,bitset
用到了类似状压的思想,由于长度高达所以不能直接用普通整形,可以用一个bitset维护 (表示第一次听说bitset这种东西 ) 。其中加法的询问转化比较巧妙,不过题解说都是很套路的东西?
#include <iostream>
#include <cstdio>
#include <cmath>
#include <algorithm>
#include <bitset>
using namespace std;
const int maxn = 1e5 + 50,N = 1e5;
int n,m,sqt,l,r,num[maxn],pos[maxn],cnt[maxn],ans[maxn];
bitset <maxn> a,b;
struct query{
int l,r,x,opt,pos;
}q[maxn];
int read(){
int x = 0;
char c = getchar();
while(c < '0' || c > '9') c = getchar();
while(c >= '0' && c <= '9') x = x * 10 + (c ^ 48),c = getchar();
return x;
}
bool cmp(query a,query b){
if(pos[a.l] != pos[b.l]) return pos[a.l] < pos[b.l];
else{
if(pos[a.l] & 1) return a.r < b.r;
else return a.r > b.r;
}
}
void add(int x){
if(!cnt[x]) a[x] = 1,b[N - x] = 1;
cnt[x] ++;
}
void del(int x){
cnt[x] --;
if(!cnt[x]) a[x] = 0,b[N - x] = 0;
}
int main(){
n = read(),m = read(),sqt = sqrt(n);
for(int i = 1; i <= n; i ++) num[i] = read(),pos[i] = (i - 1) / sqt + 1;
for(int i = 1; i <= m; i ++) q[i].opt = read(),q[i].l = read(),q[i].r = read(),q[i].x = read(),q[i].pos = i;
sort(q + 1,q + m + 1,cmp);
for(int i = 1; i <= m; i ++){
while(l < q[i].l) del(num[l]),l ++;
while(l > q[i].l) l --,add(num[l]);
while(r < q[i].r) r ++,add(num[r]);
while(r > q[i].r) del(num[r]),r --;
if(q[i].opt == 1){
if((a & (a << q[i].x)).count()) ans[q[i].pos] = 1;
}else if(q[i].opt == 2){
if((a & (b >> (N - q[i].x))).count()) ans[q[i].pos] = 1;
}else{
for(int j = 1; j * j <= q[i].x; j ++){
if(q[i].x % j == 0 && a[j] && a[q[i].x / j]){
ans[q[i].pos] = 1;
break;
}
}
}
}
for(int i = 1; i <= m; i ++){
if(ans[i]) printf("hana\n");
else printf("bi\n");
}
return 0;
}
COT2 - Count on a tree II:树上莫队模板
求出树的欧拉序即可按照某些方法转化成区间问题,套莫队就好了
我在想能不能用LCT做
#include <iostream>
#include <cstdio>
#include <cmath>
#include <algorithm>
using namespace std;
const int maxn = 4e4 + 50,maxm = 1e5 + 50;
int n,m,x,y,l = 1,r,sum,len,cnt,num[maxn],t[maxn],a[2 * maxn],pos[2 * maxn],fst[maxn],lst[maxn],last[maxn],d[maxn],vis[maxn],tot[maxn],ans[maxm],f[maxn][20];
struct edge{
int v,nxt;
}e[2 * maxn];
struct query{
int l,r,lca,pos;
}q[maxm];
int read(){
int x = 0;
char c = getchar();
while(c < '0' || c > '9') c = getchar();
while(c >= '0' && c <= '9') x = x * 10 + (c ^ 48),c = getchar();
return x;
}
bool cmp(query a,query b){
if(pos[a.l] != pos[b.l]) return pos[a.l] < pos[b.l];
else{
if(pos[a.l] & 1) return a.r < b.r;
else return a.r > b.r;
}
}
inline void insert(int x,int y){
cnt ++,e[cnt].v = y,e[cnt].nxt = last[x],last[x] = cnt;
}
void dfs(int u,int fa){
a[++cnt] = u,fst[u] = cnt,d[u] = d[fa] + 1,f[u][0] = fa;
for(int i = 1; f[f[u][i - 1]][i - 1]; i ++) f[u][i] = f[f[u][i - 1]][i - 1];
for(int i = last[u]; i; i = e[i].nxt){
int v = e[i].v;
if(v != fa) dfs(v,u);
}
a[++cnt] = u,lst[u] = cnt;
}
int Lca(int x,int y){
if(d[x] < d[y]) swap(x,y);
for(int i = 16; i >= 0; i --) if(d[f[x][i]] >= d[y]) x = f[x][i];
if(x == y) return x;
for(int i = 16; i >= 0; i --) if(f[x][i] != f[y][i]) x = f[x][i],y = f[y][i];
return f[x][0];
}
void update(int x){
if(!x) return;
vis[x] ^= 1;
if(!vis[x]){
tot[num[x]] --;
if(!tot[num[x]]) sum --;
}else{
tot[num[x]] ++;
if(tot[num[x]] == 1) sum ++;
}
}
int main(){
n = read(),m = read();
for(int i = 1; i <= n; i ++) num[i] = t[i] = read();
sort(t + 1,t + n + 1);
len = unique(t + 1,t + n + 1) - t - 1;
for(int i = 1; i <= n; i ++) num[i] = lower_bound(t + 1,t + len + 1,num[i]) - t;
for(int i = 1; i < n; i ++){
x = read(),y = read();
insert(x,y),insert(y,x);
}
cnt = 0;
dfs(1,0);
int sqt = sqrt(cnt),t = ceil((double) cnt / sqt);
for(int i = 1; i <= t; ++i)
for(int j = sqt * (i - 1) + 1; j <= i * sqt; ++j)
pos[j] = i;
for(int i = 1; i <= m; i ++){
x = read(),y = read(),q[i].pos = i;
int t = Lca(x,y);
if(fst[x] > fst[y]) swap(x,y);
if(t == x) q[i].l = fst[x],q[i].r = fst[y];
else q[i].l = lst[x],q[i].r = fst[y],q[i].lca = t;
}
sort(q + 1,q + m + 1,cmp);
for(int i = 1; i <= m; i ++){
while(l < q[i].l) update(a[l]),l ++;
while(l > q[i].l) l --,update(a[l]);
while(r < q[i].r) r ++,update(a[r]);
while(r > q[i].r) update(a[r]),r --;
update(q[i].lca);
ans[q[i].pos] = sum;
update(q[i].lca);
}
for(int i = 1; i <= m; i ++) printf("%d\n",ans[i]);
return 0;
}
[国家集训队] 数颜色:带修莫队模板,毒瘤卡常
谁来告诉我块大小到底取还是啊
在原算法的基础上加一维时间,表示当前进行了几个修改操作,移动的时候再考虑上时间即可。有个小优化是修改时直接swap要修改的值与当前值,这样就不用判断当前是修改还是撤回修改,减小了代码量。
真的就是毒瘤题 ,实在卡不过去了,就从pb大佬的剪贴板里粘了一大堆 几行东西,不过就算这样最大点也只差0.02s啊!!!
人傻常数大的代码
#pragma GCC optimize(2)
#pragma GCC optimize(3)
#pragma GCC optimize("Ofast")
#pragma GCC optimize("inline")
#pragma GCC optimize("-fgcse")
#pragma GCC optimize("-fgcse-lm")
#pragma GCC optimize("-fipa-sra")
#pragma GCC optimize("-ftree-pre")
#pragma GCC optimize("-ftree-vrp")
#pragma GCC optimize("-fpeephole2")
#pragma GCC optimize("-ffast-math")
#pragma GCC optimize("-fsched-spec")
#pragma GCC optimize("unroll-loops")
#pragma GCC optimize("-falign-jumps")
#pragma GCC optimize("-falign-loops")
#pragma GCC optimize("-falign-labels")
#pragma GCC optimize("-fdevirtualize")
#pragma GCC optimize("-fcaller-saves")
#pragma GCC optimize("-fcrossjumping")
#pragma GCC optimize("-fthread-jumps")
#pragma GCC optimize("-funroll-loops")
#pragma GCC optimize("-fwhole-program")
#pragma GCC optimize("-freorder-blocks")
#pragma GCC optimize("-fschedule-insns")
#pragma GCC optimize("inline-functions")
#pragma GCC optimize("-ftree-tail-merge")
#pragma GCC optimize("-fschedule-insns2")
#pragma GCC optimize("-fstrict-aliasing")
#pragma GCC optimize("-fstrict-overflow")
#pragma GCC optimize("-falign-functions")
#pragma GCC optimize("-fcse-skip-blocks")
#pragma GCC optimize("-fcse-follow-jumps")
#pragma GCC optimize("-fsched-interblock")
#pragma GCC optimize("-fpartial-inlining")
#pragma GCC optimize("no-stack-protector")
#pragma GCC optimize("-freorder-functions")
#pragma GCC optimize("-findirect-inlining")
#pragma GCC optimize("-frerun-cse-after-loop")
#pragma GCC optimize("inline-small-functions")
#pragma GCC optimize("-finline-small-functions")
#pragma GCC optimize("-ftree-switch-conversion")
#pragma GCC optimize("-foptimize-sibling-calls")
#pragma GCC optimize("-fexpensive-optimizations")
#pragma GCC optimize("-funsafe-loop-optimizations")
#pragma GCC optimize("inline-functions-called-once")
#pragma GCC optimize("-fdelete-null-pointer-checks")
#include <iostream>
#include <cstdio>
#include <cmath>
#include <algorithm>
using namespace std;
const int maxn = 1.4e5,maxv = 1e6 + 50;
int n,m,x,y,l = 1,r,t,sum,cntq,cntr,len,num[maxn],cnt[maxv],ans[maxn];
struct node{
int pos,val;
}a[maxn];
struct query{
int l,r,t,pos;
}q[maxn];
int read(){
int x = 0;
char c = getchar();
while(c < '0' || c > '9') c = getchar();
while(c >= '0' && c <= '9') x = x * 10 + (c ^ 48),c = getchar();
return x;
}
void print(int x){
int h = 0,c[20];
while(x) c[++h] = x % 10,x /= 10;
for(int i = h; i >= 1; i --) putchar(c[i] + '0');
putchar('\n');
}
bool cmp(query a,query b){
if(a.l / len!= b.l / len) return a.l / len < b.l / len;
else if(a.r / len != b.r / len) return a.r / len < b.r / len;
else return a.t < b.t;
}
// void add(int x){
// if(!cnt[x]) sum ++;
// cnt[x] ++;
// }
// void del(int x){
// cnt[x] --;
// if(!cnt[x]) sum --;
// }
int main(){
n = read(),m = read();
for(int i = 1; i <= n; i ++) num[i] = read();
for(int i = 1; i <= m; i ++){
if(getchar() == 'Q') cntq ++,q[cntq].l = read(),q[cntq].r = read(),q[cntq].t = cntr,q[cntq].pos = cntq;
else cntr ++,a[cntr].pos = read(),a[cntr].val = read();
}
len = ceil(exp((log(n) + log(cntr)) / 3));
sort(q + 1,q + cntq + 1,cmp);
for(int i = 1; i <= cntq; i ++){
while(l < q[i].l) sum -= !--cnt[num[l++]];
while(l > q[i].l) sum += !cnt[num[--l]] ++;
while(r < q[i].r) sum += !cnt[num[++r]] ++;
while(r > q[i].r) sum -= !--cnt[num[r--]];
while(t != q[i].t){
if(t < q[i].t) t ++;
if(l <= a[t].pos && r >= a[t].pos){
cnt[num[a[t].pos]] --;
if(!cnt[num[a[t].pos]]) sum --;
if(!cnt[a[t].val]) sum ++;
cnt[a[t].val] ++;
}
swap(num[a[t].pos],a[t].val);
if(t > q[i].t) t --;
}
ans[q[i].pos] = sum;
}
for(int i = 1; i <= cntq; i ++) print(ans[i]);
return 0;
}
回滚莫队模板
用于某些不好删除但可以插入的题,对每一块计算左端点在当前块的区间,因为左端点在同一块的区间右端点单调递增,每次右端点可以从上一个右端点继续扩展,左端点只需要从暴力向左扩展,这样就避免了莫队的删除操作。
考虑时间复杂度,凭感觉 取块的大小为:
对于左端点,每个区间至多从块右端扩展到块左端,时间为,有个询问,则总复杂度为
对于右端点,每块至多从块左端扩展到号点,单块时间,共块,总复杂度
一般同阶,故总复杂度
#include <iostream>
#include <cstdio>
#include <cmath>
#include <algorithm>
using namespace std;
const int maxn = 2e5 + 50,inf = 1e9;
int n,m,l,r,len,tot,mx,tmx,a[maxn],t[maxn],pos[maxn],tl[maxn],tr[maxn],lb[maxn],rb[maxn],ans[maxn];
struct query{
int l,r,pos;
}q[maxn];
int read(){
int x = 0;
char c = getchar();
while(c < '0' || c > '9') c = getchar();
while(c >= '0' && c <= '9') x = x * 10 + (c ^ 48),c = getchar();
return x;
}
bool cmp(query a,query b){
if(pos[a.l] != pos[b.l]) return pos[a.l] < pos[b.l];
else return a.r < b.r;
}
inline void update(int x){
lb[a[x]] = min(lb[a[x]],x),rb[a[x]] = max(rb[a[x]],x);
}
int main(){
n = read(),len = sqrt(n);
for(int i = 1; i <= n; i ++) pos[i] = (i - 1) / len + 1;
for(int i = 1; i <= n; i ++) a[i] = t[i] = read();
sort(t + 1,t + n + 1);
tot = unique(t + 1,t + n + 1) - t - 1;
for(int i = 1; i <= n; i ++) a[i] = lower_bound(t + 1,t + tot + 1,a[i]) - t;
m = read();
for(int i = 1; i <= m; i ++) q[i].l = read(),q[i].r = read(),q[i].pos = i;
sort(q + 1,q + m + 1,cmp);
for(int i = 1; i <= n; i ++) lb[i] = inf,rb[i] = 0;
for(int k = 1,i = 1; k <= pos[n]; k ++){
while(i <= m && q[i].r <= k * len){
for(int j = q[i].l; j <= q[i].r; j ++) update(j),mx = max(mx,rb[a[j]] - lb[a[j]]);
ans[q[i].pos] = mx;
for(int j = q[i].l; j <= q[i].r; j ++) lb[a[j]] = inf,rb[a[j]] = 0;
mx = 0,i ++;
}
r = k * len;
if(r >= m) continue;
while(i <= m && pos[q[i].l] == k){
l = k * len + 1;
while(r < q[i].r) r ++,update(r),mx = max(mx,rb[a[r]] - lb[a[r]]);
tmx = mx;
for(int j = q[i].l; j <= k * len; j ++) tl[a[j]] = lb[a[j]],tr[a[j]] = rb[a[j]];
while(l > q[i].l) l --,update(l),mx = max(mx,rb[a[l]] - lb[a[l]]);
for(int j = q[i].l; j <= k * len; j ++) lb[a[j]] = tl[a[j]],rb[a[j]] = tr[a[j]];
ans[q[i].pos] = mx,mx = tmx,i ++;
}
mx = 0;
for(int j = k * len + 1; j <= r; j ++) lb[a[j]] = inf,rb[a[j]] = 0;
}
for(int i = 1; i <= m; i ++) printf("%d\n",ans[i]);
return 0;
}
歴史の研究:回滚莫队
国外OJ真烦,不给错误信息,因为long long调了三个小时
真就不开long long见祖宗
#include <iostream>
#include <cstdio>
#include <cmath>
#include <algorithm>
using namespace std;
const int maxn = 1e5 + 50;
int n,m,l,r,len,tot,a[maxn],num[maxn],t[maxn],pos[maxn],tmp[maxn],cnt[maxn];
long long mx,tmx,ans[maxn];
struct query{
int l,r,pos;
}q[maxn];
int read(){
int x = 0;
char c = getchar();
while(c < '0' || c > '9') c = getchar();
while(c >= '0' && c <= '9') x = x * 10 + (c ^ 48),c = getchar();
return x;
}
bool cmp(query a,query b){
if(pos[a.l] != pos[b.l]) return pos[a.l] < pos[b.l];
else return a.r < b.r;
}
int main(){
n = read(),m = read(),len = sqrt(n);
for(int i = 1; i <= n; i ++) pos[i] = (i - 1) / len + 1;
for(int i = 1; i <= n; i ++) num[i] = t[i] = read();
sort(t + 1,t + n + 1);
tot = unique(t + 1,t + n + 1) - t - 1;
for(int i = 1; i <= n; i ++) a[i] = lower_bound(t + 1,t + tot + 1,num[i]) - t;
for(int i = 1; i <= m; i ++) q[i].l = read(),q[i].r = read(),q[i].pos = i;
sort(q + 1,q + m + 1,cmp);
// for(int i = 1; i <= m; i ++) cout << q[i].pos <<endl;
// cout << endl;
for(int k = 1,i = 1; k <= pos[n]; k ++){
while(i <= m && q[i].r <= k * len){
for(int j = q[i].l; j <= q[i].r; j ++) cnt[a[j]] ++,mx = max(mx,1ll * num[j] * cnt[a[j]]);
ans[q[i].pos] = mx;
for(int j = q[i].l; j <= q[i].r; j ++) cnt[a[j]] = 0;
mx = 0,i ++;
}
r = k * len;
if(r >= m) continue;
while(i <= m && pos[q[i].l] == k){
l = k * len + 1;
while(r < q[i].r) r ++,cnt[a[r]] ++,mx = max(mx,1ll * num[r] * cnt[a[r]]);
tmx = mx;
for(int j = q[i].l; j <= k * len; j ++) tmp[a[j]] = cnt[a[j]];
while(l > q[i].l) l --,cnt[a[l]] ++,mx = max(mx,1ll * num[l] * cnt[a[l]]);
for(int j = q[i].l; j <= k * len; j ++) cnt[a[j]] = tmp[a[j]];
// for(int j = q[i].l; j <= k * len; j ++) cout << j << ' ' << num[j] << ' ' << cnt[a[j]] << endl;
// cout << i << ' ' << mx << endl;
ans[q[i].pos] = mx,mx = tmx,i ++;
}
mx = 0;
for(int j = k * len + 1; j <= r; j ++) cnt[a[j]] = 0;
// cout << k << ' ' << i << endl;
}
// cout << endl;
for(int i = 1; i <= m; i ++) printf("%lld\n",ans[i]);
return 0;
}
[WC2013]糖果公园:树上莫队+带修莫队
莫队毕业题,融合了几个算法,要注意移动时间的时候也要判欧拉序里的vis标记回滚莫队:我不配拥有姓名
#include <iostream>
#include <cstdio>
#include <cmath>
#include <algorithm>
using namespace std;
const int maxn = 1e5 + 50,maxp = 2 * maxn,maxv = 1e6;
int n,m,Q,x,y,l = 1,r,t,len,type,ecnt,pcnt,cntr,cntq,v[maxn],w[maxn],c[maxn],p[maxp],pos[maxp],fst[maxn],lst[maxn],last[maxn],d[maxn],f[maxn][20],cnt[maxv],vis[maxn];
long long sum,ans[maxn];
struct edge{
int v,nxt;
}e[2 * maxn];
struct revise{
int pos,val;
}a[maxn];
struct query{
int l,r,t,lca,pos;
}q[maxn];
int read(){
int x = 0;
char c = getchar();
while(c < '0' || c > '9') c = getchar();
while(c >= '0' && c <= '9') x = x * 10 + (c ^ 48),c = getchar();
return x;
}
inline void insert(int x,int y){
ecnt ++,e[ecnt].v = y,e[ecnt].nxt = last[x],last[x] = ecnt;
}
bool cmp(query a,query b){
if(pos[a.l] != pos[b.l]) return pos[a.l] < pos[b.l];
else if(pos[a.r] != pos[b.r]) return pos[a.r] < pos[b.r];
else return a.t < b.t;
}
void dfs(int u,int fa){
pcnt ++,p[pcnt] = u,fst[u] = pcnt,d[u] = d[fa] + 1,f[u][0] = fa;
for(int i = 1; i <= 17; i ++) f[u][i] = f[f[u][i - 1]][i - 1];
for(int i = last[u]; i; i = e[i].nxt){
int v = e[i].v;
if(v != fa) dfs(v,u);
}
pcnt ++,p[pcnt] = u,lst[u] = pcnt;
}
int Lca(int x,int y){
if(d[x] < d[y]) swap(x,y);
for(int i = 17; i >= 0; i --) if(d[f[x][i]] >= d[y]) x = f[x][i];
if(x == y) return x;
for(int i = 17; i >= 0; i --) if(f[x][i] != f[y][i]) x = f[x][i],y = f[y][i];
return f[x][0];
}
inline void update(int x){
if(!x) return;
vis[x] ^= 1;
if(vis[x]) cnt[c[x]] ++,sum += 1ll * w[cnt[c[x]]] * v[c[x]];
else sum -= 1ll * w[cnt[c[x]]] * v[c[x]],cnt[c[x]] --;
}
int main(){
n = read(),m = read(),Q = read();
for(int i = 1; i <= m; i ++) v[i] = read();
for(int i = 1; i <= n; i ++) w[i] = read();
for(int i = 1; i < n; i ++){
x = read(),y = read();
insert(x,y),insert(y,x);
}
dfs(1,0);
len = pow(pcnt,2.0 / 3);
for(int i = 1; i <= pcnt; i ++) pos[i] = (i - 1) / len + 1;
for(int i = 1; i <= n; i ++) c[i] = read();
for(int i = 1; i <= Q; i ++){
type = read(),x = read(),y = read();
if(type == 0) cntr ++,a[cntr].pos = x,a[cntr].val = y;
else{
cntq ++,q[cntq].t = cntr,q[cntq].pos = cntq;
if(fst[x] > fst[y]) swap(x,y);
int t = Lca(x,y);
if(t == x) q[cntq].l = fst[x],q[cntq].r = fst[y];
else q[cntq].l = lst[x],q[cntq].r = fst[y],q[cntq].lca = t;
}
}
sort(q + 1,q + cntq + 1,cmp);
for(int i = 1; i <= cntq; i ++){
while(l < q[i].l) update(p[l]),l ++;
while(l > q[i].l) l --,update(p[l]);
while(r < q[i].r) r ++,update(p[r]);
while(r > q[i].r) update(p[r]),r --;
while(t != q[i].t){
if(t < q[i].t) t ++;
if(vis[a[t].pos]){
sum -= 1ll * w[cnt[c[a[t].pos]]] * v[c[a[t].pos]],cnt[c[a[t].pos]] --;
cnt[a[t].val] ++,sum += 1ll * w[cnt[a[t].val]] * v[a[t].val];
}
swap(c[a[t].pos],a[t].val);
if(t > q[i].t) t --;
}
update(q[i].lca);
ans[q[i].pos] = sum;
update(q[i].lca);
}
for(int i = 1; i <= cntq; i ++) printf("%lld\n",ans[i]);
return 0;
}
FFT模板
看了半天证明,打代码时才发现并没什么用
#include <iostream>
#include <cstdio>
#include <cstring>
#include <cmath>
using namespace std;
const int maxn = 2.1e6 + 50;
const double Pi = acos(-1);
int n,m,N,l,rev[maxn],ans[maxn];
char s[maxn];
struct cp{
double x,y;
}a[maxn],b[maxn];
cp operator + (cp a,cp b){
return (cp){a.x + b.x,a.y + b.y};
}
cp operator - (cp a,cp b){
return (cp){a.x - b.x,a.y - b.y};
}
cp operator * (cp a,cp b){
return (cp){a.x * b.x - a.y * b.y,a.x * b.y + a.y * b.x};
}
int read(){
int x = 0;
char c = getchar();
while(c < '0' || c > '9') c = getchar();
while(c >= '0' && c <= '9') x = x * 10 + (c ^ 48),c = getchar();
return x;
}
void FFT(cp *a,int p){
for(int i = 0; i < N; i ++) if(i < rev[i]) swap(a[i],a[rev[i]]);
for(int i = 1; i < N; i <<= 1){
cp t = (cp){cos(Pi / i),p * sin(Pi / i)};
for(int j = 0; j < N; j += i * 2){
cp w = (cp){1,0};
for(int k = j; k < j + i; k ++){
cp t1 = a[k],t2 = w * a[k + i];
a[k] = t1 + t2,a[k + i] = t1 - t2;
w = w * t;
}
}
}
if(p == -1) for(int i = 0; i < N; i ++) a[i].x /= N;
}
int main(){
n = read(),m = read();
for(int i = 0; i <= n; i ++) a[i].x = read();
for(int i = 0; i <= m; i ++) b[i].x = read();
N = 1;
while(N <= n + m) N <<= 1,l ++;
for(int i = 1; i <= N; i ++) rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (l - 1));
FFT(a,1),FFT(b,1);
for(int i = 0; i < N; i ++) a[i] = a[i] * b[i];
FFT(a,-1);
for(int i = 0; i <= n + m; i ++) printf("%d ",int(a[i].x + 0.5));
return 0;
}
A*B Problem:FFT
听说,恶臭的题号与恶臭的样例更陪哦(雾
用FFT加速高精乘,把每个数看做一个多项式,系数为数的每一位(将10带入即数的值),用FFT求两多项式之积即可
#include <iostream>
#include <cstdio>
#include <cstring>
#include <cmath>
using namespace std;
const int maxn = 2.1e6 + 50;
const double Pi = acos(-1);
int n,m,N,l,rev[maxn],ans[maxn];
char s[maxn];
struct cp{
double x,y;
}a[maxn],b[maxn];
cp operator + (cp a,cp b){
return (cp){a.x + b.x,a.y + b.y};
}
cp operator - (cp a,cp b){
return (cp){a.x - b.x,a.y - b.y};
}
cp operator * (cp a,cp b){
return (cp){a.x * b.x - a.y * b.y,a.x * b.y + a.y * b.x};
}
void FFT(cp *a,int p){
for(int i = 0; i < N; i ++) if(i < rev[i]) swap(a[i],a[rev[i]]);
for(int i = 1; i < N; i <<= 1){
cp t = (cp){cos(Pi / i),p * sin(Pi / i)};
for(int j = 0; j < N; j += i * 2){
cp w = (cp){1,0};
for(int k = j; k < j + i; k ++){
cp t1 = a[k],t2 = w * a[k + i];
a[k] = t1 + t2,a[k + i] = t1 - t2;
w = w * t;
}
}
}
if(p == -1) for(int i = 0; i < N; i ++) a[i].x /= N;
}
int main(){
scanf("%s",s);
n = strlen(s);
for(int i = 0; i < n; i ++) a[i].x = s[n - i - 1] - '0';
scanf("%s",s);
m = strlen(s);
for(int i = 0; i < m; i ++) b[i].x = s[m - i - 1] - '0';
N = 1;
while(N < n + m) N <<= 1,l ++;
for(int i = 1; i <= N; i ++) rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (l - 1));
FFT(a,1),FFT(b,1);
for(int i = 0; i < N; i ++) a[i] = a[i] * b[i];
FFT(a,-1);
for(int i = 0; i < N; i ++){
ans[i] += int(a[i].x + 0.5);
ans[i + 1] += ans[i] / 10;
ans[i] %= 10;
}
while(!ans[N] && N) N --;
for(int i = N; i >= 0; i --) printf("%d",ans[i]);
return 0;
}
[ZJOI2014]力:FFT
#include <iostream>
#include <cstdio>
#include <cmath>
using namespace std;
const int maxn = 3e5 + 50;
const double Pi = acos(-1);
int n,m = 1,l,rev[maxn];
struct cp{
double x,y;
cp operator + (cp a){
return {a.x + x,a.y + y};
}
cp operator - (cp a){
return {x - a.x,y - a.y};
}
cp operator * (cp a){
return {x * a.x - y * a.y,x * a.y + y * a.x};
}
}a[maxn],b[maxn],t[maxn];
void FFT(cp *a,int p){
for(int i = 0; i < m; i ++) if(i < rev[i]) swap(a[i],a[rev[i]]);
for(int i = 1; i < m; i <<= 1){
cp w1 = {cos(Pi / i),p * sin(Pi / i)};
for(int j = 0; j < m; j += 2 * i){
cp w = {1,0};
for(int k = j; k < j + i; k ++){
cp t1 = a[k],t2 = w * a[k + i];
a[k] = t1 + t2,a[k + i] = t1 - t2;
w = w * w1;
}
}
}
if(p == -1) for(int i = 0; i < m; i ++) a[i].x /= m;
}
int main(){
scanf("%d",&n);
while(m < 2 * n - 1) m <<= 1,l ++;
for(int i = 1; i < m; i ++) rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (l - 1));
for(int i = 1; i <= n; i ++) scanf("%lf",&a[i].x),b[n - i] = a[i],t[i].x = 1.0 / i / i;
FFT(a,1),FFT(b,1),FFT(t,1);
for(int i = 0; i < m; i ++) a[i] = a[i] * t[i],b[i] = b[i] * t[i];
FFT(a,-1),FFT(b,-1);
for(int i = 1; i <= n; i ++) printf("%.3f\n",a[i].x - b[n - i].x);
return 0;
}
[AH/HNOI2017]礼物:FFT
#include <iostream>
#include <cstdio>
#include <cmath>
using namespace std;
const int maxn = 3e5 + 50;
const long long inf = 1e18;
const double Pi = acos(-1);
int n,m,N = 1,l,t,rev[maxn];
long long s1,s2,ans = inf;
int read(){
int x = 0;
char c = getchar();
while(c < '0' || c > '9') c = getchar();
while(c >= '0' && c <= '9') x = x * 10 + (c ^ 48),c = getchar();
return x;
}
struct cp{
double x,y;
cp operator + (cp a){
return {a.x + x,a.y + y};
}
cp operator - (cp a){
return {x - a.x,y - a.y};
}
cp operator * (cp a){
return {x * a.x - y * a.y,x * a.y + y * a.x};
}
}a[maxn],b[maxn];
void FFT(cp *a,int p){
for(int i = 0; i < N; i ++) if(i < rev[i]) swap(a[i],a[rev[i]]);
for(int i = 1; i < N; i <<= 1){
cp w1 = {cos(Pi / i),p * sin(Pi / i)};
for(int j = 0; j < N; j += 2 * i){
cp w = {1,0};
for(int k = j; k < j + i; k ++){
cp t1 = a[k],t2 = w * a[k + i];
a[k] = t1 + t2,a[k + i] = t1 - t2;
w = w * w1;
}
}
}
if(p == -1) for(int i = 0; i < N; i ++) a[i].x = (long long)(a[i].x / N + 0.5);
}
int main(){
n = read(),m = read();
while(N <= 3 * n) N <<= 1,l ++;
for(int i = 1; i < N; i ++) rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (l - 1));
for(int i = 1; i <= n; i ++) t = read(),a[i].x = a[i + n].x = t,s1 += 1ll * t * t,s2 += t;
for(int i = 1; i <= n; i ++) t = read(),b[n - i + 1].x = t,s1 += 1ll * t * t,s2 -= t;
FFT(a,1),FFT(b,1);
for(int i = 0; i < N; i ++) a[i] = a[i] * b[i];
FFT(a,-1);
for(int i = 1; i <= n; i ++)
for(int j = -m; j <= m; j ++)
ans = min(ans,1ll * n * j * j + s1 + 2ll * s2 * j - 2 * (long long)a[n + i].x);
printf("%lld",ans);
return 0;
}
任意模数NTT模板
谁来教我NTT模数和原根怎么背
#include <iostream>
#include <cstdio>
using namespace std;
const int maxn = 3e5 + 50;
const int g = 3,p1 = 469762049,p2 = 998244353,p3 = 1004535809;
int n,m,p,N = 1,l,rev[maxn],a[maxn],b[maxn],t1[maxn],t2[maxn],d[3][maxn];
int read(){
int x = 0;
char c = getchar();
while(c < '0' || c > '9') c = getchar();
while(c >= '0' && c <= '9') x = x * 10 + (c ^ 48),c = getchar();
return x;
}
int qpow(int x,int k,int mod){
long long d = 1,t = x;
while(k){
if(k & 1) d = d * t % mod;
t = t * t % mod,k >>= 1;
}
return d;
}
void DNT(int *a,int op,int mod){
for(int i = 0; i < N; i ++) if(i < rev[i]) swap(a[i],a[rev[i]]);
for(int i = 1; i < N; i <<= 1){
int w1 = qpow(g,(mod - 1) / (2 * i),mod);
for(int j = 0; j < N; j += 2 * i){
int w = 1;
for(int k = j; k < j + i; k ++){
int t1 = a[k],t2 = 1ll * w * a[k + i] % mod;
a[k] = (t1 + t2) % mod,a[k + i] = (t1 - t2 + mod) % mod;
w = 1ll * w * w1 % mod;
}
}
}
if(!op){
int inv = qpow(N,mod - 2,mod);
for(int i = 0; i < N; i ++) a[i] = 1ll * a[i] * inv % mod;
for(int i = 1; i <= N / 2; i ++) swap(a[i],a[N - i]);
}
}
void NTT(int *a,int *b,int mod){
DNT(a,1,mod),DNT(b,1,mod);
for(int i = 0; i < N; i ++) a[i] = 1ll * a[i] * b[i] % mod;
DNT(a,0,mod);
}
void solve(int *a,int * b,int *s,int mod){
for(int i = 0; i < N; i ++) t1[i] = a[i];
for(int i = 0; i < N; i ++) t2[i] = b[i];
NTT(t1,t2,mod);
for(int i = 0; i < N; i ++) s[i] = t1[i];
}
int CRT(long long x,long long y,long long z){
long long t = (y - x + p2) % p2 * qpow(p1,p2 - 2,p2) % p2 * p1 + x;
return ((z - t % p3 + p3) % p3 * qpow(1ll * p1 * p2 % p3,p3 - 2,p3) % p3 * (1ll * p1 * p2 % p) % p + t) % p;
}
int main(){
n = read(),m = read(),p = read();
while(N <= n + m) N <<= 1,l ++;
for(int i = 0; i < N; i ++) rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (l - 1));
for(int i = 0; i <= n; i ++) a[i] = read();
for(int i = 0; i <= m; i ++) b[i] = read();
solve(a,b,d[0],p1);
solve(a,b,d[1],p2);
solve(a,b,d[2],p3);
// for(int i = 0; i < N; i ++) cout << d[0][i] << ' ' << d[1][i] << ' ' << d[2][i] << endl;
for(int i = 0; i <= n + m; i ++) printf("%d ",CRT(d[0][i],d[1][i],d[2][i]));
return 0;
}
推荐阅读