欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

【树状数组】普通、二维、离线树状数组的(单点修改,单点查询,区间修改,区间查询)模板及应用总结

程序员文章站 2022-03-11 19:01:47
...

树状数组

用于快速高效的计算与前缀和相关的信息

lowbit

int lowbit( int i ) { return i & -i; }

l o w b i t \rm lowbit lowbit:返回 x x x二进制最低位为 1 1 1的位置的值

e.g. 40=101000lowbit(40)=8

线段树与树状数组

因为涉及 l o w b i t \rm lowbit lowbit,所以树状数组的下标一定从 1 1 1开始,而不是 0 0 0

线段树用mid=(l+r)>>1进行 log ⁡ \log log的优化

树状数组的通过 ± l o w b i t ( i ) ±\rm lowbit(i) ±lowbit(i)进行二进制位的进/退 1 1 1

时间复杂度同样都是 O ( n log ⁡ n ) O(n\log n) O(nlogn)

但一般来说树状数组的空间都是 O ( N ) O(N) O(N),不会像线段树有 N < < 2 N<<2 N<<2的大空间

线段树因为其结构原因有更多的应用:优化建图,线段树分治 . . . ... ...

但是树状数组就比较死板了,就是跟静态区间/单点挂钩

那么应用广点,消耗的代价(更大空间)多点也是可以理解的了

很多情况下树状数组和线段树没有什么区别,可以互换

单点修改

void add( int i, int val ) {
    for( ;i <= n;i += lowbit( i ) )
        t[i] += val;
}

区间查询

int query( int i ) {//求的是[1,i]的前缀和
    int ans = 0;
    for( ;i;i -= lowbit( i ) ) 
        ans += t[i];
    return ans;
}
int query( int l, int r ) {
    return query( r ) - query( l - 1 );
}

一般的写法都是维护前缀和,所以修改是 + l o w b i t +\rm lowbit +lowbit,查询是 − l o w b i t -\rm lowbit lowbit

但有些时候题目反而是跟后缀挂钩,这个时候有两种选择

  • 强制后缀转前缀

    每次传入 i i i的时候,暴力变成 i = n − i + 1 i=n-i+1 i=ni+1,然后进行add query的操作

  • 直接反转使用树状数组

    修改直接 − l o w b i t -\rm lowbit lowbit,查询 + l o w b i t +\rm lowbit +lowbit

只要维护了意义上的相对即可

区间修改

原序列 a 1 , a 2 , . . . , a n a_1,a_2,...,a_n a1,a2,...,an,定义差分数组 c i = a i − a i − 1 c_i=a_i-a_{i-1} ci=aiai1, 则 a i = ∑ j = 1 i c j a_i=\sum_{j=1}^ic_j ai=j=1icj

那么修改区间 [ l , r ] [l,r] [l,r],加上 w w w,相当于在 c l c_l cl w w w,在 c r + 1 c_{r+1} cr+1 w w w

这就将区间修改转化成了两次的单点修改

达到 [ l , r ] [l,r] [l,r]区间的数加 w w w的效果

区间求和

∑ i = 1 n a i = ∑ i = 1 n ∑ j = 1 i c j = ∑ i = 1 n ( n − i + 1 ) c i = ( c 1 ) + ( c 1 + c 2 ) + . . . + ( c 1 + c 2 + . . . + c n ) \sum_{i=1}^na_i=\sum_{i=1}^n\sum_{j=1}^ic_j=\sum_{i=1}^n(n-i+1)c_i \\=(c_1)+(c_1+c_2)+...+(c_1+c_2+...+c_n) i=1nai=i=1nj=1icj=i=1n(ni+1)ci=(c1)+(c1+c2)+...+(c1+c2+...+cn)

= ( n + 1 ) ∗ ( c 1 + c 2 + . . . + c n ) − ( c 1 ⏞ 1 + c 2 + c 2 ⏞ 2 + . . . + c n + c n ⏞ n ) =(n+1)*(c_1+c_2+...+c_n)-(\overbrace{c_1}^{1}+\overbrace{c_2+c_2}^{2}+...\overbrace{+c_n+c_n}^{n}) =(n+1)(c1+c2+...+cn)(c1 1+c2+c2 2+...+cn+cn n)

= ( n + 1 ) ∗ ∑ i = 1 n c i − ∑ i = 1 n c i ∗ i =(n+1)*\sum_{i=1}^nc_i-\sum_{i=1}^nc_i*i =(n+1)i=1ncii=1ncii
所以只需要用两个树状数组,分别维护 c i c_i ci c i ∗ i c_i*i cii即可

LOJ:区间修改区间查询

#include <cstdio>
#define int long long
#define maxn 1000005
int n, Q;
int a[maxn], t1[maxn], t2[maxn];

int lowbit( int x ) { return x & -x; }

