线段树详解
线段树详解
问题引入
给定一段数组 [ d 1 , d 2 , . . . , d n ] \left [ d_{1},d_{2},...,d_{n} \right ] [d1,d2,...,dn],定义下列操作:
- 单个元素修改:将 d n d_{n} dn加上或减去一个值
- 区间元素修改:给定左闭右开区间 [ a , b ) \left [ a,b \right ) [a,b),将下标在区间 [ a , b ) \left [ a,b \right ) [a,b)中的元素分别加上或减去一个相同的值
- 单个元素查询:查询 d n d_{n} dn的值
- 区间元素查询:给定左闭右开区间 [ a , b ) \left [ a,b \right ) [a,b),返回 ∑ i = a b − 1 d i \sum^{b-1}_{i = a} d_{i} ∑i=ab−1di的值
如果使用简单的for循环去做,那么必将浪费很多时间,此时线段树将是最好的选择
线段树的定义
所谓线段树,就是将一段数组进行二分,一段数组一次分成两组,类似于分割线段,因此得名线段树,下图以8个元素的数组展示线段树的数据结构
以上我们将数组 [ 1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 ] \left [ 1,2,3,4,5,6,7,8 \right ] [1,2,3,4,5,6,7,8]不断的进行二分,得到一颗二叉树,其中二叉树依次按照从上到下从左到右的顺序从1开始编号,并且每个二叉树节点的值就是该二叉树节点所代表区间的和
这种结构就叫做线段树
线段树的构建
为了方便,我们不使用指针式二叉树数据结构,而是用一个节点数组储存二叉数结构
仿照满二叉树的特性,我们也把线段树看成一颗满二叉树
所以线段树也有以下性质:
- 一个编号为 i i i的节点的左子树节点的编号为 i × 2 i \times 2 i×2
- 一个编号为 i i i的节点的右子树节点的编号为 i × 2 + 1 i \times 2 + 1 i×2+1
所以我们可以写出如下代码:
注意我们使用的一切区间都是左闭右开区间,主要是为了方便好写
节点数据结构:
/**
* 节点数据结构
* l: 节点代表区间的区间左端点位置(包括)
* r: 节点代表区间的区间右端点位置(不包括)
* sum: 节点代表区间的区间和
*/
struct TreeNode
{
int l;
int r;
int sum;
};
建立线段树:
/**
* 根据原数组自底向上地建立一颗线段树
* arr: 原始数组
* tree: 线段树的节点数组
* i: 线段树根节点的编号
* l: 原始数组的区间左端点位置(包括)
* r: 原始数组的区间右端点位置(不包括)
*/
void buildTree(int *arr, TreeNode *tree, int i, int l, int r)
{
// 定义节点代表的区间
tree[i].l = l;
tree[i].r = r;
// 如果区间长度为1,则此节点是叶子节点
if (l == r - 1)
{
// 叶子节点的区间和就是单个元素的值
tree[i].sum = arr[l];
}
else
{
//如果不是叶子节点
// mid为区间[l,r)的区间中点
// 左移一位的意义为将这个数乘以2,右移一位的意义为将这个数除以2并取整
int mid = l + ((r - l) >> 1);
// 建立左子树
buildTree(arr, tree, i << 1, l, mid);
// 建立右子树
buildTree(arr, tree, (i << 1) + 1, mid, r);
// 该节点区间和为左子树的区间和加上右子树的区间和,自底向上
tree[i].sum = tree[i << 1].sum + tree[(i << 1) + 1].sum;
}
}
理解了上面的代码之后我们就可以使用我们的buildTree方法了:
int main()
{
// 原始数组
int arr[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15};
// 长度为15 * 4 = 40的节点数组
TreeNode tree[60];
// 建立线段树,根节点的编号为1,区间为[0,15)
buildTree(arr, tree, 1, 0, 15);
return 0;
}
我们来讨论为什么要开长度为40的节点数组,也就是为什么要四倍建树:
上文提到我们把线段树看成满二叉树,满二叉树的节点数取决于树的高度,满二叉树的第一层的节点数为1,第二层为2,依次为4,8…根据等比数列求和公式可得:
Q = 2 n − 1 Q = 2^{n} - 1 Q=2n−1
其中 Q Q Q为满二叉树的节点数, n n n为树的高度
我们要求的就是树的高度,我们知道每次我们将一个数组进行二分的时候,线段树的高度就会+1:
L × 1 2 n = 1 L \times \frac{1}{2^{n}} = 1 L×2n1=1
其中 L L L为原始数组的长度, n n n为将原始数组二分几次能将这个原始数组的长度变成1,然后变形得到
n = l o g 2 L n = log_{2}L n=log2L
得到线段树层数公式:
n = ⌈ l o g 2 L ⌉ + 1 n = \lceil log_{2}L \rceil + 1 n=⌈log2L⌉+1
带入 Q = 2 n − 1 Q = 2^{n} - 1 Q=2n−1中得到:
Q = 2 ⌈ l o g 2 L ⌉ + 1 − 1 Q = 2^{\lceil log_{2}L \rceil + 1} - 1 Q=2⌈log2L⌉+1−1
我们对其进行放缩一点点,消掉取整符号:
Q = 2 ⌈ l o g 2 L ⌉ + 1 − 1 ⩽ 2 l o g 2 L + 2 − 1 = 4 × L − 1 ⩽ 4 × L Q = 2^{\lceil log_{2}L \rceil + 1} - 1 \leqslant 2^{log_{2}L + 2} - 1 = 4 \times L - 1 \leqslant 4 \times L Q=2⌈log2L⌉+1−1⩽2log2L+2−1=4×L−1⩽4×L
证毕,因此我们要四倍建树,才能满足满二叉树的节点编号相关的这一性质
线段树的单点查询
这很容易,就像在二叉搜索树中查找目标节点一样:
/**
* 在线段树中查找目标节点的值
* tree: 线段树的节点数组
* i: 线段树的根节点
* target: 目标节点在原始数组的下标索引值
*/
int queryNode(TreeNode *tree, int i, int target)
{
if (tree[i].l == tree[i].r - 1)
{
// 达到叶子节点,找到了
return tree[i].sum;
}
else
{
// 非叶子节点
if (tree[i << 1].l <= target && target < tree[i << 1].r)
{
// 在左子树里面
return queryNode(tree, i << 1, target);
}
else
{
// 在右子树里面
return queryNode(tree, (i << 1) + 1, target);
}
}
}
线段树的区间查询
/**
* 在线段树中查找目标区间的区间和
* tree: 线段树的节点数组
* i: 线段树的根节点
* l: 目标区间的区间左端点位置(包括)
* r: 目标区间的区间右端点位置(不包括)
*/
int queryInterval(TreeNode *tree, int i, int l, int r)
{
if (tree[i].l >= l && tree[i].r <= r)
{
// 该区间为目标区间的子集,即该区间完全被目标区间所覆盖,直接返回即可
return tree[i].sum;
}
else if (r <= tree[i].l || l >= tree[i].r)
{
// 该区间和目标区间的交集为空集,即该区间和目标区间毫无关联,返回0
return 0;
}
else
{
int sum = 0;
// 如果目标区间和左子树区间有交集,那么就查查左子树
if (l < tree[i << 1].r)
sum += queryInterval(tree, i << 1, l, r);
// 如果目标区间和右子树区间有交集,那么就查查右子树
if (r > tree[(i << 1) + 1].l)
sum += queryInterval(tree, (i << 1) + 1, l, r);
return sum;
}
}
线段树的单点修改
类比单点查询:
/**
* 在线段树中修改目标节点的值
* tree: 线段树的节点数组
* i: 线段树的根节点
* target: 目标节点在原始数组的下标索引值
* k: 增量
*/
void addNode(TreeNode *tree, int target, int i, int k)
{
// 如果是叶子节点,说明我们找到了
if (tree[i].l == tree[i].r - 1)
{
tree[i].sum += k;
return;
}
else
{
if (target < tree[i << 1].r)
addNode(tree, target, i << 1, k); // 查找左子树
else
addNode(tree, target, (i << 1) + 1, k); // 查找右子树
// 因为子树的sum已经修改了,当前节点的sum也要更新
tree[i].sum = tree[i << 1].sum + tree[(i << 1) + 1].sum;
}
}
线段树的区间修改
/**
* 在线段树中修改目标节点的值
* tree: 线段树的节点数组
* i: 线段树的根节点
* l: 目标区间的区间左端点位置(包括)
* r: 目标区间的区间右端点位置(不包括)
* k: 增量
*/
void AddInterval(TreeNode *tree, int l, int r, int i, int k)
{
// 如果是叶子节点,修改sum即可
if (tree[i].l == tree[i].r - 1)
{
tree[i].sum += k;
return;
}
else if (r <= tree[i].l || l >= tree[i].r)
{
// 毫不相关什么都不干
return;
}
else
{
if (l < tree[i << 1].r)
AddInterval(tree, l, r, i << 1, k); // 加左子树
if (r > tree[(i << 1) + 1].l)
AddInterval(tree, l, r, (i << 1) + 1, k); // 加右子树
//更细当前节点的sum值
tree[i].sum = tree[i << 1].sum + tree[(i << 1) + 1].sum;
}
}
好了,目前为止,四个操作我们都实现了,代码也能正常工作,但是我们考虑线段树的区间修改的这个操作:
对原来子数组 [ 2 , 3 , 4 ] \left [2,3,4 \right ] [2,3,4]进行相加,正常使用for循环修改数组只需要修改元素即可,使用了线段树,不光要用修改当前区间还需要修改子区间,效率不是反而下降了吗?
线段树的优化:Pushdown
我们引入一个概念:懒操作
懒操作:如果一个操作需要影响一个关键元素和其他关联元素,但是我们无法确保这些关联元素之后是否会使用,可以将关键元素进行标记,如果以后需要用到这些关联元素则查找关键元素的标记进行更新
因此我们需要修改一下节点数据结构,增加一个lazy字段,代表懒标记
/**
* 节点数据结构
* l: 节点代表区间的区间左端点位置(包括)
* r: 节点代表区间的区间右端点位置(不包括)
* sum: 节点代表区间的区间和
* lazy: 懒标记
*/
struct TreeNode
{
int l;
int r;
int sum;
int lazy;
};
然后我们区间修改就可以这么写代码
/**
* 在线段树中修改目标节点的值
* tree: 线段树的节点数组
* i: 线段树的根节点
* l: 目标区间的区间左端点位置(包括)
* r: 目标区间的区间右端点位置(不包括)
* k: 增量
*/
void AddInterval(TreeNode *tree, int l, int r, int i, int k)
{
if (l <= tree[i].l && r >= tree[i].r)
{
// 当前区间是目标区间的子集,直接记录懒标记,和更新sum,而不去更新子树
tree[i].lazy = k;
tree[i].sum += (tree[i].r - tree[i].l) * k; // 区间长度乘以增量
}
else if (r <= tree[i].l || l >= tree[i].r)
{
// 毫不相关
return;
}
else
{
// 同理
if (l < tree[i << 1].r)
AddInterval(tree, l, r, i << 1, k);
if (r > tree[i << 1].r)
AddInterval(tree, l, r, (i << 1) + 1, k);
tree[i].sum = tree[i << 1].sum + tree[(i << 1) + 1].sum;
}
}
等等,我们是不是忘记使用懒标记了。那我们什么时候该使用懒标记呢,当使用该节点的子节点的时候,我们无法知道子节点是不是已经更新了,此时我们才需要使用懒标记
我们定义使用懒标记的pushdown操作:
/**
* 使用懒标记
* tree: 线段树的节点数组
* i: 节点编号
*/
void pushdown(TreeNode *tree, int i)
{
// 向下子树传递懒标记
tree[i << 1].lazy = tree[i].lazy;
tree[(i << 1) + 1].lazy = tree[i].lazy;
// 更新子树的sum值
tree[i << 1].sum += (tree[i << 1].r - tree[i << 1].l) * tree[i].lazy;
tree[(i << 1) + 1].sum += (tree[(i << 1) + 1].r - tree[(i << 1) + 1].l) * tree[i].lazy;
// 清空懒标记
tree[i].lazy = 0;
}
现在,我们只需要在访问子树的时候,加上pushdown操作即可:
/**
* 节点数据结构
* l: 节点代表区间的区间左端点位置(包括)
* r: 节点代表区间的区间右端点位置(不包括)
* sum: 节点代表区间的区间和
*/
struct TreeNode
{
int l;
int r;
int sum;
int lazy;
};
/**
* 使用懒标记
* tree: 线段树的节点数组
* i: 节点编号
*/
void pushdown(TreeNode *tree, int i)
{
// 向下子树传递懒标记
tree[i << 1].lazy = tree[i].lazy;
tree[(i << 1) + 1].lazy = tree[i].lazy;
// 更新子树的sum值
tree[i << 1].sum += (tree[i << 1].r - tree[i << 1].l) * tree[i].lazy;
tree[(i << 1) + 1].sum += (tree[(i << 1) + 1].r - tree[(i << 1) + 1].l) * tree[i].lazy;
// 清空懒标记
tree[i].lazy = 0;
}
/**
* 根据原数组自底向上地建立一颗线段树
* arr: 原始数组
* tree: 线段树的节点数组
* i: 线段树根节点的编号
* l: 原始数组的区间左端点位置(包括)
* r: 原始数组的区间右端点位置(不包括)
*/
void buildTree(int *arr, TreeNode *tree, int i, int l, int r)
{
// 定义节点代表的区间
tree[i].l = l;
tree[i].r = r;
tree[i].lazy = 0;
// 如果区间长度为1,则此节点是叶子节点
if (l == r - 1)
{
// 叶子节点的区间和就是单个元素的值
tree[i].sum = arr[l];
}
else
{
//如果不是叶子节点
// mid为区间[l,r)的区间中点
// 左移一位的意义为将这个数乘以2,右移一位的意义为将这个数除以2并取整
int mid = l + ((r - l) >> 1);
// 建立左子树
buildTree(arr, tree, i << 1, l, mid);
// 建立右子树
buildTree(arr, tree, (i << 1) + 1, mid, r);
// 该节点区间和为左子树的区间和加上右子树的区间和,自底向上
tree[i].sum = tree[i << 1].sum + tree[(i << 1) + 1].sum;
}
}
/**
* 在线段树中查找目标节点的值
* tree: 线段树的节点数组
* i: 线段树的根节点
* target: 目标节点在原始数组的下标索引值
*/
int queryNode(TreeNode *tree, int i, int target)
{
if (tree[i].l == tree[i].r - 1)
{
// 达到叶子节点,找到了
return tree[i].sum;
}
else
{
// 要使用子树
pushdown(tree, i);
// 非叶子节点
if (tree[i << 1].l <= target && target < tree[i << 1].r)
{
// 在左子树里面
return queryNode(tree, i << 1, target);
}
else
{
// 在右子树里面
return queryNode(tree, (i << 1) + 1, target);
}
}
}
/**
* 在线段树中查找目标区间的区间和
* tree: 线段树的节点数组
* i: 线段树的根节点
* l: 目标区间的区间左端点位置(包括)
* r: 目标区间的区间右端点位置(不包括)
*/
int queryInterval(TreeNode *tree, int i, int l, int r)
{
if (tree[i].l >= l && tree[i].r <= r)
{
// 该区间为目标区间的子集,即该区间完全被目标区间所覆盖,直接返回即可
return tree[i].sum;
}
else if (r <= tree[i].l || l >= tree[i].r)
{
// 该区间和目标区间的交集为空集,即该区间和目标区间毫无关联,返回0
return 0;
}
else
{
int sum = 0;
// 要使用子树
pushdown(tree, i);
// 如果目标区间和左子树区间有交集,那么就查查左子树
if (l < tree[i << 1].r)
sum += queryInterval(tree, i << 1, l, r);
// 如果目标区间和右子树区间有交集,那么就查查右子树
if (r > tree[(i << 1) + 1].l)
sum += queryInterval(tree, (i << 1) + 1, l, r);
return sum;
}
}
/**
* 在线段树中修改目标节点的值
* tree: 线段树的节点数组
* i: 线段树的根节点
* target: 目标节点在原始数组的下标索引值
* k: 增量
*/
void addNode(TreeNode *tree, int target, int i, int k)
{
// 如果是叶子节点,说明我们找到了
if (tree[i].l == tree[i].r - 1)
{
tree[i].sum += k;
return;
}
else
{
// 要使用子树
pushdown(tree, i);
if (target < tree[i << 1].r)
addNode(tree, target, i << 1, k); // 查找左子树
else
addNode(tree, target, (i << 1) + 1, k); // 查找右子树
// 因为子树的sum已经修改了,当前节点的sum也要更新
tree[i].sum = tree[i << 1].sum + tree[(i << 1) + 1].sum;
}
}
/**
* 在线段树中修改目标节点的值
* tree: 线段树的节点数组
* i: 线段树的根节点
* l: 目标区间的区间左端点位置(包括)
* r: 目标区间的区间右端点位置(不包括)
* k: 增量
*/
void addInterval(TreeNode *tree, int l, int r, int i, int k)
{
if (l <= tree[i].l && r >= tree[i].r)
{
// 当前区间是目标区间的子集,直接记录懒标记,和更新sum,而不去更新子树
tree[i].lazy = k;
tree[i].sum += (tree[i].r - tree[i].l) * k; // 区间长度乘以增量
}
else if (r <= tree[i].l || l >= tree[i].r)
{
// 毫不相关
return;
}
else
{
// 要使用子树
pushdown(tree, i);
// 同理
if (l < tree[i << 1].r)
addInterval(tree, l, r, i << 1, k);
if (r > tree[i << 1].r)
addInterval(tree, l, r, (i << 1) + 1, k);
tree[i].sum = tree[i << 1].sum + tree[(i << 1) + 1].sum;
}
}
如果还不理解可以参考一下图片