leetcode编程:对数时间查找有序数组中位数
问题描述
输入:两个有序数组nums1,nums2
输出:两个数组的中位数
原问题描述如下
There are two sorted arrays nums1 and nums2 of size m and n respectively.
Find the median of the two sorted arrays. The overall run time complexity should be O(log (m+n)).
Example 1:
nums1 = [1, 3]
nums2 = [2]
The median is 2.0
Example 2:
nums1 = [1, 2]
nums2 = [3, 4]
The median is (2 + 3)/2 = 2.5
大致思路:
如何查找两个数组的第k个元素
比较nums1[k/2]和nums2[k/2],若nums1[k/2]>noms[k/2],则两个数组的第k个元素必定大于第二个数组的前k/2个元素。
此时,问题可以递归为 找nums1和从k/2开始的nums2两个数组的第k/2个元素。
#include <iostream>
#include <vector>
using namespace std;
class Solution {
public:
double findMedianSortedArrays(vector<int>& nums1, vector<int>& nums2) {
int n1 = nums1.size(), n2 = nums2.size();
if ( (n1 + n2) % 2 == 0){
int a = get_kth_from_2_sorted_vectors( (n1+n2)/2 , nums1, nums2);
int b = get_kth_from_2_sorted_vectors( (n1+n2)/2-1, nums1, nums2);
return ( a + b ) / 2.0;
}
else
return get_kth_from_2_sorted_vectors( (n1+n2-1)/2 , nums1, nums2);
}
int get_kth_from_2_sorted_vectors(int k, vector<int>& nums1, vector<int>& nums2){
int n1 = nums1.size(), n2 = nums2.size();
int start1 = 0, start2 = 0;
while (true) {
if (start1 >= n1)
return nums2[start2 + k];
if (start2 >= n2)
return nums1[start1 + k];
if (k == 0)
return nums1[start1] < nums2[start2] ? nums1[start1] : nums2[start2];
int mid = (k-1) / 2;
if (start2 + mid >= n2 || (start1 + mid < n1 && nums1[start1 + mid] < nums2[start2 + mid]))
start1 += ( mid + 1 );
else
start2 += ( mid + 1 );
k = k - mid - 1;
}
return 0;
}
};
int main(){
int array1[] = {1,2};
int array2[] = {3,4};
vector <int> v1 (array1, array1 + sizeof(array1)/sizeof(int));
vector <int> v2 (array2, array2 + sizeof(array2)/sizeof(int));
Solution * s = new Solution();
double a = s->findMedianSortedArrays(v1,v2);
cout << a << endl;
}
改进思路
当两个数组长度之和为偶数时,上述代码调用了两次查找第k大元素的函数,这个时间是可以节省的。
如果你知道两个数组的 (n1+n2)/2-1 大的元素属于哪一个数组,下标是多少,是可以直接算出两个数组的第 (n1+n2)/2 大的元素的。
如果a在nums1的下标是M,那么,a肯定大于nums[0], noms[1],..., noms[M-1](一共M个)
若此时知道a是nums1和nums2的第K小的元素(K可以是0),则a肯定大于两个数组第0小的数,第1小的数,...,第K-1小的数(一共K个)。
又在nums1中有且仅有M个数小于a,则在nums2中,有且仅有K-M个元素小于a(即nums2[0],...,nums2[K-M-1])。
a是两个数组中第K大的元素,如果第K+1大的元素在nums1中,则必为a后面的元素nums1[M+1],如果出现在nums2中,则必为nums2[K-M]。
所以找到第K+1大的元素只要比较nums1[M+1]和nums2[K-M]谁更小就可以了。
#include <iostream>
#include <vector>
using namespace std;
class Solution {
int choose_vector ;
int index ;
public:
double findMedianSortedArrays(vector<int>& nums1, vector<int>& nums2) {
int n1 = nums1.size(), n2 = nums2.size();
if ( (n1 + n2) % 2 == 0){
int a = get_kth_from_2_sorted_vectors( (n1+n2)/2-1, nums1, nums2);
int b = -1;
int index2 = (n1+n2)/2-index - 1;
if (choose_vector == 1){
if (n2 < index2 + 1)
b = nums1[index + 1];
else if (index + 1< n1)
b = nums1[index + 1] < nums2[index2] ? nums1[index+1] : nums2[index2];
else
b = nums2[index2];
}
else {
if (n1 < index2 + 1)
b = nums2 [index + 1];
else if (index + 1< n2)
b = nums2[index + 1] < nums1[index2] ? nums2[index+1] : nums1[index2];
else
b = nums1[index2];
}
return ( a + b ) / 2.0;
}
else{
return get_kth_from_2_sorted_vectors( (n1+n2-1)/2 , nums1, nums2);
}
}
int get_kth_from_2_sorted_vectors(int k, vector<int>& nums1, vector<int>& nums2){
int n1 = nums1.size(), n2 = nums2.size();
int start1 = 0, start2 = 0;
while (true) {
if (start1 >= n1){
this->choose_vector = 2;
this->index = start2 + k;
return nums2[start2 + k];
}
if (start2 >= n2){
this->choose_vector = 1;
this->index = start1 + k;
return nums1[start1 + k];
}
if (k == 0){
if (nums1[start1] < nums2[start2]){
this->choose_vector = 1;
this->index = start1;
return nums1[start1];
}
else{
this->choose_vector = 2;
this->index = start2;
return nums2[start2];
}
}
int mid = (k-1) / 2;
if (start2 + mid >= n2 || (start1 + mid < n1 && nums1[start1 + mid] < nums2[start2 + mid])) {
start1 += ( mid + 1 );
}
else {
start2 += ( mid + 1 );
}
k -= (mid + 1);
}
return 0;
}
};
int main(){
int array1[] = {2,3,4};
int array2[] = {1};
vector <int> v1 (array1, array1 + sizeof(array1)/sizeof(int));
vector <int> v2 (array2, array2 + sizeof(array2)/sizeof(int));
Solution * s = new Solution();
double a = s->findMedianSortedArrays(v1,v2);
cout << a << endl;
}