欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页  >  IT编程

【题解】LuoGu6859:蝴蝶与花

程序员文章站 2022-03-27 16:02:51
原题传送门这是一个超难题的模板,但是有一个性质,ai=1/2a_i=1/2ai​=1/2,我敏锐的感觉到这道题可能是基于这个性质来做的先令l=1l=1l=1,看看是否存在rrr,使得al+...+ar=sa_l+...+a_r=sal​+...+ar​=s发现不存在这个rrr,当且仅当存在rrr,使得∑i=lrai=s−1且ar+1=2\sum_{i=l}^{r}a_i=s-1且a_{r+1}=2∑i=lr​ai​=s−1且ar+1​=2如果存在rrr满足,当然直接输出,若是不存在必定形如考虑移...

原题传送门
这是一个超难题的模板,但是有一个性质, a i = 1 / 2 a_i=1/2 ai=1/2,我敏锐的感觉到这道题可能是基于这个性质来做的

先令 l = 1 l=1 l=1,看看是否存在 r r r,使得 a l + . . . + a r = s a_l+...+a_r=s al+...+ar=s
发现不存在这个 r r r,当且仅当存在 r r r,使得 ∑ i = l r a i = s − 1 且 a r + 1 = 2 \sum_{i=l}^{r}a_i=s-1且a_{r+1}=2 i=lrai=s1ar+1=2
如果存在 r r r满足,当然直接输出,若是不存在
必定形如【题解】LuoGu6859:蝴蝶与花

考虑移动 l , r l,r l,r,然后这个时候一个必定的条件是 a r + 1 = 2 a_{r+1}=2 ar+1=2
如果 a 1 = 1 a_1=1 a1=1,可以使 l + 1 , r + 1 l+1,r+1 l+1,r+1,这样整段区间和+1,得到答案
如果 a 1 = 2 , a r + 2 = 1 a_1=2,a_{r+2}=1 a1=2,ar+2=1,可以 l + 1 , r + 2 l+1,r+2 l+1,r+2,也能得到答案
发现其实答案跟l右端第一个1的位置与r右端第一个1的位置有关

p 1 , p 2 p1,p2 p1,p2分别代表 l , r l,r l,r右端第一个1的位置
若不存在另行讨论
那么又可以得到从 l − > p 1 l->p1 l>p1之间有 l s u m = p 1 − 1 lsum = p1-1 lsum=p11个2
r − > p 2 r->p2 r>p2之间有 r s u m = p 2 − r − 1 rsum=p2-r-1 rsum=p2r1个2

讨论 l s u m , r s u m lsum,rsum lsum,rsum的大小(具体原因可以在草稿纸上得出)

  • l s u m < r s u m : l = p 1 + 1 , r = p 1 + r lsum<rsum:l=p1+1,r=p1+r lsum<rsum:l=p1+1,r=p1+r
  • l s u m = r s u m : l = p 1 , r = p 2 lsum=rsum:l=p1,r=p2 lsum=rsum:l=p1,r=p2
  • l s u m > r s u m : l = r s u m + 1 , r = p 2 lsum>rsum:l=rsum+1,r=p2 lsum>rsum:l=rsum+1,r=p2

题目就做出来了

问题来到如何求出 r , p 1 , p 2 r,p1,p2 r,p1,p2,这个可以用数据结构直接来维护
如果用树状数组,需要套上二分,复杂度 O ( n l o g 2 n ) O(nlog^2n) O(nlog2n)
如果用线段树,直接 O ( n l o g n ) O(nlogn) O(nlogn)

我两个都写了

Code树状数组(这个要开O2才能过):

#include <bits/stdc++.h>
#define maxn 2000010
using namespace std;
int tree1[maxn], tree2[maxn], n, m, a[maxn];

inline int read(){
	int s = 0, w = 1;
	char c = getchar();
	for (; !isdigit(c); c = getchar()) if (c == '-') w = -1;
	for (; isdigit(c); c = getchar()) s = (s << 1) + (s << 3) + (c ^ 48);
	return s * w;
}

int lowbit(int x){ return x & -x; }
void update1(int x, int y){ for (; x <= n; x += lowbit(x)) tree1[x] += y; }
void update2(int x, int y){ for (; x <= n; x += lowbit(x)) tree2[x] += y; }
int query1(int x){ int s = 0; for (; x; x -= lowbit(x)) s += tree1[x]; return s; }
int query2(int x){ int s = 0; for (; x; x -= lowbit(x)) s += tree2[x]; return s; }

int find(int s){
	int l = 1, r = n, ans = -1;
	while (l <= r){
		int mid = (l + r) >> 1;
		if (query1(mid) <= s) ans = mid, l = mid + 1; else r = mid - 1;
	}
	return ans;
}

int find1(int x){
	int l = x, r = n, ans = -1, base = query2(x - 1);
	while (l <= r){
		int mid = (l + r) >> 1;
		if (query2(mid) > base) ans = mid, r = mid - 1; else l = mid + 1;
	}
	return ans;
}

