欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

线段树详解

程序员文章站 2022-07-13 11:19:20
...

线段树详解

问题引入

给定一段数组 [ d 1 , d 2 , . . . , d n ] \left [ d_{1},d_{2},...,d_{n} \right ] [d1,d2,...,dn],定义下列操作:

  1. 单个元素修改:将 d n d_{n} dn加上或减去一个值
  2. 区间元素修改:给定左闭右开区间 [ a , b ) \left [ a,b \right ) [a,b),将下标在区间 [ a , b ) \left [ a,b \right ) [a,b)中的元素分别加上或减去一个相同的值
  3. 单个元素查询:查询 d n d_{n} dn的值
  4. 区间元素查询:给定左闭右开区间 [ a , b ) \left [ a,b \right ) [a,b),返回 ∑ i = a b − 1 d i \sum^{b-1}_{i = a} d_{i} i=ab1di的值

如果使用简单的for循环去做,那么必将浪费很多时间,此时线段树将是最好的选择

线段树的定义

所谓线段树,就是将一段数组进行二分,一段数组一次分成两组,类似于分割线段,因此得名线段树,下图以8个元素的数组展示线段树的数据结构

线段树详解

以上我们将数组 [ 1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 ] \left [ 1,2,3,4,5,6,7,8 \right ] [1,2,3,4,5,6,7,8]不断的进行二分,得到一颗二叉树,其中二叉树依次按照从上到下从左到右的顺序从1开始编号,并且每个二叉树节点的值就是该二叉树节点所代表区间的和

这种结构就叫做线段树

线段树的构建

为了方便,我们不使用指针式二叉树数据结构,而是用一个节点数组储存二叉数结构

仿照满二叉树的特性,我们也把线段树看成一颗满二叉树

所以线段树也有以下性质:

  1. 一个编号为 i i i的节点的左子树节点的编号为 i × 2 i \times 2 i×2
  2. 一个编号为 i i i的节点的右子树节点的编号为 i × 2 + 1 i \times 2 + 1 i×2+1

所以我们可以写出如下代码:

注意我们使用的一切区间都是左闭右开区间,主要是为了方便好写

节点数据结构:

/**
 * 节点数据结构
 * l: 节点代表区间的区间左端点位置(包括)
 * r: 节点代表区间的区间右端点位置(不包括)
 * sum: 节点代表区间的区间和
 */
struct TreeNode
{
    int l;
    int r;
    int sum;
};

建立线段树:

/**
 * 根据原数组自底向上地建立一颗线段树
 * arr: 原始数组
 * tree: 线段树的节点数组
 * i: 线段树根节点的编号
 * l: 原始数组的区间左端点位置(包括)
 * r: 原始数组的区间右端点位置(不包括)
 */
void buildTree(int *arr, TreeNode *tree, int i, int l, int r)
{
    // 定义节点代表的区间
    tree[i].l = l;
    tree[i].r = r;
    // 如果区间长度为1,则此节点是叶子节点
    if (l == r - 1)
    {
        // 叶子节点的区间和就是单个元素的值
        tree[i].sum = arr[l];
    }
    else
    {
        //如果不是叶子节点
        // mid为区间[l,r)的区间中点
        // 左移一位的意义为将这个数乘以2,右移一位的意义为将这个数除以2并取整
        int mid = l + ((r - l) >> 1);
        // 建立左子树
        buildTree(arr, tree, i << 1, l, mid);
        // 建立右子树
        buildTree(arr, tree, (i << 1) + 1, mid, r);
        // 该节点区间和为左子树的区间和加上右子树的区间和,自底向上
        tree[i].sum = tree[i << 1].sum + tree[(i << 1) + 1].sum;
    }
}

理解了上面的代码之后我们就可以使用我们的buildTree方法了:

int main()
{
    // 原始数组
    int arr[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15};
    // 长度为15 * 4 = 40的节点数组
    TreeNode tree[60];
    // 建立线段树,根节点的编号为1,区间为[0,15)
    buildTree(arr, tree, 1, 0, 15);
    return 0;
}

我们来讨论为什么要开长度为40的节点数组,也就是为什么要四倍建树:

上文提到我们把线段树看成满二叉树,满二叉树的节点数取决于树的高度,满二叉树的第一层的节点数为1,第二层为2,依次为4,8…根据等比数列求和公式可得:

Q = 2 n − 1 Q = 2^{n} - 1 Q=2n1

其中 Q Q Q为满二叉树的节点数, n n n为树的高度

我们要求的就是树的高度,我们知道每次我们将一个数组进行二分的时候,线段树的高度就会+1:

L × 1 2 n = 1 L \times \frac{1}{2^{n}} = 1 L×2n1=1

其中 L L L为原始数组的长度, n n n为将原始数组二分几次能将这个原始数组的长度变成1,然后变形得到

