欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

C++用归并算法求逆序对

程序员文章站 2022-05-10 15:12:55
...

原题目:

猫猫 TOM 和小老鼠 JERRY 最近又较量上了,但是毕竟都是成年人,他们已经不喜欢再玩那种你追我赶的游戏,现在他们喜欢玩统计。

最近,TOM 老猫查阅到一个人类称之为“逆序对”的东西,这东西是这样定义的:对于给定的一段正整数序列,逆序对就是序列中 a_i>a_ja 且
i<ji<j 的有序对。知道这概念后,他们就比赛谁先算出给定的一段正整数>序列中逆序对的数目。

输入格式 第一行,一个数 n,表示序列中有 n个数。

第二行 nn 个数,表示给定的序列。序列中每个数字不超过 10^910 9 。

输出格式 输出序列中逆序对的数目。

样例

输入

6
5 4 2 6 3 1

输出

11

我先来介绍一下归并排序

时间复杂度:
O(n log n)

空间复杂度:
O(n)

定义:
归并排序是建立在归并操作上的一种有效的排序算法,该算法采用的是分治法(二分法)。

设n为序列元素的个数,a[i]为其中第i个元素,msort为归并排序的函数名,b[i]为两个相邻的子序列合并完的临时有序序列

1.将一个序列进行二分,直到分解成n个元素(l == r)

如代码:

int mid = (l + r) / 2;//取中间 
if(l == r)//若l == r了,就代表这个子序列就只剩1个元素了,需要返回 
{
    return;
}
else
{
    msort(l, mid);//分成l和中间一段,中间 + 1和r一段(二分) 
    msort(mid + 1, r);
}

2.将已有序的子序列合并(二分到最后的时候,一个元素一定是有序的),得到完全有序的序列

如代码:

int i = l;//i从l开始,到mid,因为现在排序的是l ~ r的区间且要二分合并 
int j = mid + 1;//j从mid + 1开始,到r原因同上
int t = l;//数组b的下标,数组b存的是l ~ r区间排完序的值 
while(i <= mid && j <= r)//同上i,j的解释 
{
    if(a[i] > a[j])//如果前面的元素比后面大(l ~ mid中的元素 > mid + 1 ~ r中的元素)(逆序对出现!!!) 
    { 
        ans += mid - i + 1;//由于l ~ mid和mid + 1 ~ r都是有序序列所以一旦l ~ mid中的元素 > mid + 1 ~ r中的元素而又因为第i个元素 < i + 1 ~ mid那么i + 1 ~ mid的元素都 > 第j个元素。所以+的元素个数就是i ~ mid的元素个数,及mid - i + 1(归并排序里没有这句话,求逆序对里有) 
        b[t++] = a[j];//第j个元素比i ~ mid的元素都小,那么第j个元素是目前最小的了,就放进b数组里 
        ++j;//下一个元素(mid + 1 ~ r的元素小,所以加第j个元素) 
    }
    else
    {
        b[t++] = a[i];//i小,存a[i] 
        ++i;//同理 
    }
}
while(i <= mid)//把剩的元素(因为较大所以在上面没选) 
{
    b[t++] = a[i];//存进去 
    ++i; 
}
while(j <= r)//同理 
{
    b[t++] = a[j];
    ++j;
}

将已有序的b数组赋值到a数组(原数组)里

如代码:

for(register int i = l; i <= r; ++i)//把有序序列b赋值到a里 
{
    a[i] = b[i];
}

手动模拟一下样例(以便更好的理解算法):
二分时,奇数时默认取左边n / 2 + 1,右边取n / 2了(为了方便)

那么我们开始合并(只能同一层合并,并向上传递,我把后面两个while循环在合并中省去了)

1.第四层5和4合并:i = l = 1;mid = (l + r) / 2 = 1;r = 2;j = mid + 1 = 2;5 > 4;逆序对数量 + mid - i + 1 = 1;b数组:4 5 0 0 0 0; j + 1 > r; 退出;a数组:4 5 2 6 3 1;

