欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

Top K问题

程序员文章站 2024-03-15 22:04:48
...

Top K问题

Top K问题在数据分析中非常普遍的一个问题(在面试中也经常被问到),比如:

从20亿个数字的文本中,找出最大的前100个。

解决Top K问题有两种思路,

  • 最直观:小顶堆(大顶堆 -> 最小100个数);
  • 较高效:Quick Select算法。

LeetCode上有一个问题215. Kth Largest Element in an Array,类似于Top K问题。

1. 堆

小顶堆(min-heap)有个重要的性质——每个结点的值均不大于其左右孩子结点的值,则堆顶元素即为整个堆的最小值。JDK中PriorityQueue实现了数据结构堆,通过指定comparator字段来表示小顶堆或大顶堆,默认为null,表示自然序(natural ordering)。

小顶堆解决Top K问题的思路:小顶堆维护当前扫描到的最大100个数,其后每一次的扫描到的元素,若大于堆顶,则入堆,然后删除堆顶;依此往复,直至扫描完所有元素。Java实现第K大整数代码如下:

public int findKthLargest(int[] nums, int k) {
  PriorityQueue<Integer> minQueue = new PriorityQueue<>(k);
  for (int num : nums) {
    if (minQueue.size() < k || num > minQueue.peek())
      minQueue.offer(num);
    if (minQueue.size() > k)
      minQueue.poll();
  }
  return minQueue.peek();
}

时间复杂度:n*logK

速记:堆排的时间复杂度是n*logn,这里相当于只对前Top K个元素建堆排序,想法不一定对,但一定有助于记忆。

适用场景
实现的过程中,我们先用前K个数建立了一个堆,然后遍历数组来维护这个堆。这种做法带来了三个好处:(1)不会改变数据的输入顺序(按顺序读的);(2)不会占用太多的内存空间(事实上,一次只读入一个数,内存只要求能容纳前K个数即可);(3)由于(2),决定了它特别适合处理海量数据。

这三点,也决定了它最优的适用场景。

2. Quick Select

Quick Select [1]脱胎于快排(Quick Sort),两个算法的作者都是Hoare,并且思想也非常接近:选取一个基准元素pivot,将数组切分(partition)为两个子数组,比pivot大的扔左子数组,比pivot小的扔右子数组,然后递推地切分子数组。Quick Select不同于Quick Sort的是其没有对每个子数组做切分,而是对目标子数组做切分。其次,Quick Select与Quick Sort一样,是一个不稳定的算法;pivot选取直接影响了算法的好坏,worst case下的时间复杂度达到了O(n2)O(n2)。下面给出Quick Sort的Java实现:

public void quickSort(int arr[], int left, int right) {
  if (left >= right) return;
  int index = partition(arr, left, right);
  quickSort(arr, left, index - 1);
  quickSort(arr, index + 1, right);
}

// partition subarray a[left..right] so that a[left..j-1] >= a[j] >= a[j+1..right]
// and return index j
private int partition(int arr[], int left, int right) {
  int i = left, j = right + 1, pivot = arr[left];
  while (true) {
    while (i < right && arr[++i] > pivot)
      if (i == right) break;
    while (j > left && arr[--j] < pivot)
      if (j == left) break;
    if (i >= j) break;
    swap(arr, i, j);
  }
  swap(arr, left, j);  // swap pivot and a[j]
  return j;
}

private void swap(int[] arr, int i, int j) {
  int tmp = arr[i];
  arr[i] = arr[j];
  arr[j] = tmp;
}

Quick Select的目标是找出第k大元素,所以

  • 若切分后的左子数组的长度 > k,则第k大元素必出现在左子数组中;
  • 若切分后的左子数组的长度 = k-1,则第k大元素为pivot;
  • 若上述两个条件均不满足,则第k大元素必出现在右子数组中。

Quick Select的Java实现如下:

public int findKthLargest(int[] nums, int k) {
  return quickSelect(nums, k, 0, nums.length - 1);
}

// quick select to find the kth-largest element
public int quickSelect(int[] arr, int k, int left, int right) {
  if (left == right) return arr[right];
  int index = partition(arr, left, right);
  if (index - left + 1 > k)
    return quickSelect(arr, k, left, index - 1);
  else if (index - left + 1 == k)
    return arr[index];
  else
    return quickSelect(arr, k - index + left - 1, index + 1, right);

}