void modify( int x, int val ) {
	for( int i = x;i <= n;i += lowbit( i ) )
		t1[i] += val, t2[i] += val * x;
}

int query( int x ) {
	int ans = 0;
	for( int i = x;i;i -= lowbit( i ) )
		ans += ( x + 1 ) * t1[i] - t2[i];
	return ans;
}

signed main() {
	scanf( "%lld %lld", &n, &Q );
	for( int i = 1;i <= n;i ++ ) {
		scanf( "%lld", &a[i] );
		modify( i, a[i] - a[i - 1] );
	}
	int opt, l, r, x;
	while( Q -- ) {
		scanf( "%lld %lld %lld", &opt, &l, &r );
		if( opt & 1 ) {
			scanf( "%lld", &x );
			modify( l, x ), modify( r + 1, -x );
		}
		else printf( "%lld\n", query( r ) - query( l - 1 ) );
	}
	return 0;
}

二维树状数组

既然树状数组是前缀和的工具,那么二维树状数组就相当于与二维差分

树状数组嵌树状数组的感觉,查询就是用二维差分计算围成的面积

void modify( int x, int y, int val ) {
	for( int i = x;i <= n;i += lowbit( i ) )
		for( int j = y;j <= m;j += lowbit( j ) )
			t[i][j] += val;
}
int query( int x, int y ) {
	int ans = 0;
	for( int i = x;i;i -= lowbit( i ) )
		for( int j = y;j;j -= lowbit( j ) )
			ans += t[i][j];
	return ans;
}

modify( x1, y1, k );
query( x2, y2 ) - query( x2, y1 - 1 ) - query( x1 - 1, y2 ) + query( x1 - 1, y1 - 1 );

离线树状数组

离线树状数组求区间不同数的个数/值和

都是将询问按照 l , r \rm l,r l,r排序,然后记录 i i i的上一个/下一个位置

将指针拨到询问的端点处,删去上一个位置/加入下一个位置

从而做到 1 1 1的个数差,满足不同数只记录一次的要求,自然树状数组就能维护

例题的最后两题就是如此,看代码比较清晰能够理解

例题

POJ:stars

POJ2352

题目已经保证了 y y y递增,那么树状数组维护 x x x,每次查询比 x x x小的星星有多少个即可

#include <cstdio>
#include <iostream>
using namespace std;
#define maxn 32005
int n, N;
int ans[maxn], t[maxn], x[maxn], y[maxn];


int lowbit( int i ) { return i & -i; }

void modify( int i ) {
	for( ;i <= N;i += lowbit( i ) ) t[i] ++;
}

int query( int i ) {
	int ret = 0;
	for( ;i;i -= lowbit( i ) ) ret += t[i];
	return ret;
}

int main() {
	scanf( "%d", &n );
	for( int i = 1;i <= n;i ++ ) {
		scanf( "%d %d", &x[i], &y[i] );
		x[i] ++, y[i] ++, N = max( N, x[i] );
		//x,y值域包含0 树状数组不能从0开始 所以整体+1
    }
	for( int i = 1;i <= n;i ++ )
		ans[query( x[i] )] ++, modify( x[i] );
	for( int i = 0;i < n;i ++ )
		printf( "%d\n", ans[i] );
	return 0;
}

MooFest

POJ1990

v v v排序,维护两个树状数组,一个是小于当前位置的位置和,一个是大于当前位置的位置和,这样就避免了距离带的绝对值

#include <cstdio>
#include <iostream>
#include <algorithm>
using namespace std;
#define int long long
#define maxn 20005
struct node {
	int val, pos;
	node(){}
	node( int Val, int Pos ) {
		val = Val, pos = Pos;
	}
}cow[maxn];
struct Node {
	int cnt, sumd;
	Node(){}
	Node( int Cnt, int Sumd ) {
		cnt = Cnt, sumd = Sumd;
	}
}t1[maxn], t2[maxn];
int n, N;

int lowbit( int i ) { return i & -i; }

void modify1( int i, int val ) {
	for( ;i <= N;i += lowbit( i ) )
		t1[i].cnt ++, t1[i].sumd += val;
}

void modify2( int i, int val ) {
	for( ;i;i -= lowbit( i ) )
		t2[i].cnt ++, t2[i].sumd += val;
}

Node query1( int i ) {
	Node ans( 0, 0 );
	for( ;i;i -= lowbit( i ) )
		ans.cnt += t1[i].cnt, ans.sumd += t1[i].sumd;
	return ans;
}

Node query2( int i ) {
	Node ans( 0, 0 );
	for( ;i <= N;i += lowbit( i ) )
		ans.cnt += t2[i].cnt, ans.sumd += t2[i].sumd;
	return ans;
}

bool cmp( node x, node y ) { return x.val < y.val; } 

