当前位置: 代码迷 >> 综合 >> POJ 2400 KM算法 最小权匹配 回溯输出所有最优匹配方案
  详细解决方案

POJ 2400 KM算法 最小权匹配 回溯输出所有最优匹配方案

热度:98   发布时间:2024-01-13 17:52:51.0

很蛋疼的一题 

首先输入就很蛋疼, 据网上的神牛们纷纷说题目的矩阵给反了。然后按反着来还真给过了

KM的话 由于是 n与n的匹配,所以直接取负求KM毫无压力

但是如果两边点数不等,据说会有问题    


#include <iostream>
#include <cstdio>
#include <cstring>
#include <algorithm>
#include <vector>
#include <queue>
#define MAXN 505
#define MAXM 555555
#define INF 1000000000
using namespace std;
int n, m, ny, nx;
int w[MAXN][MAXN];
int lx[MAXN], ly[MAXN];
int linky[MAXN];
int visx[MAXN], visy[MAXN];
int slack[MAXN], ans;
int res[MAXN], v[MAXN];
bool find(int x)
{visx[x] = 1;for(int y = 1; y <= ny; y++){if(visy[y]) continue;int t = lx[x] + ly[y] - w[x][y];if(t == 0){visy[y] = 1;if(linky[y] == -1 || find(linky[y])){linky[y] = x;return true;}}else if(slack[y] > t) slack[y] = t;}return false;
}
int KM()
{memset(linky, -1, sizeof(linky));for(int i = 1; i <= nx; i++) lx[i] = -INF;memset(ly, 0, sizeof(ly));for(int i = 1; i <= nx; i++)for(int j = 1; j <= ny; j++)if(w[i][j] > lx[i]) lx[i] = w[i][j];for(int x = 1; x <= nx; x++){for(int i = 1; i <= ny; i++) slack[i] = INF;while(true){memset(visx, 0, sizeof(visx));memset(visy, 0, sizeof(visy));if(find(x)) break;int d = INF;for(int i = 1; i <= ny; i++)if(!visy[i]) d = min(d, slack[i]);for(int i = 1; i <= nx; i++)if(visx[i]) lx[i] -=d;for(int i = 1; i <= ny; i++)if(visy[i]) ly[i] += d;else slack[i] -= d;}}ans = 0;for(int i = 1; i <= ny; i++)if(linky[i] > 0) ans -= w[linky[i]][i];return ans;
}
int cnt;
void dfs(int u, int sum)
{if(sum > ans) return;if(u > n){if(sum != ans) return;printf("Best Pairing %d\n", ++cnt);for(int i = 1; i <= n; i++)printf("Supervisor %d with Employee %d\n", i, res[i]);}else{for(int i = 1; i <= n; i++)if(!v[i]){res[u] = i;v[i] = 1;dfs(u + 1, sum - w[u][i]);v[i] = 0;}}
}
int main()
{int x, y, z, e, cas = 0, T;scanf("%d", &T);while(T--){if(cas) printf("\n");scanf("%d", &n);nx = n, ny = n, m = n;for(int i = 1; i <= n; i++)for(int j = 1; j <= m; j++)w[i][j] = 0;for(int i = 1; i <= n; i++)for(int j = 0; j < n; j++){scanf("%d", &x);w[x][i] -= j;}for(int i = 1; i <= n; i++)for(int j = 0; j < n; j++){scanf("%d", &x);w[i][x] -= j;}printf("Data Set %d, Best average difference: %f\n", ++cas, 0.5 * KM() / n);cnt = 0;memset(v, 0, sizeof(v));dfs(1, 0);}return 0;
}