欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

二分图的最佳完美匹配--KM算法(DFS寻路+BFS寻路(O(n^3))) + HDU2255入门题

程序员文章站 2022-06-09 20:17:11
...

Reference Blog:
原理清晰深刻:https://blog.csdn.net/sixdaycoder/article/details/47720471
较容易于理解:https://www.cnblogs.com/wenruo/p/5264235.html

如果二分图的每条边都有一个权(可以是负数),要求一种完备匹配方案,使得所有匹配边的权和最大,记做最佳完美匹配。(特殊的,当所有边的权为1时,就是最大完备匹配问题)

算法流程:
二分图的最佳完美匹配--KM算法(DFS寻路+BFS寻路(O(n^3))) + HDU2255入门题

dfs寻增广路模板:(只针对随即数据O(n^3),对于极限数据(w[i][j]很大)slack优化作用不显著)

const int AX = 3e2+6;
bool visx[AX];
bool visy[AX];
int w[AX][AX];
int lx[AX] , ly[AX];
int linker[AX];
int slack[AX];  
int n ;
bool dfs( int x ){
    visx[x] = true;
    for( int y = 1 ; y <= n ; y ++ ){
        if( !visy[y] && lx[x] + ly[y] == w[x][y] ){
            visy[y] = true;
            if( linker[y] == -1 || dfs( linker[y] ) ){
                linker[y] = x ;
                return true;
            }
        }else if( slack[y] > lx[x] + ly[y] - w[x][y] ){
            slack[y] = lx[x] + ly[y] - w[x][y];
        }
    }
    return false;
}

void KM(){
    memset( linker , -1 , sizeof(linker) );
    memset( ly , 0 , sizeof(ly) );
    for( int i = 1 ; i <= n ; i++ ){
        lx[i] = -INF;
        for( int j = 1 ; j <= n ; j++ ){
            if( lx[i] < w[i][j] ) lx[i] = w[i][j];
        }
    }
    for( int x = 1 ; x <= n ; x++ ){
        for( int i = 1 ; i <= n ; i++ ) slack[i] = INF;
        while(1){
            memset( visx , false , sizeof(visx) );
            memset( visy , false , sizeof(visy) );
            if( dfs(x) ){
                break;
            }else{
                int delta = INF;
                for( int j = 1 ; j <= n ; j++ ){
                    if( !visy[j] && delta > slack[j] ){
                        delta = slack[j];
                    }
                }

                for( int i = 1 ; i <= n ; i++ ){
                    if( visx[i] ) lx[i] -= delta;
                }
                for( int i = 1 ; i <= n ; i++ ){
                    if( visy[i] ) ly[i] += delta;
                    else slack[i] -= delta;
                }
            }
        }
    }
}

BFS寻路模板(真正O(n^3))

const int AX = 3e2+6;
LL w[AX][AX];
LL lx[AX] , ly[AX];
int linker[AX];
LL slack[AX];
int n ;
bool visy[AX];
int pre[AX];
void bfs( int k ){
    int x , y = 0 , yy = 0 , delta;
    memset( pre , 0 , sizeof(pre) );
    for( int i = 1 ; i <= n ; i++ ) slack[i] = INF;
    linker[y] = k;
    while(1){
        x = linker[y]; delta = INF; visy[y] = true;
        for( int i = 1 ; i <= n ;i++ ){
            if( !visy[i] ){
                if( slack[i] > lx[x] + ly[i] - w[x][i] ){
                    slack[i] = lx[x] + ly[i] - w[x][i];
                    pre[i] = y; 
                }
                if( slack[i] < delta ) delta = slack[i] , yy = i ;
            }
        }
        for( int i = 0 ; i <= n ; i++ ){
            if( visy[i] ) lx[linker[i]] -= delta , ly[i] += delta;
            else slack[i] -= delta;
        }
        y = yy ;
        if( linker[y] == -1 ) break;
    }
    while( y ) linker[y] = linker[pre[y]] , y = pre[y];
}

void KM(){
    memset( lx , 0 ,sizeof(lx) );
    memset( ly , 0 ,sizeof(ly) );
    memset( linker , -1, sizeof(linker) );
    for( int i = 1 ; i <= n ; i++ ){
        memset( visy , false , sizeof(visy) );
        bfs(i);
    }
}

HDU2255
AC Code:
O(n^3)

#include <bits/stdc++.h>
#pragma comment(linker, “/STACK:1024000000,1024000000”)
#define INF 0x3f3f3f3f
#define LL long long
using namespace std;
const int AX = 3e2+6;
bool visx[AX];
bool visy[AX];
int w[AX][AX];
int lx[AX] , ly[AX];  //可行性顶标
int linker[AX];  //记录匹配的边
int slack[AX];   //记录每个j相连的i的最小的lx[i]+ly[j]-w[i][j]
int n ;
bool dfs( int x ){
    visx[x] = true;
    for( int y = 1 ; y <= n ; y ++ ){
        if( !visy[y] && lx[x] + ly[y] == w[x][y] ){
            visy[y] = true;
            if( linker[y] == -1 || dfs( linker[y] ) ){
                linker[y] = x ;
                return true;
            }
        }else if( slack[y] > lx[x] + ly[y] - w[x][y] ){//x,y不在相等子图且y不在增广路
            slack[y] = lx[x] + ly[y] - w[x][y];
        }
    }
    return false;
}