上面给出的代码都是求解第k大元素;若想要得到Top K元素,仅需要将代码做稍微的修改:比如,扫描完成后的小顶堆对应于Top K,Quick Select算法用中间变量保存Top K元素。

时间复杂度:n

速记:记住就行,基于partition函数的时间复杂度比较难证明,从来没考过。

适用场景
对照着堆排的解法来看,partition函数会不断地交换元素的位置,所以它肯定会改变数据输入的顺序;既然要交换元素的位置,那么所有元素必须要读到内存空间中,所以它会占用比较大的空间,至少能容纳整个数组;数据越多,占用的空间必然越大,海量数据处理起来相对吃力。

但是,它的时间复杂度很低,意味着数据量不大时,效率极高。

3. 参考资料

[1] Hoare, Charles Anthony Richard. “Algorithm 65: find.” Communications of the ACM 4.7 (1961): 321-322.
[2] James Aspnes, QuickSelect.
[3] https://www.cnblogs.com/en-heng/p/6336625.html
[4] https://blog.csdn.net/luochoudan/article/details/53736752

5.4 完整代码
import java.util.PriorityQueue;

/**
 * @author leahy(583310958 @ qq.com)
 * @date 2019/11/15 21:08
 */
public class TopK {
    public static void main(String[] args) {
        int[] datas = {2,4,5,0,1,11,45,6,10,57,30};
        int[] topK = findKthLargest03(datas, 5);
        for(int k : topK) {
            System.out.println(k);
        }
    }

    /**
     * 使用小堆顶
     * JDK中自带的PriorityQueue
     */
    public static int[] findKthLargest01(int[] nums, int k) {
        int[] result = new int[k];
        PriorityQueue<Integer> minQueue = new PriorityQueue<>();
        for (int num : nums) {
            if (minQueue.size() < k || num > minQueue.peek()) {
                minQueue.offer(num);
            }
            if (minQueue.size() > k) {
                minQueue.poll();
            }
        }
        int  i = 0;
        while (minQueue.size() != 0) {
            result[i] = minQueue.poll();
            i++;
        }
        return result;
    }

    /**
     * 使用小堆顶
     * 自己造*
     */
    public static int[] findKthLargest02(int[] nums, int k) {
        //初始化一个含有k个元素的数组
        int[] result = new int[k];
        for (int i = 0; i < k; i++) {
            result[i] = nums[i];
        }
        //构造最小堆
        for(int i = result.length/2 -1; i >= 0; i--) {
            buildHeap(result,i,result.length);
        }
        //更新迭代,得到TopK
        for (int j = k; j < nums.length; j++) {
            int temp = result[0];
            if (nums[j] > temp) {
                result[0] = nums[j];
                buildHeap(result, 0, result.length);
            }
        }
        return result;
    }
    public static void buildHeap(int[] nums, int index, int length) {
        int left = index * 2 + 1;
        int right = index * 2 + 2;
        int largest = index;
        if (left < length && nums[left] < nums[largest]) {
            largest = left;
        }
        if (right < length && nums[right] < nums[largest]) {
            largest = right;
        }
        if (index != largest){
            Swap(nums, largest, index);
            buildHeap(nums, largest, length);
        }
    }
    public static void Swap(int[] nums, int i, int j) {
        int temp = nums[i];
        nums[i] = nums[j];
        nums[j] = temp;
    }

    /**
     * 采用快排的方法
     * 不稳定
     */
    public static int[] findKthLargest03(int[] nums, int k) {
        int n = quickSelect(nums, k, 0,nums.length - 1);
        int[] result = new int[k];
        for (int i = 0; i < k; i++) {
            result[i] = nums[i];
        }
        return result;
    }
    public static int quickSelect(int[] nums, int k, int left, int right) {
        if (left == right)
            return right;
        int index = partition(nums, left, right);
        if(index - left + 1 > k) {
            return quickSelect(nums, k, left, index - 1);
        }
        else if (index - left + 1== k) {
            return index;
        }
        else
            return quickSelect(nums, k -index + left - 1, index + 1, right);
    }
    public static int partition(int[] nums, int left, int right) {
        int partitionIndex = left;
        for (int i = left + 1; i <= right; i++) {
            if (nums[i] > nums[left]) {
                partitionIndex++;
                if (partitionIndex != i) {
                    Swap(nums, partitionIndex, i);
                }
            }
        }
        Swap(nums, partitionIndex, left);
        return partitionIndex;
    }
}