欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

树上倍增法求最近公共祖先LCA

程序员文章站 2022-04-04 20:04:07
...

LCA,最近公共祖先,这个东西有很多作用,因此,如何高效求出LCA就成了一个热点的讨论话题。

树上倍增法求最近公共祖先LCA

下面所有的讨论都以图中这棵树为例子。

先来了解下什么是倍增吧,倍增其实就是二分的逆向,二分是逐渐缩小范围,而倍增是成倍扩大。这里的倍增借用二进制来表达更容易理解;倍增的做法是先求出20,21,22,,然后任意一个数字都可以用20,21,22,相加来表示,就像给你32个1,你能表示出32-bit 中的任意一个二进制一样。

倍增有什么好处呢,好处就是!倍增是一种优化手段,能提升查找等操作的效率的手段,其提升效率的原因就是二进制思想,提升的幅度为O(n)O(logn),具体的解释可以参照树状数组简单易懂的详解,数组数组的思想就是基于倍增来实现的。

这里为什么要说成树上倍增呢?因为这个算法的操作都是在树上完成的,没错,求LCA的方法还有很多,比如RMQ-ST算法也可以做,这个算法的思想也是倍增,只不过这个倍增体现在区间上,而树上倍增法求最近公共祖先LCA的倍增体现在树的深度上。

先说说朴素的做法,求两个结点的最近公共祖先,我会让一个结点先向上走到根,并记录下它走的路径,然后然让另一个结点也向上往根走,边走边在先前记录的路径中查找是否存在该结点。举个例子,求lca(3, 6),先让3走到1,路径为3, 2, 1,然后让6走到16在序列[3, 2, 1]中查找,没有找到,继续走,走到44在序列[3, 2, 1]中查找,没有找到,继续走,走到22在序列[3, 2, 1]中查找,找到了,那么lca(3, 6) = 2

分析下朴素做法的时间复杂度,算法中需要让两个结点依次走到根,且在一个结点移动的过程中还需要在路径序列中查找;假设树有n个结点,由于树可能退化成链,因此从某一个结点移动到根这个操作的时间复杂度为O(n),而查找这个操作可以使用set这一类容器,故时间复杂度为O(logn),因此朴素算法求一次LCA的时间复杂度为O(nlogn),假设要多次求LCA,这个时间复杂度显然是不能接受的。

然而,受剑指Offer66题之每日6题 - 第六天中第六题:两个链表的第一个公共结点 的启发,在树上求LCA和在两个链表的第一个公共结点是一样的,因此,朴素做法就有三种了,大家可以去剑指Offer66题之每日6题 - 第六天详细了解,这里就不多赘述了。

LCA用得普遍的地方就是求树中两个结点之间的最短路:dis[u, v] = dis[root, u] + dis[root, v] - 2 * dis[root, lca(u, v)]


现在就来好好说下树上倍增法求最近公共祖先LCA的算法了。

思想

算法的思想很简单,同剑指Offer66题之每日6题 - 第六天中第六题两个链表的第一个公共结点中的O(n)的做法一样,把两个结点移动到同一高度,然后一起向根走,一边走一边比较两个结点是否相等就行了。

但是这样做,时间复杂度还是O(n),问题的规模较大时,复杂度还是不能接受,因此,树上倍增就是来提升这个效率的,树上倍增把移动这个操作提速了,原来只能一步一步移动,现在可以移动多步了。

具体是怎么移动的呢?请看完预处理,然后接着看LCA就知道了。

预处理

首先,要预处理出数中每一个结点的深度dep以及到根的距离dis,前面也提到了,树上倍增是树深度的倍增,自然需要每一个结点的dep

然后,要预处理出每一个结点的第2i个祖先pd[u][i],什么意思呢,举个例子就明白了,例如结点11的第20=1个祖先是9,第21=2个祖先是8,第22=4个祖先是1。这一步就是要为倍增提供”零件“。

第一步可以使用dfs预处理出来,第二步,可以使用动态规划处理出来,pd[u][i] = pd[pd[u][i - 1]][i - 1],画个图就理解了。

树上倍增法求最近公共祖先LCA

结点C的第22=4个祖先等于结点C的第21=2个祖先B的第21=2个祖先A。

LCA

预处理完成后,剩下的事情就是向根结点移动了;

第一步求出两个结点之间的高度差,让较深的那个结点移动到另一个结点一样的高度上,如果是朴素算法需要一步一步移动,而树上倍增算法把这个高度差表示成二进制,从而把这个移动转化成二进制的数位上移动,这样子,复杂度一下子就降到了O(logn)。举个例子,高度差diff = 6(110),那么较深的结点先移动2,这时高度差变为4,然后较深的结点移动4,这时两个结点的高度一样了。

第二步就是两个结点同时向根移动,先看看两个结点最远的祖先是否相同,如果相同,说明最近的祖先还可能没出现,于是再看看两个结点第二远的祖先是否相同;如果两个结点最远的祖先不相同,说明这两个结点正在接近最近公共祖先,故把这两个结点同时移动到对应的祖先处。以此类推,最终可以得到最近公共祖先。这里距离都是2i,原因在第一步中已经说明。

代码

宏,全局变量

/**
 * 直系祖先,pd[u][0]
 */
#define NUM_PARENT 0
/**
 * 树中结点的最大数目
 */
#define MAXSIZE (40000 + 5)

/**
 * 求二进制中最高一位1的index
 */
#define BITOFBINARY(x) ((int)(log((x) * 1.0) / log(2.0)))

