算法竞赛进阶指南,371 页, 树的直径
参考 https://www.cnblogs.com/knife-rose/p/11189707.html
本题要点:
1、k 分类讨论:
1)K=1时,要使得走过的路最少,显然是找到树的直径。
答案 = 2 * (n - 1) - L1 + 1 // L1 是原树的直径
2)K=2的时候呢,我们仔细思考一下就会发现,对于两条路重叠的部分,因为每个点一定要走,所以要多走一次。
那么我们把直径上的边权值取反,再找出一条权值最大的路径,减掉就好了。不能用bfs求树的直径,要用dp
答案 = 2 * n - L1 - L2 //L2 是把第一条直径L1上的边改为 -1 后的直径
2、求解树的直径的两种方法优缺点:
1)双向 bfs法:
缺点:不能处理含有负边权的树
优点:明显直到直径的起点和终点
2)树形dp
缺点:起点终点难以记录
优点:能处理负边
3、用 dfs 把直径上的所有的边权由 1 改为 -1,每次把连接x和y的两条边都改了
数组 edge_id[i] 表示第i条边相反的边的编号
读入x, y 时候,
add(x, y, 1), edge[tot] = 边(x-->y)
和 add(y, x, 1), edge[tot +1] = 边(y-->x)
edge_id[tot] = tot + 1
edge_id[tot + 1] = tot
#include <cstdio>
#include <cstring>
#include <iostream>
#include <queue>
#include <algorithm>
using namespace std;
const int MaxN = 300010;
int n, k, tot, L2;
int ver[MaxN * 2], head[MaxN * 2], Next[MaxN * 2], edge[MaxN * 2];
int edge_id[MaxN * 2];
bool vis[MaxN];
int d[MaxN]; //d[x] 表示以x为根节点的子树,x能够到达的最远距离struct Node
{
int id;int dist;
}; //每个节点到当前起点 st的距离void add(int x, int y, int z)
{
ver[++tot] = y, edge[tot] = z, Next[tot] = head[x], head[x] = tot;
}// 双向 bfs 求树的直径
void bfs(int st, int& diam, int& diam_y) //st是根, diam 是直径, diam_y 是离 起点 st最远的那个点
{
Node tmp, nod;memset(vis, false, sizeof(vis));nod.id = st, nod.dist = 0;queue<Node> q;q.push(nod);vis[st] = true;while(q.size()){
nod = q.front();q.pop();if(diam < nod.dist){
diam_y = nod.id;diam = nod.dist;}for(int i = head[nod.id]; i; i = Next[i]){
int y = ver[i];if(vis[y]){
continue;}tmp.id = y, tmp.dist = nod.dist + edge[i];q.push(tmp);vis[y] = true;}}
}int st, ed;//把直径上的所有的边权值由 1 改为 -1
bool dfs(int cur) //cur 表示当前的节点
{
vis[cur] = true;if(cur == ed){
return true;}for(int i = head[cur]; i; i = Next[i]){
int y = ver[i];if(vis[y]){
continue;}if(dfs(y)){
edge[i] = -1;edge[edge_id[i]] = -1;return true;}}return false;
}void dp(int x)
{
vis[x] = true;for(int i = head[x]; i; i = Next[i]){
int y = ver[i];if(vis[y]){
continue;}dp(y);L2 = max(L2, d[x] + d[y] + edge[i]);d[x] = max(d[x], d[y] + edge[i]);}
}void solve()
{
int p, q, L1 = 0;bfs(1, L1, p); //算出pL1 = 0;bfs(p, L1, q);if(1 == k){
printf("%d\n", 2 * (n - 1) - L1 + 1);}else if(2 == k){
st = p, ed = q;memset(vis, false, sizeof(vis));dfs(st);memset(vis, false, sizeof(vis));dp(1);printf("%d\n", 2 * n - L1 - L2); }
}int main()
{
int x, y;scanf("%d%d", &n, &k);for(int i = 1; i < n; ++i){
scanf("%d%d", &x, &y);add(x, y, 1);edge_id[tot] = tot + 1;add(y, x, 1);edge_id[tot] = tot - 1;}solve();return 0;
}