n = l o g 2 L n = log_{2}L n=log2L

得到线段树层数公式:

n = ⌈ l o g 2 L ⌉ + 1 n = \lceil log_{2}L \rceil + 1 n=log2L+1

带入 Q = 2 n − 1 Q = 2^{n} - 1 Q=2n1中得到:

Q = 2 ⌈ l o g 2 L ⌉ + 1 − 1 Q = 2^{\lceil log_{2}L \rceil + 1} - 1 Q=2log2L+11

我们对其进行放缩一点点,消掉取整符号:

Q = 2 ⌈ l o g 2 L ⌉ + 1 − 1 ⩽ 2 l o g 2 L + 2 − 1 = 4 × L − 1 ⩽ 4 × L Q = 2^{\lceil log_{2}L \rceil + 1} - 1 \leqslant 2^{log_{2}L + 2} - 1 = 4 \times L - 1 \leqslant 4 \times L Q=2log2L+112log2L+21=4×L14×L

证毕,因此我们要四倍建树,才能满足满二叉树的节点编号相关的这一性质

线段树的单点查询

这很容易,就像在二叉搜索树中查找目标节点一样:

/**
 * 在线段树中查找目标节点的值
 * tree: 线段树的节点数组
 * i: 线段树的根节点
 * target: 目标节点在原始数组的下标索引值
 */
int queryNode(TreeNode *tree, int i, int target)
{
    if (tree[i].l == tree[i].r - 1)
    {
        // 达到叶子节点,找到了
        return tree[i].sum;
    }
    else
    {
        // 非叶子节点
        if (tree[i << 1].l <= target && target < tree[i << 1].r)
        {
            // 在左子树里面
            return queryNode(tree, i << 1, target);
        }
        else
        {
            // 在右子树里面
            return queryNode(tree, (i << 1) + 1, target);
        }
    }
}

线段树的区间查询

/**
 * 在线段树中查找目标区间的区间和
 * tree: 线段树的节点数组
 * i: 线段树的根节点
 * l: 目标区间的区间左端点位置(包括)
 * r: 目标区间的区间右端点位置(不包括)
 */
int queryInterval(TreeNode *tree, int i, int l, int r)
{
    if (tree[i].l >= l && tree[i].r <= r)
    {
        // 该区间为目标区间的子集,即该区间完全被目标区间所覆盖,直接返回即可
        return tree[i].sum;
    }
    else if (r <= tree[i].l || l >= tree[i].r)
    {
        // 该区间和目标区间的交集为空集,即该区间和目标区间毫无关联,返回0
        return 0;
    }
    else
    {
        int sum = 0;
        // 如果目标区间和左子树区间有交集,那么就查查左子树
        if (l < tree[i << 1].r)
            sum += queryInterval(tree, i << 1, l, r);
        // 如果目标区间和右子树区间有交集,那么就查查右子树
        if (r > tree[(i << 1) + 1].l)
            sum += queryInterval(tree, (i << 1) + 1, l, r);
        return sum;
    }
}

线段树的单点修改

类比单点查询:

/**
 * 在线段树中修改目标节点的值
 * tree: 线段树的节点数组
 * i: 线段树的根节点
 * target: 目标节点在原始数组的下标索引值
 * k: 增量
 */
void addNode(TreeNode *tree, int target, int i, int k)
{
    // 如果是叶子节点,说明我们找到了
    if (tree[i].l == tree[i].r - 1)
    {
        tree[i].sum += k;
        return;
    }
    else
    {
        if (target < tree[i << 1].r)
            addNode(tree, target, i << 1, k); // 查找左子树
        else
            addNode(tree, target, (i << 1) + 1, k); // 查找右子树
        // 因为子树的sum已经修改了,当前节点的sum也要更新
        tree[i].sum = tree[i << 1].sum + tree[(i << 1) + 1].sum;
    }
}

线段树的区间修改

/**
 * 在线段树中修改目标节点的值
 * tree: 线段树的节点数组
 * i: 线段树的根节点
 * l: 目标区间的区间左端点位置(包括)
 * r: 目标区间的区间右端点位置(不包括)
 * k: 增量
 */
void AddInterval(TreeNode *tree, int l, int r, int i, int k)
{
    // 如果是叶子节点,修改sum即可
    if (tree[i].l == tree[i].r - 1)
    {
        tree[i].sum += k;
        return;
    }
    else if (r <= tree[i].l || l >= tree[i].r)
    {
        // 毫不相关什么都不干
        return;
    }
    else
    {
        if (l < tree[i << 1].r)
            AddInterval(tree, l, r, i << 1, k); // 加左子树
        if (r > tree[(i << 1) + 1].l)
            AddInterval(tree, l, r, (i << 1) + 1, k); // 加右子树
        //更细当前节点的sum值
        tree[i].sum = tree[i << 1].sum + tree[(i << 1) + 1].sum;
    }
}