/**
 * 求二进制中最低一位1所表示的数值
 */
int lowbit(int x)
{
    return x & -x;
}

/**
 * 树高的最大幂次
 */
const int MAXDEP = BITOFBINARY(MAXSIZE);

/**
 * 每个结点的深度,距根结点的距离
 */
int dep[MAXSIZE], dis[MAXSIZE];

/**
 * 每个结点的不同深度幂次的祖先
 */
int pd[MAXSIZE][MAXDEP + 1];

预处理

/**
 * 求出每个结点的深度,距离根的距离及它们的直系祖先
 */
void init_dfs(int src)
{
    for (int i = head[src]; i + 1; i = edges[i].next) {
        int to = edges[i].to;

        // 领接表建树,避免重复访问
        if (to == pd[src][NUM_PARENT])
            continue;
        dep[to] = dep[src] + 1;
        dis[to] = dis[src] + edges[i].val;
        pd[to][NUM_PARENT] = src;
        init_dfs(to);
    }
}

/**
 * 动态规划求出每个结点不同距离的祖先
 */
void init_redouble()
{
    for (int power = 1; power <= MAXDEP; ++power)
        for (int i = 1; i <= n; i++)
            pd[i][power] = pd[pd[i][power - 1]][power - 1];
}

LCA

int lca(int x, int y)
{
    // 始终保持x结点的深度较深
    if (dep[x] < dep[y])
        swap(x, y);

    // 求出高度差,并使x移动到同y一样的高度
    for (int diff = dep[x] - dep[y]; diff; diff -= lowbit(diff))
        x = pd[x][BITOFBINARY(lowbit(diff))];

    // 处理x和y是同一个结点或y是x的祖先这两种情况
    if (x == y)
        return x;

    // x和y一样的高度,同时移动x, y
    for (int i = MAXDEP; i >= 0; --i)
        if (pd[x][i] != pd[y][i])
            x = pd[x][i],
            y = pd[y][i];
    return pd[x][NUM_PARENT];
}

完整代码

这里结合一个题目背景,HDU2586:How far away?,完整地给出代码。

这个题目的意思是:给你n个点,n - 1条边的最小生成树,然后给你m次询问,每次询问树中任意两个结点之间的最短路。

做法是随便令一个结点为根,然后用树上倍增的方法求lca,然后利用dis[u, v] = dis[root, u] + dis[root, v] - 2 * dis[root, lca(u, v)]可以求得答案。

n达到了40000m达到了200,朴素做法或许行得通,但我没试过。

#include <bits/stdc++.h>

using namespace std;

#define MAXSIZE (40000 + 5)
#define NUM_PARENT 0

#define BITOFBINARY(x) ((int)(log((x) * 1.0) / log(2.0)))

typedef struct Edge Edge;

struct Edge {
    int to, val;
    int next;
    Edge() {};
    Edge(int to, int val, int next = -1) :
        to(to), val(val), next(next) {}
};

int n, m;
Edge edges[MAXSIZE * 2];
int head[MAXSIZE];

int lowbit(int x)
{
    return x & -x;
}

void add_edge(int x, int y, int val, int i)
{
    edges[i] = Edge(y, val, head[x]);
    head[x] = i;
}

const int MAXDEP = BITOFBINARY(MAXSIZE);

int dep[MAXSIZE], dis[MAXSIZE];
int pd[MAXSIZE][MAXDEP + 1];

void init_dfs(int src)
{
    for (int i = head[src]; i + 1; i = edges[i].next) {
        int to = edges[i].to;
        if (to == pd[src][NUM_PARENT])
            continue;
        dep[to] = dep[src] + 1;
        dis[to] = dis[src] + edges[i].val;
        pd[to][NUM_PARENT] = src;
        init_dfs(to);
    }
}

void init_redouble()
{
    for (int power = 1; power <= MAXDEP; ++power)
        for (int i = 1; i <= n; i++)
            pd[i][power] = pd[pd[i][power - 1]][power - 1];
}

int lca(int x, int y)
{
    if (dep[x] < dep[y])
        swap(x, y);

    for (int diff = dep[x] - dep[y]; diff; diff -= lowbit(diff))
        x = pd[x][BITOFBINARY(lowbit(diff))];

    if (x == y)
        return x;

    for (int i = MAXDEP; i >= 0; --i)
        if (pd[x][i] != pd[y][i])
            x = pd[x][i],
            y = pd[y][i];
    return pd[x][NUM_PARENT];
}

int main()
{
    int T;
    for (scanf("%d", &T); T--; ) {
        int x, y, val;
        scanf("%d%d", &n, &m);

        int root = 1;

        memset(head, -1, sizeof(head));
        memset(pd, 0, sizeof(pd));
        dis[root] = 0;
        dep[root] = 1;

        for (int i = 0; i < 2 * (n - 1); i += 2) {
            scanf("%d%d%d", &x, &y, &val);
            add_edge(x, y, val, i);
            add_edge(y, x, val, i + 1);
        }

        init_dfs(root);
        init_redouble();

        for (; m--; ) {
            scanf("%d%d", &x, &y);
            printf("%d\n", dis[x] + dis[y] - 2 * dis[lca(x, y)]);
        }
    }
    return 0;
}

复杂度

预处理中,init_dfs的时间复杂度为O(n)init_redouble的时间复杂度为O(nlogn),所以总的复杂度为O(nlogn)

由于倍增算法把树上的移动转为在二进制数位上的移动,故单次lca的时间复杂度为O(logn),可以接受;

相关标签: 二进制