【树状数组】普通、二维、离线树状数组的(单点修改,单点查询,区间修改,区间查询)模板及应用总结
文章目录
树状数组
用于快速高效的计算与前缀和相关的信息
lowbit
int lowbit( int i ) { return i & -i; }
l o w b i t \rm lowbit lowbit:返回 x x x二进制最低位为 1 1 1的位置的值
e.g.
40=101000
,lowbit(40)=8
线段树与树状数组
因为涉及 l o w b i t \rm lowbit lowbit,所以树状数组的下标一定从 1 1 1开始,而不是 0 0 0
线段树用mid=(l+r)>>1
进行
log
\log
log的优化
树状数组的通过 ± l o w b i t ( i ) ±\rm lowbit(i) ±lowbit(i)进行二进制位的进/退 1 1 1
时间复杂度同样都是 O ( n log n ) O(n\log n) O(nlogn)
但一般来说树状数组的空间都是 O ( N ) O(N) O(N),不会像线段树有 N < < 2 N<<2 N<<2的大空间
线段树因为其结构原因有更多的应用:优化建图,线段树分治 . . . ... ...
但是树状数组就比较死板了,就是跟静态区间/单点挂钩
那么应用广点,消耗的代价(更大空间)多点也是可以理解的了
很多情况下树状数组和线段树没有什么区别,可以互换
单点修改
void add( int i, int val ) {
for( ;i <= n;i += lowbit( i ) )
t[i] += val;
}
区间查询
int query( int i ) {//求的是[1,i]的前缀和
int ans = 0;
for( ;i;i -= lowbit( i ) )
ans += t[i];
return ans;
}
int query( int l, int r ) {
return query( r ) - query( l - 1 );
}
一般的写法都是维护前缀和,所以修改是 + l o w b i t +\rm lowbit +lowbit,查询是 − l o w b i t -\rm lowbit −lowbit
但有些时候题目反而是跟后缀挂钩,这个时候有两种选择
-
强制后缀转前缀
每次传入 i i i的时候,暴力变成 i = n − i + 1 i=n-i+1 i=n−i+1,然后进行
add
query
的操作 -
直接反转使用树状数组
修改直接 − l o w b i t -\rm lowbit −lowbit,查询 + l o w b i t +\rm lowbit +lowbit
只要维护了意义上的相对即可
区间修改
原序列 a 1 , a 2 , . . . , a n a_1,a_2,...,a_n a1,a2,...,an,定义差分数组 c i = a i − a i − 1 c_i=a_i-a_{i-1} ci=ai−ai−1, 则 a i = ∑ j = 1 i c j a_i=\sum_{j=1}^ic_j ai=∑j=1icj
那么修改区间 [ l , r ] [l,r] [l,r],加上 w w w,相当于在 c l c_l cl加 w w w,在 c r + 1 c_{r+1} cr+1减 w w w
这就将区间修改转化成了两次的单点修改
达到 [ l , r ] [l,r] [l,r]区间的数加 w w w的效果
区间求和
∑ i = 1 n a i = ∑ i = 1 n ∑ j = 1 i c j = ∑ i = 1 n ( n − i + 1 ) c i = ( c 1 ) + ( c 1 + c 2 ) + . . . + ( c 1 + c 2 + . . . + c n ) \sum_{i=1}^na_i=\sum_{i=1}^n\sum_{j=1}^ic_j=\sum_{i=1}^n(n-i+1)c_i \\=(c_1)+(c_1+c_2)+...+(c_1+c_2+...+c_n) i=1∑nai=i=1∑nj=1∑icj=i=1∑n(n−i+1)ci=(c1)+(c1+c2)+...+(c1+c2+...+cn)
= ( n + 1 ) ∗ ( c 1 + c 2 + . . . + c n ) − ( c 1 ⏞ 1 + c 2 + c 2 ⏞ 2 + . . . + c n + c n ⏞ n ) =(n+1)*(c_1+c_2+...+c_n)-(\overbrace{c_1}^{1}+\overbrace{c_2+c_2}^{2}+...\overbrace{+c_n+c_n}^{n}) =(n+1)∗(c1+c2+...+cn)−(c1 1+c2+c2 2+...+cn+cn n)
=
(
n
+
1
)
∗
∑
i
=
1
n
c
i
−
∑
i
=
1
n
c
i
∗
i
=(n+1)*\sum_{i=1}^nc_i-\sum_{i=1}^nc_i*i
=(n+1)∗i=1∑nci−i=1∑nci∗i
所以只需要用两个树状数组,分别维护
c
i
c_i
ci和
c
i
∗
i
c_i*i
ci∗i即可
#include <cstdio>
#define int long long
#define maxn 1000005
int n, Q;
int a[maxn], t1[maxn], t2[maxn];
int lowbit( int x ) { return x & -x; }
void modify( int x, int val ) {
for( int i = x;i <= n;i += lowbit( i ) )
t1[i] += val, t2[i] += val * x;
}
int query( int x ) {
int ans = 0;
for( int i = x;i;i -= lowbit( i ) )
ans += ( x + 1 ) * t1[i] - t2[i];
return ans;
}
signed main() {
scanf( "%lld %lld", &n, &Q );
for( int i = 1;i <= n;i ++ ) {
scanf( "%lld", &a[i] );
modify( i, a[i] - a[i - 1] );
}
int opt, l, r, x;
while( Q -- ) {
scanf( "%lld %lld %lld", &opt, &l, &r );
if( opt & 1 ) {
scanf( "%lld", &x );
modify( l, x ), modify( r + 1, -x );
}
else printf( "%lld\n", query( r ) - query( l - 1 ) );
}
return 0;
}
二维树状数组
既然树状数组是前缀和的工具,那么二维树状数组就相当于与二维差分
树状数组嵌树状数组的感觉,查询就是用二维差分计算围成的面积
void modify( int x, int y, int val ) {
for( int i = x;i <= n;i += lowbit( i ) )
for( int j = y;j <= m;j += lowbit( j ) )
t[i][j] += val;
}
int query( int x, int y ) {
int ans = 0;
for( int i = x;i;i -= lowbit( i ) )
for( int j = y;j;j -= lowbit( j ) )
ans += t[i][j];
return ans;
}
modify( x1, y1, k );
query( x2, y2 ) - query( x2, y1 - 1 ) - query( x1 - 1, y2 ) + query( x1 - 1, y1 - 1 );
离线树状数组
离线树状数组求区间不同数的个数/值和
都是将询问按照 l , r \rm l,r l,r排序,然后记录 i i i的上一个/下一个位置
将指针拨到询问的端点处,删去上一个位置/加入下一个位置
从而做到 1 1 1的个数差,满足不同数只记录一次的要求,自然树状数组就能维护
例题的最后两题就是如此,看代码比较清晰能够理解
例题
POJ:stars
题目已经保证了 y y y递增,那么树状数组维护 x x x,每次查询比 x x x小的星星有多少个即可
#include <cstdio>
#include <iostream>
using namespace std;
#define maxn 32005
int n, N;
int ans[maxn], t[maxn], x[maxn], y[maxn];
int lowbit( int i ) { return i & -i; }
void modify( int i ) {
for( ;i <= N;i += lowbit( i ) ) t[i] ++;
}
int query( int i ) {
int ret = 0;
for( ;i;i -= lowbit( i ) ) ret += t[i];
return ret;
}
int main() {
scanf( "%d", &n );
for( int i = 1;i <= n;i ++ ) {
scanf( "%d %d", &x[i], &y[i] );
x[i] ++, y[i] ++, N = max( N, x[i] );
//x,y值域包含0 树状数组不能从0开始 所以整体+1
}
for( int i = 1;i <= n;i ++ )
ans[query( x[i] )] ++, modify( x[i] );
for( int i = 0;i < n;i ++ )
printf( "%d\n", ans[i] );
return 0;
}
MooFest
按 v v v排序,维护两个树状数组,一个是小于当前位置的位置和,一个是大于当前位置的位置和,这样就避免了距离带的绝对值
#include <cstdio>
#include <iostream>
#include <algorithm>
using namespace std;
#define int long long
#define maxn 20005
struct node {
int val, pos;
node(){}
node( int Val, int Pos ) {
val = Val, pos = Pos;
}
}cow[maxn];
struct Node {
int cnt, sumd;
Node(){}
Node( int Cnt, int Sumd ) {
cnt = Cnt, sumd = Sumd;
}
}t1[maxn], t2[maxn];
int n, N;
int lowbit( int i ) { return i & -i; }
void modify1( int i, int val ) {
for( ;i <= N;i += lowbit( i ) )
t1[i].cnt ++, t1[i].sumd += val;
}
void modify2( int i, int val ) {
for( ;i;i -= lowbit( i ) )
t2[i].cnt ++, t2[i].sumd += val;
}
Node query1( int i ) {
Node ans( 0, 0 );
for( ;i;i -= lowbit( i ) )
ans.cnt += t1[i].cnt, ans.sumd += t1[i].sumd;
return ans;
}
Node query2( int i ) {
Node ans( 0, 0 );
for( ;i <= N;i += lowbit( i ) )
ans.cnt += t2[i].cnt, ans.sumd += t2[i].sumd;
return ans;
}
bool cmp( node x, node y ) { return x.val < y.val; }
signed main() {
scanf( "%lld", &n );
for( int i = 1;i <= n;i ++ ) {
scanf( "%lld %lld", &cow[i].val, &cow[i].pos );
N = max( cow[i].pos, N );
}
N ++;
sort( cow + 1, cow + n + 1, cmp );
int ans = 0;
for( int i = 1;i <= n;i ++ ) {
Node t = query1( cow[i].pos );
ans += ( cow[i].pos * t.cnt - t.sumd ) * cow[i].val;
t = query2( cow[i].pos );
ans += ( t.sumd - t.cnt * cow[i].pos ) * cow[i].val;
modify1( cow[i].pos, cow[i].pos );
modify2( cow[i].pos, cow[i].pos );
}
printf( "%lld\n", ans );
return 0;
}
[SDOI2009]HH的项链
离线树状数组求区间不同数个数
将询问按 l l l排序,对于每个位置 i i i记录下一个与该位置值相等的位置,每一次到 i i i就把下一次的位置加进去
询问区间左端点以前的自然都要加,这样区间查询相减,就知道下一次的位置在不在区间内,就恰好为 1 1 1
#include <cstdio>
#include <algorithm>
using namespace std;
#define maxn 1000005
struct node {
int l, r, id;
}q[maxn];
int n, m;
int a[maxn], t[maxn], lst[maxn], nxt[maxn], ans[maxn];
bool vis[maxn];
void read( int &x ) {
x = 0; char s = getchar();
while( s < '0' or s > '9' ) s = getchar();
while( '0' <= s and s <= '9' ) {
x = ( x << 1 ) + ( x << 3 ) + ( s ^ 48 );
s = getchar();
}
}
int lowbit( int i ) { return i & -i; }
void add( int i ) {
for( ;i < maxn;i += lowbit( i ) ) t[i] ++;
}
int query( int i ) {
int ret = 0;
for( ;i;i -= lowbit( i ) ) ret += t[i];
return ret;
}
int main() {
read( n );
for( int i = 1;i <= n;i ++ ) read( a[i] );
read( m );
for( int i = 1;i <= m;i ++ )
read( q[i].l ), read( q[i].r ), q[i].id = i;
sort( q + 1, q + m + 1, []( node x, node y ) { return x.l < y.l; } );
for( int i = 1;i <= n;i ++ )
if( ! vis[a[i]] ) add( i ), vis[a[i]] = 1;
for( int i = n;i;i -- ) {
if( ! lst[a[i]] ) nxt[i] = maxn;
else nxt[i] = lst[a[i]];
lst[a[i]] = i;
}
int pos = 1;
for( int i = 1;i <= m;i ++ ) {
while( pos < q[i].l ) add( nxt[pos] ), pos ++;
ans[q[i].id] = query( q[i].r ) - query( q[i].l - 1 );
}
for( int i = 1;i <= m;i ++ )
printf( "%d\n", ans[i] );
return 0;
}
Turing Tree
离线树状数组求区间不同数的和
与不同数的个数一致的思路,这里值域比较大,就记录下标
按照 r r r排序也可
#include <map>
#include <cstdio>
#include <algorithm>
using namespace std;
#define maxn 30005
#define maxm 100005
#define int long long
struct node {
int l, r, id;
}q[maxm];
map < int, int > lst;
int T, n, m;
int a[maxn], t[maxn], ans[maxm];
int lowbit( int i ) { return i & -i; }
void add( int i, int val ) {
for( ;i <= n;i += lowbit( i ) ) t[i] += val;
}
int query( int i ) {
int ret = 0;
for( ;i;i -= lowbit( i ) ) ret += t[i];
return ret;
}
signed main() {
scanf( "%lld", &T );
while( T -- ) {
scanf( "%lld", &n );
for( int i = 1;i <= n;i ++ )
scanf( "%lld", &a[i] ), t[i] = 0;
scanf( "%lld", &m );
for( int i = 1;i <= m;i ++ )
scanf( "%lld %lld", &q[i].l, &q[i].r ), q[i].id = i;
sort( q + 1, q + m + 1, []( node x, node y ) { return x.r < y.r; } );
lst.clear();
int j = 1;
for( int i = 1;i <= m;i ++ ) {
for( ;j <= q[i].r;j ++ ) {
if( lst[a[j]] ) add( lst[a[j]], -a[j] );
add( j, a[j] ), lst[a[j]] = j;
}
ans[q[i].id] = query( q[i].r ) - query( q[i].l - 1 );
}
for( int i = 1;i <= m;i ++ )
printf( "%lld\n", ans[i] );
}
return 0;
}
上一篇: CD-HIT的使用
下一篇: 【Python SQLAlchemy】