【题解】LuoGu6859:蝴蝶与花
原题传送门
这是一个超难题的模板,但是有一个性质,
a
i
=
1
/
2
a_i=1/2
ai=1/2,我敏锐的感觉到这道题可能是基于这个性质来做的
先令
l
=
1
l=1
l=1,看看是否存在
r
r
r,使得
a
l
+
.
.
.
+
a
r
=
s
a_l+...+a_r=s
al+...+ar=s
发现不存在这个
r
r
r,当且仅当存在
r
r
r,使得
∑
i
=
l
r
a
i
=
s
−
1
且
a
r
+
1
=
2
\sum_{i=l}^{r}a_i=s-1且a_{r+1}=2
∑i=lrai=s−1且ar+1=2
如果存在
r
r
r满足,当然直接输出,若是不存在
必定形如
考虑移动
l
,
r
l,r
l,r,然后这个时候一个必定的条件是
a
r
+
1
=
2
a_{r+1}=2
ar+1=2
如果
a
1
=
1
a_1=1
a1=1,可以使
l
+
1
,
r
+
1
l+1,r+1
l+1,r+1,这样整段区间和+1,得到答案
如果
a
1
=
2
,
a
r
+
2
=
1
a_1=2,a_{r+2}=1
a1=2,ar+2=1,可以
l
+
1
,
r
+
2
l+1,r+2
l+1,r+2,也能得到答案
发现其实答案跟l右端第一个1的位置与r右端第一个1的位置有关
记
p
1
,
p
2
p1,p2
p1,p2分别代表
l
,
r
l,r
l,r右端第一个1的位置
若不存在另行讨论
那么又可以得到从
l
−
>
p
1
l->p1
l−>p1之间有
l
s
u
m
=
p
1
−
1
lsum = p1-1
lsum=p1−1个2
r
−
>
p
2
r->p2
r−>p2之间有
r
s
u
m
=
p
2
−
r
−
1
rsum=p2-r-1
rsum=p2−r−1个2
讨论 l s u m , r s u m lsum,rsum lsum,rsum的大小(具体原因可以在草稿纸上得出)
- l s u m < r s u m : l = p 1 + 1 , r = p 1 + r lsum<rsum:l=p1+1,r=p1+r lsum<rsum:l=p1+1,r=p1+r
- l s u m = r s u m : l = p 1 , r = p 2 lsum=rsum:l=p1,r=p2 lsum=rsum:l=p1,r=p2
- l s u m > r s u m : l = r s u m + 1 , r = p 2 lsum>rsum:l=rsum+1,r=p2 lsum>rsum:l=rsum+1,r=p2
题目就做出来了
问题来到如何求出
r
,
p
1
,
p
2
r,p1,p2
r,p1,p2,这个可以用数据结构直接来维护
如果用树状数组,需要套上二分,复杂度
O
(
n
l
o
g
2
n
)
O(nlog^2n)
O(nlog2n)
如果用线段树,直接
O
(
n
l
o
g
n
)
O(nlogn)
O(nlogn)
我两个都写了
Code树状数组(这个要开O2才能过):
#include <bits/stdc++.h>
#define maxn 2000010
using namespace std;
int tree1[maxn], tree2[maxn], n, m, a[maxn];
inline int read(){
int s = 0, w = 1;
char c = getchar();
for (; !isdigit(c); c = getchar()) if (c == '-') w = -1;
for (; isdigit(c); c = getchar()) s = (s << 1) + (s << 3) + (c ^ 48);
return s * w;
}
int lowbit(int x){ return x & -x; }
void update1(int x, int y){ for (; x <= n; x += lowbit(x)) tree1[x] += y; }
void update2(int x, int y){ for (; x <= n; x += lowbit(x)) tree2[x] += y; }
int query1(int x){ int s = 0; for (; x; x -= lowbit(x)) s += tree1[x]; return s; }
int query2(int x){ int s = 0; for (; x; x -= lowbit(x)) s += tree2[x]; return s; }
int find(int s){
int l = 1, r = n, ans = -1;
while (l <= r){
int mid = (l + r) >> 1;
if (query1(mid) <= s) ans = mid, l = mid + 1; else r = mid - 1;
}
return ans;
}
int find1(int x){
int l = x, r = n, ans = -1, base = query2(x - 1);
while (l <= r){
int mid = (l + r) >> 1;
if (query2(mid) > base) ans = mid, r = mid - 1; else l = mid + 1;
}
return ans;
}
int main(){
n = read(), m = read();
for (int i = 1; i <= n; ++i){
a[i] = read();
update1(i, a[i]), update2(i, a[i] == 1);
}
int p1 = find1(1);
while (m--){
char c = getchar();
for (; c != 'C' && c != 'A'; c = getchar());
int x = read();
if (c == 'A'){
if (x == 1){
if (p1 == -1) puts("none");
else printf("%d %d\n", p1, p1); continue;
}
if (query1(n) < x){
puts("none"); continue;
}
int pos = find(x);
if (pos < 1){
puts("none"); continue;
} else
if (query1(pos) == x){
printf("%d %d\n", 1, pos); continue;
}
int p2 = find1(pos + 1);
if (p1 == -1){
puts("none"); continue;
}
if (p2 == -1){
if (n - pos <= p1 - 1) puts("none");
else printf("%d %d\n", p1 + 1, pos + p1);
continue;
}
int l1 = p1 - 1, l2 = p2 - pos - 1;
if (l1 < l2) printf("%d %d\n", p1 + 1, pos + p1);
else if (l1 == l2) printf("%d %d\n", p1, p2);
else printf("%d %d\n", l2 + 1, p2);
} else{
int y = read();
if (a[x] != y){
update1(x, -a[x] + y);
if (y == 2) update2(x, -1); else update2(x, 1);
a[x] = y;
p1 = find1(1);
}
}
}
return 0;
}
Code线段树:
#include <bits/stdc++.h>
#define maxn 2000010
#define ls rt << 1
#define rs rt << 1 | 1
using namespace std;
int n, m, a[maxn], ifequal;
struct Seg{
int l, r, sum, cnt;
}seg[maxn << 2];
inline int read(){
int s = 0, w = 1;
char c = getchar();
for (; !isdigit(c); c = getchar()) if (c == '-') w = -1;
for (; isdigit(c); c = getchar()) s = (s << 1) + (s << 3) + (c ^ 48);
return s * w;
}
void pushup(int rt){
seg[rt].sum = seg[ls].sum + seg[rs].sum;
seg[rt].cnt = seg[ls].cnt + seg[rs].cnt;
}
void build(int rt, int l, int r){
seg[rt].l = l, seg[rt].r = r;
if (l == r){
seg[rt].sum = a[l], seg[rt].cnt = (a[l] == 1);
return;
}
int mid = (l + r) >> 1;
build(ls, l, mid), build(rs, mid + 1, r);
pushup(rt);
}
void update(int rt, int x, int y){
if (seg[rt].l > x || seg[rt].r < x) return;
if (seg[rt].l == seg[rt].r){
seg[rt].sum = y;
seg[rt].cnt = (y == 1);
return;
}
update(ls, x, y), update(rs, x, y);
pushup(rt);
}
int query(int rt, int x){
if (seg[rt].r <= x) return seg[rt].cnt;
if (seg[rt].l > x) return 0;
return query(ls, x) + query(rs, x);
}
int query2(int rt, int x){
if (seg[rt].r <= x) return seg[rt].sum;
if (seg[rt].l > x) return 0;
return query2(ls, x) + query2(rs, x);
}
int find1(int rt, int x){
if (seg[rt].l == seg[rt].r) return seg[rt].cnt > x ? seg[rt].l : -1;
if (seg[ls].cnt > x) return find1(ls, x);
else return find1(rs, x - seg[ls].cnt);
}
int findpos(int rt, int x){
if (seg[rt].sum == x){
ifequal = 1;
return seg[rt].r;
}
if (seg[rt].sum < x && seg[rt].sum + a[seg[rt].r + 1] > x) return seg[rt].r;
if (seg[rt].l == seg[rt].r) return -1;
if (seg[ls].sum >= x) return findpos(ls, x);
else if (seg[ls].sum < x && seg[ls].sum + a[seg[ls].r + 1] > x) return seg[ls].r;
else return findpos(rs, x - seg[ls].sum);
}
int main(){
n = read(), m = read();
for (int i = 1; i <= n; ++i) a[i] = read();
build(1, 1, n);
int p1 = find1(1, 0);
while (m--){
char c = getchar();
for (; c != 'C' && c != 'A'; c = getchar());
int x = read();
if (c == 'A'){
if (x == 1){
if (p1 == -1) puts("none");
else printf("%d %d\n", p1, p1); continue;
}
if (seg[1].sum < x){
puts("none"); continue;
}
ifequal = 0;
int pos = findpos(1, x);
if (pos < 1){
puts("none"); continue;
} else if (ifequal){
printf("%d %d\n", 1, pos); continue;
}
int p2 = find1(1, query(1, pos));
if (p1 == -1){
puts("none"); continue;
}
if (p2 == -1){
if (n - pos <= p1 - 1) puts("none");
else printf("%d %d\n", p1 + 1, pos + p1);
continue;
}
int l1 = p1 - 1, l2 = p2 - pos - 1;
if (l1 < l2) printf("%d %d\n", p1 + 1, pos + p1);
else if (l1 == l2) printf("%d %d\n", p1, p2);
else printf("%d %d\n", l2 + 1, p2);
} else{
int y = read();
if (a[x] != y) update(1, x, y), p1 = find1(1, 0), a[x] = y;
}
}
return 0;
}
后来我翻一翻题解,发现大家都把线段树的几个部分改成非递归版本了,好像可以更快,我也改了一下
Code:
#include <bits/stdc++.h>
#define maxn 2000010
#define ls rt << 1
#define rs rt << 1 | 1
using namespace std;
struct Seg{
int l, r, sum, cnt;
}seg[maxn << 2];
int n, m, a[maxn], ifequal;
inline int read(){
int s = 0, w = 1;
char c = getchar();
for (; !isdigit(c); c = getchar()) if (c == '-') w = -1;
for (; isdigit(c); c = getchar()) s = (s << 1) + (s << 3) + (c ^ 48);
return s * w;
}
void pushup(int rt){
seg[rt].sum = seg[ls].sum + seg[rs].sum;
seg[rt].cnt = seg[ls].cnt + seg[rs].cnt;
}
void build(int rt, int l, int r){
seg[rt].l = l, seg[rt].r = r;
if (l == r){
seg[rt].sum = a[l], seg[rt].cnt = a[l] == 1;
return;
}
int mid = (l + r) >> 1;
build(ls, l, mid), build(rs, mid + 1, r);
pushup(rt);
}
void update(int rt, int x, int y){
if (seg[rt].l == seg[rt].r){
seg[rt].sum = y, seg[rt].cnt = (y == 1);
return;
}
if (seg[ls].r >= x) update(ls, x, y);
else update(rs, x, y);
pushup(rt);
}
int query(int rt, int x){
if (seg[rt].r <= x) return seg[rt].cnt;
if (seg[rt].l > x) return 0;
return query(ls, x) + query(rs, x);
}
int find1(int rt, int x){
while (1){
if (seg[rt].l == seg[rt].r){
if (seg[rt].cnt > x) return seg[rt].l;
else return -1;
}
if (seg[ls].cnt > x) rt = ls;
else x -= seg[ls].cnt, rt = rs;
}
return -1;
}
int findpos(int rt, int x){
while (1){
if (seg[rt].sum == x){
ifequal = 1;
return seg[rt].r;
}
if (seg[rt].sum < x && seg[rt].sum + a[seg[rt].r + 1] > x) return seg[rt].r;
if (seg[rt].l == seg[rt].r) return -1;
if (seg[ls].sum >= x) rt = ls;
else if (seg[ls].sum < x && seg[ls].sum + a[seg[ls].r + 1] > x) return seg[ls].r;
else x -= seg[ls].sum, rt = rs;
}
return -1;
}
int main(){
n = read(), m = read();
for (int i = 1; i <= n; ++i) a[i] = read();
build(1, 1, n);
int p1 = find1(1, 0);
while (m--){
char c = getchar();
for (; c != 'C' && c != 'A'; c = getchar());
int x = read();
if (c == 'A'){
if (x == 1){
if (p1 == -1) puts("none");
else printf("%d %d\n", p1, p1); continue;
}
if (seg[1].sum < x){
puts("none"); continue;
}
ifequal = 0;
int pos = findpos(1, x);
if (pos < 1){
puts("none"); continue;
} else
if (ifequal){
printf("%d %d\n", 1, pos); continue;
}
int p2 = find1(1, query(1, pos));
if (p1 == -1){
puts("none"); continue;
}
if (p2 == -1){
if (n - pos <= p1 - 1) puts("none");
else printf("%d %d\n", p1 + 1, pos + p1);
continue;
}
int l1 = p1 - 1, l2 = p2 - pos - 1;
if (l1 < l2) printf("%d %d\n", p1 + 1, pos + p1);
else if (l1 == l2) printf("%d %d\n", p1, p2);
else printf("%d %d\n", l2 + 1, p2);
} else{
int y = read();
if (a[x] != y){
update(1, x, y);
a[x] = y;
p1 = find1(1, 0);
}
}
}
return 0;
}
本文地址:https://blog.csdn.net/ModestCoder_/article/details/109271870