int main(){
	n = read(), m = read();
	for (int i = 1; i <= n; ++i){
		a[i] = read();
		update1(i, a[i]), update2(i, a[i] == 1);
	}
	int p1 = find1(1);
	while (m--){
		char c = getchar();
		for (; c != 'C' && c != 'A'; c = getchar());
		int x = read();
		if (c == 'A'){
			if (x == 1){
				if (p1 == -1) puts("none");
				else printf("%d %d\n", p1, p1); continue;
			}
			if (query1(n) < x){
				puts("none"); continue;
			}
			int pos = find(x);
			if (pos < 1){
				puts("none"); continue;
			} else
			if (query1(pos) == x){
				printf("%d %d\n", 1, pos); continue;
			}
			int p2 = find1(pos + 1);
			if (p1 == -1){
				puts("none"); continue;
			}
			if (p2 == -1){
				if (n - pos <= p1 - 1) puts("none");
				else printf("%d %d\n", p1 + 1, pos + p1);
				continue;
			}
			int l1 = p1 - 1, l2 = p2 - pos - 1;
			if (l1 < l2) printf("%d %d\n", p1 + 1, pos + p1);
			else if (l1 == l2) printf("%d %d\n", p1, p2);
			else printf("%d %d\n", l2 + 1, p2);
		} else{
			int y = read();
			if (a[x] != y){
				update1(x, -a[x] + y);
				if (y == 2) update2(x, -1); else update2(x, 1);
				a[x] = y;
				p1 = find1(1);
			}
		}
	}
	return 0;
}

Code线段树:

#include <bits/stdc++.h>
#define maxn 2000010
#define ls rt << 1
#define rs rt << 1 | 1
using namespace std;
int n, m, a[maxn], ifequal;
struct Seg{
	int l, r, sum, cnt;
}seg[maxn << 2];

inline int read(){
	int s = 0, w = 1;
	char c = getchar();
	for (; !isdigit(c); c = getchar()) if (c == '-') w = -1;
	for (; isdigit(c); c = getchar()) s = (s << 1) + (s << 3) + (c ^ 48);
	return s * w;
}

void pushup(int rt){
	seg[rt].sum = seg[ls].sum + seg[rs].sum;
	seg[rt].cnt = seg[ls].cnt + seg[rs].cnt;
}

void build(int rt, int l, int r){
	seg[rt].l = l, seg[rt].r = r;
	if (l == r){
		seg[rt].sum = a[l], seg[rt].cnt = (a[l] == 1);
		return;
	}
	int mid = (l + r) >> 1;
	build(ls, l, mid), build(rs, mid + 1, r);
	pushup(rt);
}

void update(int rt, int x, int y){
	if (seg[rt].l > x || seg[rt].r < x) return;
	if (seg[rt].l == seg[rt].r){
		seg[rt].sum = y;
		seg[rt].cnt = (y == 1);
		return;
	}
	update(ls, x, y), update(rs, x, y);
	pushup(rt);
}

int query(int rt, int x){
	if (seg[rt].r <= x) return seg[rt].cnt;
	if (seg[rt].l > x) return 0;
	return query(ls, x) + query(rs, x);
}

int query2(int rt, int x){
	if (seg[rt].r <= x) return seg[rt].sum;
	if (seg[rt].l > x) return 0;
	return query2(ls, x) + query2(rs, x);
}

int find1(int rt, int x){
	if (seg[rt].l == seg[rt].r) return seg[rt].cnt > x ? seg[rt].l : -1;
	if (seg[ls].cnt > x) return find1(ls, x);
	else return find1(rs, x - seg[ls].cnt);
}

int findpos(int rt, int x){
	if (seg[rt].sum == x){
		ifequal = 1;
		return seg[rt].r;
	}
	if (seg[rt].sum < x && seg[rt].sum + a[seg[rt].r + 1] > x) return seg[rt].r;
	if (seg[rt].l == seg[rt].r) return -1;
	if (seg[ls].sum >= x) return findpos(ls, x);
	else if (seg[ls].sum < x && seg[ls].sum + a[seg[ls].r + 1] > x) return seg[ls].r;
	else return findpos(rs, x - seg[ls].sum);
}