好了,目前为止,四个操作我们都实现了,代码也能正常工作,但是我们考虑线段树的区间修改的这个操作:

线段树详解

对原来子数组 [ 2 , 3 , 4 ] \left [2,3,4 \right ] [2,3,4]进行相加,正常使用for循环修改数组只需要修改元素即可,使用了线段树,不光要用修改当前区间还需要修改子区间,效率不是反而下降了吗?

线段树的优化:Pushdown

我们引入一个概念:懒操作

懒操作:如果一个操作需要影响一个关键元素和其他关联元素,但是我们无法确保这些关联元素之后是否会使用,可以将关键元素进行标记,如果以后需要用到这些关联元素则查找关键元素的标记进行更新

因此我们需要修改一下节点数据结构,增加一个lazy字段,代表懒标记

/**
 * 节点数据结构
 * l: 节点代表区间的区间左端点位置(包括)
 * r: 节点代表区间的区间右端点位置(不包括)
 * sum: 节点代表区间的区间和
 * lazy: 懒标记
 */
struct TreeNode
{
    int l;
    int r;
    int sum;
    int lazy;
};

然后我们区间修改就可以这么写代码

/**
 * 在线段树中修改目标节点的值
 * tree: 线段树的节点数组
 * i: 线段树的根节点
 * l: 目标区间的区间左端点位置(包括)
 * r: 目标区间的区间右端点位置(不包括)
 * k: 增量
 */
void AddInterval(TreeNode *tree, int l, int r, int i, int k)
{
    if (l <= tree[i].l && r >= tree[i].r)
    {
        // 当前区间是目标区间的子集,直接记录懒标记,和更新sum,而不去更新子树
        tree[i].lazy = k;
        tree[i].sum += (tree[i].r - tree[i].l) * k; // 区间长度乘以增量
    }
    else if (r <= tree[i].l || l >= tree[i].r)
    {
        // 毫不相关
        return;
    }
    else
    {
        // 同理
        if (l < tree[i << 1].r)
            AddInterval(tree, l, r, i << 1, k);
        if (r > tree[i << 1].r)
            AddInterval(tree, l, r, (i << 1) + 1, k);
        tree[i].sum = tree[i << 1].sum + tree[(i << 1) + 1].sum;
    }
}

等等,我们是不是忘记使用懒标记了。那我们什么时候该使用懒标记呢,当使用该节点的子节点的时候,我们无法知道子节点是不是已经更新了,此时我们才需要使用懒标记

我们定义使用懒标记的pushdown操作:

/**
 * 使用懒标记
 * tree: 线段树的节点数组
 * i: 节点编号
 */
void pushdown(TreeNode *tree, int i)
{
    // 向下子树传递懒标记
    tree[i << 1].lazy = tree[i].lazy;
    tree[(i << 1) + 1].lazy = tree[i].lazy;
    // 更新子树的sum值
    tree[i << 1].sum += (tree[i << 1].r - tree[i << 1].l) * tree[i].lazy;
    tree[(i << 1) + 1].sum += (tree[(i << 1) + 1].r - tree[(i << 1) + 1].l) * tree[i].lazy;
    // 清空懒标记
    tree[i].lazy = 0;
}

现在,我们只需要在访问子树的时候,加上pushdown操作即可:

/**
 * 节点数据结构
 * l: 节点代表区间的区间左端点位置(包括)
 * r: 节点代表区间的区间右端点位置(不包括)
 * sum: 节点代表区间的区间和
 */
struct TreeNode
{
    int l;
    int r;
    int sum;
    int lazy;
};
/**
 * 使用懒标记
 * tree: 线段树的节点数组
 * i: 节点编号
 */
void pushdown(TreeNode *tree, int i)
{
    // 向下子树传递懒标记
    tree[i << 1].lazy = tree[i].lazy;
    tree[(i << 1) + 1].lazy = tree[i].lazy;
    // 更新子树的sum值
    tree[i << 1].sum += (tree[i << 1].r - tree[i << 1].l) * tree[i].lazy;
    tree[(i << 1) + 1].sum += (tree[(i << 1) + 1].r - tree[(i << 1) + 1].l) * tree[i].lazy;
    // 清空懒标记
    tree[i].lazy = 0;
}
/**
 * 根据原数组自底向上地建立一颗线段树
 * arr: 原始数组
 * tree: 线段树的节点数组
 * i: 线段树根节点的编号
 * l: 原始数组的区间左端点位置(包括)
 * r: 原始数组的区间右端点位置(不包括)
 */
