二分图的最佳完美匹配--KM算法(DFS寻路+BFS寻路(O(n^3))) + HDU2255入门题
程序员文章站
2022-06-09 20:17:11
...
Reference Blog:
原理清晰深刻:https://blog.csdn.net/sixdaycoder/article/details/47720471
较容易于理解:https://www.cnblogs.com/wenruo/p/5264235.html
如果二分图的每条边都有一个权(可以是负数),要求一种完备匹配方案,使得所有匹配边的权和最大,记做最佳完美匹配。(特殊的,当所有边的权为1时,就是最大完备匹配问题)
算法流程:
dfs寻增广路模板:(只针对随即数据O(n^3),对于极限数据(w[i][j]很大)slack优化作用不显著)
const int AX = 3e2+6;
bool visx[AX];
bool visy[AX];
int w[AX][AX];
int lx[AX] , ly[AX];
int linker[AX];
int slack[AX];
int n ;
bool dfs( int x ){
visx[x] = true;
for( int y = 1 ; y <= n ; y ++ ){
if( !visy[y] && lx[x] + ly[y] == w[x][y] ){
visy[y] = true;
if( linker[y] == -1 || dfs( linker[y] ) ){
linker[y] = x ;
return true;
}
}else if( slack[y] > lx[x] + ly[y] - w[x][y] ){
slack[y] = lx[x] + ly[y] - w[x][y];
}
}
return false;
}
void KM(){
memset( linker , -1 , sizeof(linker) );
memset( ly , 0 , sizeof(ly) );
for( int i = 1 ; i <= n ; i++ ){
lx[i] = -INF;
for( int j = 1 ; j <= n ; j++ ){
if( lx[i] < w[i][j] ) lx[i] = w[i][j];
}
}
for( int x = 1 ; x <= n ; x++ ){
for( int i = 1 ; i <= n ; i++ ) slack[i] = INF;
while(1){
memset( visx , false , sizeof(visx) );
memset( visy , false , sizeof(visy) );
if( dfs(x) ){
break;
}else{
int delta = INF;
for( int j = 1 ; j <= n ; j++ ){
if( !visy[j] && delta > slack[j] ){
delta = slack[j];
}
}
for( int i = 1 ; i <= n ; i++ ){
if( visx[i] ) lx[i] -= delta;
}
for( int i = 1 ; i <= n ; i++ ){
if( visy[i] ) ly[i] += delta;
else slack[i] -= delta;
}
}
}
}
}
BFS寻路模板(真正O(n^3))
const int AX = 3e2+6;
LL w[AX][AX];
LL lx[AX] , ly[AX];
int linker[AX];
LL slack[AX];
int n ;
bool visy[AX];
int pre[AX];
void bfs( int k ){
int x , y = 0 , yy = 0 , delta;
memset( pre , 0 , sizeof(pre) );
for( int i = 1 ; i <= n ; i++ ) slack[i] = INF;
linker[y] = k;
while(1){
x = linker[y]; delta = INF; visy[y] = true;
for( int i = 1 ; i <= n ;i++ ){
if( !visy[i] ){
if( slack[i] > lx[x] + ly[i] - w[x][i] ){
slack[i] = lx[x] + ly[i] - w[x][i];
pre[i] = y;
}
if( slack[i] < delta ) delta = slack[i] , yy = i ;
}
}
for( int i = 0 ; i <= n ; i++ ){
if( visy[i] ) lx[linker[i]] -= delta , ly[i] += delta;
else slack[i] -= delta;
}
y = yy ;
if( linker[y] == -1 ) break;
}
while( y ) linker[y] = linker[pre[y]] , y = pre[y];
}
void KM(){
memset( lx , 0 ,sizeof(lx) );
memset( ly , 0 ,sizeof(ly) );
memset( linker , -1, sizeof(linker) );
for( int i = 1 ; i <= n ; i++ ){
memset( visy , false , sizeof(visy) );
bfs(i);
}
}
HDU2255
AC Code:
O(n^3)
#include <bits/stdc++.h>
#pragma comment(linker, “/STACK:1024000000,1024000000”)
#define INF 0x3f3f3f3f
#define LL long long
using namespace std;
const int AX = 3e2+6;
bool visx[AX];
bool visy[AX];
int w[AX][AX];
int lx[AX] , ly[AX]; //可行性顶标
int linker[AX]; //记录匹配的边
int slack[AX]; //记录每个j相连的i的最小的lx[i]+ly[j]-w[i][j]
int n ;
bool dfs( int x ){
visx[x] = true;
for( int y = 1 ; y <= n ; y ++ ){
if( !visy[y] && lx[x] + ly[y] == w[x][y] ){
visy[y] = true;
if( linker[y] == -1 || dfs( linker[y] ) ){
linker[y] = x ;
return true;
}
}else if( slack[y] > lx[x] + ly[y] - w[x][y] ){//x,y不在相等子图且y不在增广路
slack[y] = lx[x] + ly[y] - w[x][y];
}
}
return false;
}
void KM(){
memset( linker , -1 , sizeof(linker) );
memset( ly , 0 , sizeof(ly) );
for( int i = 1 ; i <= n ; i++ ){
lx[i] = -INF;
for( int j = 1 ; j <= n ; j++ ){
if( lx[i] < w[i][j] ) lx[i] = w[i][j];
}
}
for( int x = 1 ; x <= n ; x++ ){
for( int i = 1 ; i <= n ; i++ ) slack[i] = INF;//每次匹配x都要更新slack
while(1){
memset( visx , false , sizeof(visx) );
memset( visy , false , sizeof(visy) );
if( dfs(x) ){
break;
}else{ // 匹配失败后x一定在增广路,寻找不在增广路的j
int delta = INF;
for( int j = 1 ; j <= n ; j++ ){
if( !visy[j] && delta > slack[j] ){
delta = slack[j];
}
}
for( int i = 1 ; i <= n ; i++ ){
if( visx[i] ) lx[i] -= delta;
}
for( int i = 1 ; i <= n ; i++ ){
if( visy[i] ) ly[i] += delta;
else slack[i] -= delta;
//修改顶标后,要把所有的slack值都减去delta
//slack[j] = min(lx[i] + ly[j] -w[i][j])
//在增广路的lx[i]减少,所以不在增广路的slack[j]减小
}
}
}
}
}
int main(){
int x ;
while( ~scanf("%d",&n) ){
for( int i = 1 ; i <= n ; i++ ){
for( int j = 1 ; j <= n ; j++ ){
scanf("%d",&x);
w[i][j] = x ;
}
}
KM();
int res = 0 ;
for( int i = 1 ; i <= n ; i++ ){
if( linker[i] != -1 ){
res += w[linker[i]][i] ;
}
}
printf("%d\n",res);
}
return 0 ;
}
TLE Code:
O(n^4)
#include <bits/stdc++.h>
#pragma comment(linker, “/STACK:1024000000,1024000000”)
#define INF 0x3f3f3f3f
#define LL long long
using namespace std;
const int AX = 3e2+6;
bool visx[AX];
bool visy[AX];
int w[AX][AX];
int lx[AX] , ly[AX];
int delta ;
int linker[AX];
int n ;
bool dfs( int x ){
visx[x] = true;
for( int y = 1 ; y <= n ; y ++ ){
if( !visy[y] && lx[x] + ly[y] == w[x][y] ){
visy[y] = true;
if( linker[y] == -1 || dfs( linker[y] ) ){
linker[y] = x ;
return true;
}
}
}
return false;
}
void KM(){
memset( linker , -1 , sizeof(linker) );
memset( ly , 0 , sizeof(ly) );
for( int i = 1 ; i <= n ; i++ ){
lx[i] = -INF;
for( int j = 1 ; j <= n ; j++ ){
if( lx[i] < w[i][j] ) lx[i] = w[i][j];
}
}
for( int x = 1 ; x <= n ; x++ ){
while(1){
memset( visx , false , sizeof(visx) );
memset( visy , false , sizeof(visy) );
if( dfs(x) ){
break;
}else{
for( int i = 1 ; i <= n ; i++ ){
if( visx[i] ){
for( int j = 1 ; j <= n ; j++ ){
if( !visy[j] ){
delta = min( delta , lx[x]+ly[j]-w[i][j] );
}
}
}
}
for( int i = 1 ; i <= n ; i++ ){
if( visx[i] ) lx[i] -= delta;
}
for( int i = 1 ; i <= n ; i++ ){
if( visy[i] ) ly[i] += delta;
}
}
}
}
}
int main(){
int x ;
while( ~scanf("%d",&n) ){
for( int i = 1 ; i <= n ; i++ ){
for( int j = 1 ; j <= n ; j++ ){
scanf("%d",&x);
w[i][j] = x ;
}
}
KM();
int res = 0 ;
for( int i = 1 ; i <= n ; i++ ){
if( linker[i] != -1 ){
res += w[linker[i]][i] ;
}
}
printf("%d\n",res);
}
return 0 ;
}