void KM(){
    memset( linker , -1 , sizeof(linker) );
    memset( ly , 0 , sizeof(ly) );
    for( int i = 1 ; i <= n ; i++ ){
        lx[i] = -INF;
        for( int j = 1 ; j <= n ; j++ ){
            if( lx[i] < w[i][j] ) lx[i] = w[i][j];
        }
    }
    for( int x = 1 ; x <= n ; x++ ){
        for( int i = 1 ; i <= n ; i++ ) slack[i] = INF;//每次匹配x都要更新slack
        while(1){
            memset( visx , false , sizeof(visx) );
            memset( visy , false , sizeof(visy) );
            if( dfs(x) ){
                break;
            }else{ // 匹配失败后x一定在增广路,寻找不在增广路的j
                int delta = INF;
                for( int j = 1 ; j <= n ; j++ ){
                    if( !visy[j] && delta > slack[j] ){
                        delta = slack[j];
                    }
                }

                for( int i = 1 ; i <= n ; i++ ){
                    if( visx[i] ) lx[i] -= delta;
                }
                for( int i = 1 ; i <= n ; i++ ){
                    if( visy[i] ) ly[i] += delta;
                    else slack[i] -= delta;
                    //修改顶标后,要把所有的slack值都减去delta
                     //slack[j] = min(lx[i] + ly[j] -w[i][j])
                     //在增广路的lx[i]减少,所以不在增广路的slack[j]减小
                }
            }
        }
    }
}

int main(){
    int x ;
    while( ~scanf("%d",&n) ){
        for( int i = 1 ; i <= n ; i++ ){
            for( int j = 1 ; j <= n ; j++ ){
                scanf("%d",&x);
                w[i][j] = x ;
            }
        }
        KM();
        int res = 0 ;
        for( int i = 1 ; i <= n ; i++ ){
            if( linker[i] != -1 ){
                res += w[linker[i]][i] ;
            }
        }
        printf("%d\n",res);
    }
    return 0 ;
}

TLE Code:
O(n^4)

#include <bits/stdc++.h>
#pragma comment(linker, “/STACK:1024000000,1024000000”)
#define INF 0x3f3f3f3f
#define LL long long
using namespace std;
const int AX = 3e2+6;
bool visx[AX];
bool visy[AX];
int w[AX][AX];
int lx[AX] , ly[AX];
int delta ;
int linker[AX];
int n ;
bool dfs( int x ){
    visx[x] = true;
    for( int y = 1 ; y <= n ; y ++ ){
        if( !visy[y] && lx[x] + ly[y] == w[x][y] ){
            visy[y] = true;
            if( linker[y] == -1 || dfs( linker[y] ) ){
                linker[y] = x ;
                return true;
            }
        }
    }
    return false;
}

void KM(){
    memset( linker , -1 , sizeof(linker) );
    memset( ly , 0 , sizeof(ly) );
    for( int i = 1 ; i <= n ; i++ ){
        lx[i] = -INF;
        for( int j = 1 ; j <= n ; j++ ){
            if( lx[i] < w[i][j] ) lx[i] = w[i][j];
        }
    }
    for( int x = 1 ; x <= n ; x++ ){
        while(1){
            memset( visx , false , sizeof(visx) );
            memset( visy , false , sizeof(visy) );
            if( dfs(x) ){
                break;
            }else{
                for( int i = 1 ; i <= n ; i++ ){
                    if( visx[i] ){
                        for( int j = 1 ; j <= n ; j++ ){
                            if( !visy[j] ){
                                delta = min( delta , lx[x]+ly[j]-w[i][j] );
                            }
                        }
                    }
                }
                for( int i = 1 ; i <= n ; i++ ){
                    if( visx[i] ) lx[i] -= delta;
                }
                for( int i = 1 ; i <= n ; i++ ){
                    if( visy[i] ) ly[i] += delta;
                }
            }
        }
    }
}

int main(){
    int x ;
    while( ~scanf("%d",&n) ){
        for( int i = 1 ; i <= n ; i++ ){
            for( int j = 1 ; j <= n ; j++ ){
                scanf("%d",&x);
                w[i][j] = x ;
            }
        }
        KM();
        int res = 0 ;
        for( int i = 1 ; i <= n ; i++ ){
            if( linker[i] != -1 ){
                res += w[linker[i]][i] ;
            }
        }
        printf("%d\n",res);
    }
    return 0 ;
}