2.第四层6和3合并:i = l = 4;mid = (l + r) / 2 = 4;r = 5;j = mid + 1 = 5; 6 > 3;逆序对数量 + mid - i + 1 = 2;b数组:4 5 0 3 6 0;j + 1 > r; 退出;a数组:4 5 2 3 6 1;

3.第三层4, 5和2合并: i = l = 1;mid = (l + r) / 2 = 2;r = 3;j = mid + 1 = 2;4 > 2; 逆序对数量 + mid - i + 1 = 4;b数组:2 4 5 3 6 0;j + 1 > r;退出;a数组:2 4 5 3 6 1;

4.第三层3, 6和1合并:i = l = 4;mid = (l + r) / 2 = 5;r = 6;j = mid + 1 = 6;3 > 1;逆序对数量 + mid - i + 1 = 6;b数组:2 4 5 1 3 6;j + 1 > r;退出;a数组:2 4 5 1 3 6;

5.第二层2, 4,5和1, 3,6合并:i = l = 1;mid = (l + r) / 2 = 3;r = 6;j = mid + 1 = 4;2 > 1; 逆序对数量 + mid - i + 1 = 9;b 数组:1 2 4 5 3 6;j + 1;2 < 3;b数组:1 2 4 5 3 6;i + 1;4 > 3;逆序对数量 + mid - i + 1 = 11;b数组:1 2 3 4 5 6;j + 1;4 < 6;b数组:1 2 3 4 5 6;i + 1;5 < 6;b数组:1 2 3 4 5 6;i + 1 > mid;退出; a数组:1 2 3 4 5 6(有序)

code:

#include <bits/stdc++.h>
using namespace std;
int n, a[5000001], b[5000001];
long long ans;
inline void msort(int l, int r)//归并排序
{
    int mid = (l + r) / 2;//取中间 
    if(l == r)//若l == r了,就代表这个子序列就只剩1个元素了,需要返回 
    {
        return;
    }
    else
    {
        msort(l, mid);//分成l和中间一段,中间 + 1和r一段(二分) 
        msort(mid + 1, r);
    }
    int i = l;//i从l开始,到mid,因为现在排序的是l ~ r的区间且要二分合并 
    int j = mid + 1;//j从mid + 1开始,到r原因同上
    int t = l;//数组b的下标,数组b存的是l ~ r区间排完序的值 
    while(i <= mid && j <= r)//同上i,j的解释 
    {
        if(a[i] > a[j])//如果前面的元素比后面大(l ~ mid中的元素 > mid + 1 ~ r中的元素)(逆序对出现!!!) 
        { 
            ans += mid - i + 1;//由于l ~ mid和mid + 1 ~ r都是有序序列所以一旦l ~ mid中的元素 > mid + 1 ~ r中的元素而又因为第i个元素 < i + 1 ~ mid那么i + 1 ~ mid的元素都 > 第j个元素。所以+的元素个数就是i ~ mid的元素个数,及mid - i + 1(归并排序里没有这句话,求逆序对里有) 
            b[t++] = a[j];//第j个元素比i ~ mid的元素都小,那么第j个元素是目前最小的了,就放进b数组里 
            ++j;//下一个元素(mid + 1 ~ r的元素小,所以加第j个元素) 
        }
        else
        {
            b[t++] = a[i];//i小,存a[i] 
            ++i;//同理 
        }
    }
    while(i <= mid)//把剩的元素(因为较大所以在上面没选) 
    {
        b[t++] = a[i];//存进去 
        ++i; 
    }
    while(j <= r)//同理 
    {
        b[t++] = a[j];
        ++j;
    }
    for(register int i = l; i <= r; ++i)//把有序序列b赋值到a里 
    {
        a[i] = b[i];
    }
    return;
}
int main()
{
    scanf("%d", &n);
    for(register int i = 1; i <= n; ++i)
    {
        scanf("%d", &a[i]);
    }
    msort(1, n);//一开始序列是1 ~ n 
    printf("%lld", ans);
    return 0;
}
相关标签: C++题解