signed main() {
	scanf( "%lld", &n );
	for( int i = 1;i <= n;i ++ ) {
		scanf( "%lld %lld", &cow[i].val, &cow[i].pos );
		N = max( cow[i].pos, N );
	}
	N ++;
	sort( cow + 1, cow + n + 1, cmp );
	int ans = 0;
	for( int i = 1;i <= n;i ++ ) {
		Node t = query1( cow[i].pos );
		ans += ( cow[i].pos * t.cnt - t.sumd ) * cow[i].val;
		t = query2( cow[i].pos );
		ans += ( t.sumd - t.cnt * cow[i].pos ) * cow[i].val;
		modify1( cow[i].pos, cow[i].pos );
		modify2( cow[i].pos, cow[i].pos );
	}
	printf( "%lld\n", ans );
	return 0;
}

[SDOI2009]HH的项链

Luogu1972

离线树状数组求区间不同数个数

将询问按 l l l排序,对于每个位置 i i i记录下一个与该位置值相等的位置,每一次到 i i i就把下一次的位置加进去

询问区间左端点以前的自然都要加,这样区间查询相减,就知道下一次的位置在不在区间内,就恰好为 1 1 1

#include <cstdio>
#include <algorithm>
using namespace std;
#define maxn 1000005
struct node {
	int l, r, id;
}q[maxn];
int n, m;
int a[maxn], t[maxn], lst[maxn], nxt[maxn], ans[maxn];
bool vis[maxn];

void read( int &x ) {
	x = 0; char s = getchar();
	while( s < '0' or s > '9' ) s = getchar();
	while( '0' <= s and s <= '9' ) {
		x = ( x << 1 ) + ( x << 3 ) + ( s ^ 48 );
		s = getchar();
	}
}

int lowbit( int i ) { return i & -i; }

void add( int i ) {
	for( ;i < maxn;i += lowbit( i ) ) t[i] ++;
}

int query( int i ) {
	int ret = 0;
	for( ;i;i -= lowbit( i ) ) ret += t[i];
	return ret;
}

int main() {
	read( n );
	for( int i = 1;i <= n;i ++ ) read( a[i] );
	read( m );
	for( int i = 1;i <= m;i ++ ) 
		read( q[i].l ), read( q[i].r ), q[i].id = i;
	sort( q + 1, q + m + 1, []( node x, node y ) { return x.l < y.l; } );
	for( int i = 1;i <= n;i ++ )
		if( ! vis[a[i]] ) add( i ), vis[a[i]] = 1;
	for( int i = n;i;i -- ) {
		if( ! lst[a[i]] ) nxt[i] = maxn;
		else nxt[i] = lst[a[i]];
		lst[a[i]] = i;
	}
	int pos = 1;
	for( int i = 1;i <= m;i ++ ) {
		while( pos < q[i].l ) add( nxt[pos] ), pos ++;
		ans[q[i].id] = query( q[i].r ) - query( q[i].l - 1 );
	}
	for( int i = 1;i <= m;i ++ )
		printf( "%d\n", ans[i] );
	return 0;
}

Turing Tree

HDU3333

离线树状数组求区间不同数的和

与不同数的个数一致的思路,这里值域比较大,就记录下标

按照 r r r排序也可

#include <map>
#include <cstdio>
#include <algorithm>
using namespace std;
#define maxn 30005
#define maxm 100005
#define int long long
struct node {
	int l, r, id;
}q[maxm];
map < int, int > lst;
int T, n, m;
int a[maxn], t[maxn], ans[maxm];

int lowbit( int i ) { return i & -i; }

void add( int i, int val ) {
	for( ;i <= n;i += lowbit( i ) ) t[i] += val;
}

int query( int i ) {
	int ret = 0;
	for( ;i;i -= lowbit( i ) ) ret += t[i];
	return ret;
}

signed main() {
	scanf( "%lld", &T );
	while( T -- ) {
		scanf( "%lld", &n );
		for( int i = 1;i <= n;i ++ )
			scanf( "%lld", &a[i] ), t[i] = 0;
		scanf( "%lld", &m );
		for( int i = 1;i <= m;i ++ )
			scanf( "%lld %lld", &q[i].l, &q[i].r ), q[i].id = i;
		sort( q + 1, q + m + 1, []( node x, node y ) { return x.r < y.r; } );
		lst.clear();
		int j = 1;
		for( int i = 1;i <= m;i ++ ) {
			for( ;j <= q[i].r;j ++ ) {
				if( lst[a[j]] ) add( lst[a[j]], -a[j] );
				add( j, a[j] ), lst[a[j]] = j;
			}
			ans[q[i].id] = query( q[i].r ) - query( q[i].l - 1 );
		}
		for( int i = 1;i <= m;i ++ )
			printf( "%lld\n", ans[i] );
	}
	return 0;
}