当前位置: 代码迷 >> 综合 >> HOJ 2586 How far away ?(算法竞赛进阶指南,树上倍增法,最近公共祖先LCA)
  详细解决方案

HOJ 2586 How far away ?(算法竞赛进阶指南,树上倍增法,最近公共祖先LCA)

热度:38   发布时间:2023-12-13 19:34:34.0

算法竞赛进阶指南, 376页,树上倍增法求 最近公共祖先LCA
题目意思:
求任意两点的距离。先求出任意两点的 LCA(x, y), 然后答案就是:
dist[x] + dist[y] - 2 * dist[lca(x, y)]

本题要点:
1、以点1为根,d[i]表示从i到根的深度(d[1] = 1); dist[i] 表示i到1点的距离。
f[i][k] 表示 点i 向上走 2^k 步到达的节点号, 如果该节点不存在, 就令 f[i][k] = 0。
f[i][0]就是节点i的父节点。除此之外, 动态规划转态转移方程。
f[i][k] = f[f[i][k - 1]][k - 1]
2、假设深度 d[y] >= d[x], 然后,y点就不断向上走 2^(log(n)), 2^(log(n) - 1), …, 2^1, 2^0 步,
使得x和y处于同一个高度,此时x == y ,说明x 节点就是要求的 LCA。
3、 经过第二步,x和y已经处于深度,然后x和y同时向上走 2^(log(n)), 2^(log(n) - 1), …, 2^1, 2^0 步,
如果 f[x][k] != f[y][k], 才同时向上走。
最后,x和y差一步就相会了,此时 LCA 就是 f[x][0] (x的父节点)

#include <cstdio>
#include <cstring>
#include <iostream>
#include <queue>
#include <cmath>
#include <algorithm>
using namespace std;
const int MaxN = 50010;
int ver[MaxN * 2], head[MaxN * 2], Next[MaxN * 2], edge[MaxN * 2];
int f[MaxN][20];	//f[i][k] 表示 点i 向上走 2^k 步到达的节点号(n个节点,节点编号从1到n)
int d[MaxN];		//d[i] 点i的深度
int dist[MaxN];		//dist[i]表示点i到点1的距离
int n, q, T, depth;	//depth 表示树的深度
int tot;void add(int x, int y, int z)
{
    ver[++tot] = y, edge[tot] = z, Next[tot] = head[x], head[x] = tot;	
}void bfs()
{
    d[1] = 1, dist[1] = 0;queue<int> q;q.push(1);while(!q.empty()){
    int x = q.front();q.pop();for(int i = head[x]; i; i = Next[i]){
    int y = ver[i];if(d[y]){
    continue;}d[y] = d[x] + 1;dist[y] = dist[x] + edge[i];f[y][0] = x;	//点y的父节点是xfor(int t = 1; t <= depth; ++t){
    f[y][t] = f[f[y][t - 1]][t - 1];	}q.push(y);}}
}int lca(int x, int y)	//节点x和y的最近公共祖先
{
    if(d[x] > d[y]){
    swap(x, y);	//使得 d[x] <= d[y], 然后调整 y}for(int t = depth; t >= 0; --t)//这个循环的目的是,使得y和x处于同一高度{
    if(d[f[y][t]] >= d[x]){
    y = f[y][t];}}if(x == y){
    return x;}for(int t = depth; t >= 0; --t)		//这个循环使得,x和y节点处于同一个深度{
    if(f[x][t] != f[y][t]){
    x = f[x][t], y = f[y][t];}}return f[x][0];
}int main()
{
    int x, y, z;scanf("%d", &T);while(T--){
    scanf("%d%d", &n, &q);depth = (int)(log(n) / log(2)) + 1;tot = 0;for(int i = 1; i <= n; ++i){
    head[i] = d[i] = 0;}for(int i = 1; i < n; ++i){
    scanf("%d%d%d", &x, &y, &z);		add(x, y, z), add(y, x, z);}bfs();for(int i = 1; i <= q; ++i){
    scanf("%d%d", &x, &y);printf("%d\n", dist[x] + dist[y] - 2 * dist[lca(x, y)]);}}return 0;
}/* 2 3 2 1 2 10 3 1 15 1 2 2 32 2 1 2 100 1 2 2 1 *//* 10 25 100 100 */