int main(){
	n = read(), m = read();
	for (int i = 1; i <= n; ++i) a[i] = read();
	build(1, 1, n);
	int p1 = find1(1, 0);
	while (m--){
		char c = getchar();
		for (; c != 'C' && c != 'A'; c = getchar());
		int x = read();
		if (c == 'A'){
			if (x == 1){
				if (p1 == -1) puts("none");
				else printf("%d %d\n", p1, p1); continue;
			}
			if (seg[1].sum < x){
				puts("none"); continue;
			}
			ifequal = 0;
			int pos = findpos(1, x);
			if (pos < 1){
				puts("none"); continue;
			} else if (ifequal){
				printf("%d %d\n", 1, pos); continue;
			}
			int p2 = find1(1, query(1, pos));
			if (p1 == -1){
				puts("none"); continue;
			}
			if (p2 == -1){
				if (n - pos <= p1 - 1) puts("none");
				else printf("%d %d\n", p1 + 1, pos + p1);
				continue;
			}
			int l1 = p1 - 1, l2 = p2 - pos - 1;
			if (l1 < l2) printf("%d %d\n", p1 + 1, pos + p1);
			else if (l1 == l2) printf("%d %d\n", p1, p2);
			else printf("%d %d\n", l2 + 1, p2);
		} else{
			int y = read();
			if (a[x] != y) update(1, x, y), p1 = find1(1, 0), a[x] = y;
		}
	}
	return 0;
}

后来我翻一翻题解,发现大家都把线段树的几个部分改成非递归版本了,好像可以更快,我也改了一下

Code:

#include <bits/stdc++.h>
#define maxn 2000010
#define ls rt << 1
#define rs rt << 1 | 1
using namespace std;
struct Seg{
	int l, r, sum, cnt;
}seg[maxn << 2];
int  n, m, a[maxn], ifequal;

inline int read(){
	int s = 0, w = 1;
	char c = getchar();
	for (; !isdigit(c); c = getchar()) if (c == '-') w = -1;
	for (; isdigit(c); c = getchar()) s = (s << 1) + (s << 3) + (c ^ 48);
	return s * w;
}

void pushup(int rt){
	seg[rt].sum = seg[ls].sum + seg[rs].sum;
	seg[rt].cnt = seg[ls].cnt + seg[rs].cnt;
}

void build(int rt, int l, int r){
	seg[rt].l = l, seg[rt].r = r;
	if (l == r){
		seg[rt].sum = a[l], seg[rt].cnt = a[l] == 1;
		return;
	}
	int mid = (l + r) >> 1;
	build(ls, l, mid), build(rs, mid + 1, r);
	pushup(rt);
}

void update(int rt, int x, int y){
	if (seg[rt].l == seg[rt].r){
		seg[rt].sum = y, seg[rt].cnt = (y == 1);
		return;
	}
	if (seg[ls].r >= x) update(ls, x, y);
	else update(rs, x, y);
	pushup(rt);
}

int query(int rt, int x){
	if (seg[rt].r <= x) return seg[rt].cnt;
	if (seg[rt].l > x) return 0;
	return query(ls, x) + query(rs, x);
}

int find1(int rt, int x){
	while (1){
		if (seg[rt].l == seg[rt].r){
			if (seg[rt].cnt > x) return seg[rt].l;
			else return -1;
		}
		if (seg[ls].cnt > x) rt = ls;
		else x -= seg[ls].cnt, rt = rs;
	}
	return -1;
}

int findpos(int rt, int x){
	while (1){
		if (seg[rt].sum == x){
			ifequal = 1;
			return seg[rt].r;
		}
		if (seg[rt].sum < x && seg[rt].sum + a[seg[rt].r + 1] > x) return seg[rt].r;
		if (seg[rt].l == seg[rt].r) return -1;
		if (seg[ls].sum >= x) rt = ls;
		else if (seg[ls].sum < x && seg[ls].sum + a[seg[ls].r + 1] > x) return seg[ls].r;
		else x -= seg[ls].sum, rt = rs;
	}
	return -1;
}

int main(){
	n = read(), m = read();
	for (int i = 1; i <= n; ++i) a[i] = read();
	build(1, 1, n);
	int p1 = find1(1, 0);
	while (m--){
		char c = getchar();
		for (; c != 'C' && c != 'A'; c = getchar());
		int x = read();
		if (c == 'A'){
			if (x == 1){
				if (p1 == -1) puts("none");
				else printf("%d %d\n", p1, p1); continue;
			}
			if (seg[1].sum < x){
				puts("none"); continue;
			}
			ifequal = 0;
			int pos = findpos(1, x);
			if (pos < 1){
				puts("none"); continue;
			} else
			if (ifequal){
				printf("%d %d\n", 1, pos); continue;
			}
			int p2 = find1(1, query(1, pos));
			if (p1 == -1){
				puts("none"); continue;
			}
			if (p2 == -1){
				if (n - pos <= p1 - 1) puts("none");
				else printf("%d %d\n", p1 + 1, pos + p1);
				continue;
			}
			int l1 = p1 - 1, l2 = p2 - pos - 1;
			if (l1 < l2) printf("%d %d\n", p1 + 1, pos + p1);
			else if (l1 == l2) printf("%d %d\n", p1, p2);
			else printf("%d %d\n", l2 + 1, p2);
		} else{
			int y = read();
			if (a[x] != y){
				update(1, x, y);
				a[x] = y;
				p1 = find1(1, 0);
			}
		}
	}
	return 0;
}

本文地址:https://blog.csdn.net/ModestCoder_/article/details/109271870