【面试必会】一文搞懂 TopK 问题及其变种
这是我在面试腾讯时遇到的真实面试题,在很多面经中也能看到它的身影,今天我们就来彻底地搞懂它!
问题描述
如何从 10w 的数据中找到最大的 100 个数?
首先看问题,10w 的数据,在堆上建个数组暴力求是没有问题的,要找最大的 100 个数,那么先从最简单最暴力的方法开始。
1. 排序法
众所周知,快速排序和堆排序的时间复杂度都可以达到 ,我们只要给 10w 数据排个序,然后取出前 100 个就好了。这种方法很暴力,在数据总数不是很大时确实可以使用,比如100个里面取前20个;当然,面试时我们只需简单地提一下这种解法,就可以说下一种优化方法了。至于排序,不是本文的重点。
接下来考虑优化,我们只需要前 100 个,为什么要把全部数据排序呢?
2. 局部排序法
我们回忆一下冒泡排序和选择排序的过程,在前 k 次循环中,可以得出前 k 个最大/最小值。
以冒泡排序(降序)为例:
for(int i = 0; i < n; i++) {
for (int j = 0; j < n-i-1; j++) {
if (arr[j] < arr[j+1])
swap(arr, j, j+1); // 交换 arr[j] 和 arr[j+1]
}
}
因此在这里,我们正好利用这两种排序算法的特性,简单写下代码:
// 我们只需要把最外层的 n 换为 k
for(int i = 0; i < k; i++) {
for (int j = 0; j < n-i-1; j++) {
//...
}
}
这样子,就能获得最大的前 k 个数,并且位于 arr 中的前 k 个位置,这样的时间复杂度就变为了 。
简单比较下前两种方法的时间复杂度: 和 ,到低哪个好,得根据 K 和 N 的大小来看,如果 K 较小(K <= ) 的情况下,我们可以采用局部排序法。
3. Partition
回忆一下快速排序,快排中的每一步,都是将待排数据分做两组,其中一组的数据的任何一个数都比另一组中的任何一个大,然后再对两组分别做类似的操作,然后继续下去…
如下图,将 arr
中的数据分为小于 k
和大于k
两部分:
接下来,我们来看怎么利用这种思想求出最大的K
个数。
我们假设存在一个数组S
,从中随意挑出了一个数 X
,然后将数组 S 分为两部分:
- A:大于等于X
- B:小于X
如下图所示,我们对数组 S 进行 Partition
操作,可以得到两种情况:
-
如果
A
的个数大于K
,那么数组S
的最大K
个数,就是A
中的最大K
个数;这个很好理解,相当于说
年级
(S)前十名
(K)一定是年级前五十名
(A)中的前十名(K) -
如果
A
的个数小于K
,我们就需要在B
中找到剩余的部分,也就是A
+B
中的K-|A|
个;同样的,
年级
(S)前十名
(K)一定是年级前三名
(A)加上年级4-100名
(B)中的前7名
(K-|A|
);
如果上面这部分还没理解,可以参考下方这个小例子,如果理解了,跳过即可:
我们只需重复上面的操作,递归直到找到前K
个数即可, 这样的平均时间复杂度为 。
这里附一份伪代码:
我根据这份伪代码简单写了下代码:(Java实现,但以通用方式来写,对于cpp、go都有参考价值)
建议大家一定要自己动手实现,光看代码是不够的,万一面试官让你手写代码你就傻眼了。另外,这份代码为了好理解,很多地方实际上是不规范的,比如变量名用大写字母等等,这些大家在写的时候是可以想办法去优化的。
public int[] KBig(int[] S, int K) {
if (K <= 0) {
return new int[0];
}
if (S.length <= K) {
return S;
}
Sclass sclass = Partition(S);
return contact(KBig(sclass.Sa, K), KBig(sclass.Sb, K - sclass.Sa.length));
}
public Sclass Partition(int[] S) {
Sclass sclass = new Sclass();
int p = S[0]; // 省略了随机选择元素的过程
for (int i = 1; i < S.length; i++) {
if (S[i] > p) {
sclass.Sa = append(sclass.Sa, S[i]);
} else {
sclass.Sb = append(sclass.Sb, S[i]);
}
}
if (sclass.Sa.length < sclass.Sb.length) {
sclass.Sa = append(sclass.Sa, p);
} else {
sclass.Sb = append(sclass.Sb, p);
}
return sclass;
}
注意到伪代码中返回了两个数组,我们这里用一个类来存这两个数组:
class Sclass { // 单纯用来存储两个数组
int[] Sa = new int[0];
int[] Sb = new int[0];
}
辅助函数:
/**
* 在数组 arr 的末尾插入值 value
* @param arr 数组
* @param value 值
* @return 返回插入后的数组
*/
int[] append(int[] arr, int value) {
int[] res = new int[arr.length + 1];
System.arraycopy(arr, 0, res, 0, arr.length);
res[arr.length] = value;
return res;
}
/**
* 将两个数组连接到一起
* @param a 数组a
* @param b 数组b
* @return 返回连接后的数组
*/
public int[] contact(int[] a, int[] b) {
int[] res = new int[a.length + b.length];
for (int i = 0; i < a.length; i++) { // 通用的拷贝方式
res[i] = a[i];
}
// 在 java 中实际上可以通过 System.arraycopy 完成拷贝
System.arraycopy(b, 0, res, a.length, b.length);
return res;
}
当你写完代码,测试一下就会发现,实际上这种方法返回的最大的K
个数是没有排序的(其实题目也没有要求你排序,且如果你对Partition
的过程很清楚的话, 你也很容易知道这里返回的是无序的最大K个数)我们需要考虑清楚应用场景,有些场景没有排序要求,有些场景有,要学会选择。
4. 二分搜索
我们要找数组S
中最大的K
个数,那么如果我们知道了第K
大的数,事情会变得简单吗?聪明的读者可能已经发现了,如果我们知道了数组S
中第K
大的数p
,那么我们只需遍历一遍数组,就能找到最大的K
个数。(即所有大于等于p
的数),这一步的时间复杂度为 。
有读者可能会问,如果等于
p
的值有多个,这样遍历一遍取出来的数多于K
个,怎么办呢?事实上解决的办法有很多,我这里简单说一种,遍历的时候只把大于
p
的数取出来,最后根据大于p
的数和K
的差值,补相应的p
就好了。例子:
S = [1, 2, 3, 3, 5],p = 3,K = 2
;即我们知道第K
大的数p
为 3,我们遍历一遍 S,把所有大于p
的数取出来,即[5]
,接下来补K- [5].size() = 1
个p
,即[5,3]
就是最大的 K 个数。
回到我们的二分搜索方法中来,我们需要在S
中找到第K
大的数,伪代码如下:
- Vmax:数组S中的最大值
- Vmin:数组S中的最小值
- delta:比
所有N个数中的任意两个不相等的元素差值的最小值
小。如果所有元素都是整数, delta可以取值0.5。
整个算法的时间复杂度为 。在数据平均分布的情况下,时间复杂度为 $ O(N*log_2N) $。
在整数的情况下,可以从另一个角度来看这个算法。假设所有整数的大小都在 之间,也就是说所有整数在二进制中都可以用
m bit
来表示(从低位到高位,分别用0, 1, ..., m-1
标记)。我们可以先考察在二进制位的第(m-1)
位,将N个整数按该位为1
或者0
分成两个部分。也就是将整数分成取值为 和 两个区间。
前一个区间中的整数第(m-1)
位为0
,后一个区间中的整数第(m-1)
位为1
。如果该位为1的整数个数A
大于等于K
,那么,在所有该位为1
的整数中继续寻找最大的K
个。否则,在该位为0
的整数中寻找最大的K-A
个。接着考虑二进制位第(m-2)
位,以此类推。思路跟上面的浮点数的情况本质上一样。
5. BFPRT算法
这个算法比较复杂,我们这里不做详细介绍,简单说下, 也是类似快速排序的思想,但是能从n个元素的序列中选出第k
大/小的元素,且保证最坏时间复杂度为 。
为什么 的算法不讲,要去讲那些看起来更 “慢” 的算法呢?要注意,我们通常讲的时间复杂度是
平均
/最差
,而且是忽略掉系数的,真实应用场景下还要考虑是否容易实现(过于复杂的可能频繁出bug
得不偿失),还要考虑各种各样的问题,并不是无脑选择时间复杂度低的方法。
这个方法配合我们前面所说的,已知数组S
中第K
大的数p
,我们只需再遍历一遍数组,就能找到最大的K
个数。这一步的时间复杂度也为 。
所以总的时间复杂度就是 。
算法步骤:
-
将n个元素每5个一组,分成
n/5
(上界)组。 -
取出每一组的中位数,任意排序方法,比如插入排序。
-
递归的调用
selection
算法查找上一步中所有中位数的中位数,设为x
,偶数个中位数的情况下设定为选取中间小的一个。 -
用
x
来分割数组,设小于等于x
的个数为k
,大于x
的个数即为n-k
。 -
若
i==k
,返回x
;若i!=k
,在大于x
的元素中递归查找第i-k
小的元素。终止条件:n=1
时,返回的即是i
小元素。
6. 最大最小堆
我们前面谈到的解法有个共同的地方,如果数据量较大时,就得对数据访问多次。
那么如果面试官问的不是从 10w 中找100个数,而是10亿呢? 这个时候数据是不能一次性读入内存的,所以我们要尽可能少的遍历所有数据。
回忆我们的堆排序,我们需要维护一个最大堆/最小堆,关键点就在这里了。我们可以从100亿个数据中取出前K
个,然后用这K
个数建立一个最小堆,之后去遍历所有数据,每取出一个数,如果大于当前堆中的最小值,就替换掉当前的最小堆中的最小值,然后维护堆的秩序,只需遍历所有数据一次,我们就能获得有序的最大 K 个数
。维护堆的时间复杂度为 ,所以算法总体的时间复杂度为 。
啰嗦一句,我们这里是用最小堆,去存最大的
k
个数,为什么不用最大堆来存呢?因为更新的时候又得调换下顺序,没有必要多此一举。
接下来我们详细说说算法该怎么实现,对堆排序熟悉的同学可能已经可以自己写出来了,那么可以跳过这部分。
我们使用一个数组H[]
来建立一个K=8
的堆:
我们知道,堆中的每个元素H[i]
,它的父亲结点是H[i/2]
,左孩子结点是H[2*i+ 1]
,右孩子结点是H[2*i+2]
。每新考虑一个数X
,需要进行的更新操作伪代码如下:
解读下伪代码,一开始进行判断X
是否大于当前的堆里面最小值,如果比这个堆的最小值还小,那就不用看了,肯定不是最大的K
个数之一;如果是大于最小值,那么就替换掉最小值,如下图所示:
然后我们就要维护堆的秩序了,依次将X
跟它的左右孩子进行比较,如果比它们大,就要交换,否则不动,假设X
大于H[1]
,那么X
就要跟H[1]
交换:
交换完后,p=q
,所以接下来会继续判断X
和H[3]
的大小,假设X
小于H[3]
,那么就X
就停止于此,结束循环。
7. 总结
方法 | 时间复杂度 | 特点 |
---|---|---|
排序法 | 实现简单,数据量小,对速度要求不敏感 | |
局部排序法 | 实现简单,数据量小,且对速度不敏感时, 时可以考虑使用 |
|
Partition | 速度快,返回数据无序 | |
二分搜索 | 速度较快,特定场景下可以使用位来实现 | |
BFPRT | 实际效果并没有想象中的好 | |
最大最小堆 | 支持超大数据量,且可更新,有序 |
参考书籍:《编程之美》