void buildTree(int *arr, TreeNode *tree, int i, int l, int r)
{
    // 定义节点代表的区间
    tree[i].l = l;
    tree[i].r = r;
    tree[i].lazy = 0;
    // 如果区间长度为1,则此节点是叶子节点
    if (l == r - 1)
    {
        // 叶子节点的区间和就是单个元素的值
        tree[i].sum = arr[l];
    }
    else
    {
        //如果不是叶子节点
        // mid为区间[l,r)的区间中点
        // 左移一位的意义为将这个数乘以2,右移一位的意义为将这个数除以2并取整
        int mid = l + ((r - l) >> 1);
        // 建立左子树
        buildTree(arr, tree, i << 1, l, mid);
        // 建立右子树
        buildTree(arr, tree, (i << 1) + 1, mid, r);
        // 该节点区间和为左子树的区间和加上右子树的区间和,自底向上
        tree[i].sum = tree[i << 1].sum + tree[(i << 1) + 1].sum;
    }
}
/**
 * 在线段树中查找目标节点的值
 * tree: 线段树的节点数组
 * i: 线段树的根节点
 * target: 目标节点在原始数组的下标索引值
 */
int queryNode(TreeNode *tree, int i, int target)
{
    if (tree[i].l == tree[i].r - 1)
    {
        // 达到叶子节点,找到了
        return tree[i].sum;
    }
    else
    {
        // 要使用子树
        pushdown(tree, i);
        // 非叶子节点
        if (tree[i << 1].l <= target && target < tree[i << 1].r)
        {
            // 在左子树里面
            return queryNode(tree, i << 1, target);
        }
        else
        {
            // 在右子树里面
            return queryNode(tree, (i << 1) + 1, target);
        }
    }
}
/**
 * 在线段树中查找目标区间的区间和
 * tree: 线段树的节点数组
 * i: 线段树的根节点
 * l: 目标区间的区间左端点位置(包括)
 * r: 目标区间的区间右端点位置(不包括)
 */
int queryInterval(TreeNode *tree, int i, int l, int r)
{
    if (tree[i].l >= l && tree[i].r <= r)
    {
        // 该区间为目标区间的子集,即该区间完全被目标区间所覆盖,直接返回即可
        return tree[i].sum;
    }
    else if (r <= tree[i].l || l >= tree[i].r)
    {
        // 该区间和目标区间的交集为空集,即该区间和目标区间毫无关联,返回0
        return 0;
    }
    else
    {
        int sum = 0;
        // 要使用子树
        pushdown(tree, i);
        // 如果目标区间和左子树区间有交集,那么就查查左子树
        if (l < tree[i << 1].r)
            sum += queryInterval(tree, i << 1, l, r);
        // 如果目标区间和右子树区间有交集,那么就查查右子树
        if (r > tree[(i << 1) + 1].l)
            sum += queryInterval(tree, (i << 1) + 1, l, r);
        return sum;
    }
}
/**
 * 在线段树中修改目标节点的值
 * tree: 线段树的节点数组
 * i: 线段树的根节点
 * target: 目标节点在原始数组的下标索引值
 * k: 增量
 */
void addNode(TreeNode *tree, int target, int i, int k)
{
    // 如果是叶子节点,说明我们找到了
    if (tree[i].l == tree[i].r - 1)
    {
        tree[i].sum += k;
        return;
    }
    else
    {
        // 要使用子树
        pushdown(tree, i);
        if (target < tree[i << 1].r)
            addNode(tree, target, i << 1, k); // 查找左子树
        else
            addNode(tree, target, (i << 1) + 1, k); // 查找右子树
        // 因为子树的sum已经修改了,当前节点的sum也要更新
        tree[i].sum = tree[i << 1].sum + tree[(i << 1) + 1].sum;
    }
}

/**
 * 在线段树中修改目标节点的值
 * tree: 线段树的节点数组
 * i: 线段树的根节点
 * l: 目标区间的区间左端点位置(包括)
 * r: 目标区间的区间右端点位置(不包括)
 * k: 增量
 */
void addInterval(TreeNode *tree, int l, int r, int i, int k)
{
    if (l <= tree[i].l && r >= tree[i].r)
    {
        // 当前区间是目标区间的子集,直接记录懒标记,和更新sum,而不去更新子树
        tree[i].lazy = k;
        tree[i].sum += (tree[i].r - tree[i].l) * k; // 区间长度乘以增量
    }
    else if (r <= tree[i].l || l >= tree[i].r)
    {
        // 毫不相关
        return;
    }
    else
    {
        // 要使用子树
        pushdown(tree, i);
        // 同理
        if (l < tree[i << 1].r)
            addInterval(tree, l, r, i << 1, k);
        if (r > tree[i << 1].r)
            addInterval(tree, l, r, (i << 1) + 1, k);
        tree[i].sum = tree[i << 1].sum + tree[(i << 1) + 1].sum;
    }
}

如果还不理解可以参考一下图片

线段树详解

线段树详解

线段树详解

上一篇: 线段树详解

下一篇: 线段树详解