关于如何建立虚树,借鉴一下学习时看的大佬的博客:
考虑得到了询问点,如何构造出一棵虚树。
首先我们要先对整棵树dfs一遍,求出他们的dfs序,然后对每个节点以dfs序为关键字从小到大排序
同时维护一个栈,表示从根到栈顶元素这条链
假设当前要加入的节点为p,栈顶元素为x=s[top],lca为他们的最近公共祖先
因为我们是按照dfs序遍历,因此lca不可能是p
那么现在会有两种情况
- lca是x,直接将p入栈。
- x,p分别位于lca的两棵子树中,此时x这棵子树已经遍历完毕,(如果没有,即x的子树中还有一个未加入的点y,但是dfn[y]<dfn[p],即应先访问y), 我们需要对其进行构建
设栈顶元素为x,第二个元素为y
- 若dfn[y]>dfn[lca],可以连边y?>x,将x出栈;
- 若dfn[y]=dfn[lca],即y=lca,连边lca?>x,此时子树构建完毕(break);
- 若dfn[y]<dfn[lca],即lca在y,x之间,连边lca?>x,x出栈,再将lca入栈。此时子树构建完毕(break)。
然后关于这道题,有一些特殊的地方。依题意可知如果一个节点需要断开,那它子树的节点学再需要断开就不用考虑了,因为一定会被断开。所以建虚树时如果栈顶元素是新元素的祖先,那新元素不需要进栈,所以栈内除栈顶元素以外的都是lca,所以我们最后建出的虚树只有叶子节点是需要被断开的点,其他都是lca,断不断均可。
#include<cstdio>
#include<cstring>
#include<algorithm>
using namespace std;
const int N=250010;
const long long inf=4567891012345678910;
struct edge{int y,next,w;
}data[N*2];
struct edge1{int y,next;
}data1[N*2];
long long dp[N],w1[N];
int n,m,k,num,num1,num2,h[N],q[N],dfn[N],dep[N],h1[N],fa[N][20],sta[N],top;
inline int read(){int x=0;char ch=getchar();while(ch<'0'||ch>'9')ch=getchar();while(ch>='0'&&ch<='9'){x=x*10+ch-'0';ch=getchar();}return x;
}
inline long long min2(long long x,long long y){return x<y?x:y;
}
inline void addedge(int x,int y,int w){data[++num].y=y,data[num].w=w,data[num].next=h[x],h[x]=num;
}
inline void addedge1(int x,int y){data1[++num1].y=y,data1[num1].next=h1[x],h1[x]=num1;
}
void dfs1(int u,int f,int d){dfn[u]=++num2;dep[u]=d;for(int i=h[u];i!=-1;i=data[i].next){int v=data[i].y;if(v!=f){w1[v]=min2(1ll*data[i].w,w1[u]),fa[v][0]=u;dfs1(v,u,d+1);} }
}
bool cmp(int i,int j){return dfn[i]<dfn[j];
}
int lca(int x,int y){if(dep[x]<dep[y])swap(x,y);for(int i=18;i>=0;--i)if(dep[fa[x][i]]>=dep[y])x=fa[x][i];if(x==y)return x;for(int i=18;i>=0;--i)if(fa[x][i]!=fa[y][i])x=fa[x][i],y=fa[y][i];return fa[x][0];
}
void ins(int x){if(top==1){sta[++top]=x;return;}int l=lca(x,sta[top]);if(l==sta[top])return;while(top>1&&dfn[sta[top-1]]>=dfn[l])addedge1(sta[top-1],sta[top]),--top;if(sta[top]!=l)addedge1(l,sta[top]),sta[top]=l;sta[++top]=x;
}
void dfs2(int u,int f){dp[u]=w1[u];if(h1[u]==-1)return;long long sum=0;for(int i=h1[u];i!=-1;i=data1[i].next){int v=data1[i].y;dfs2(v,u);sum+=dp[v]; }dp[u]=min2(dp[u],sum),h1[u]=-1;
}
int main(){n=read();memset(h,-1,sizeof h),num=0;for(int u,v,w,i=1;i<n;++i){u=read(),v=read(),w=read();addedge(u,v,w),addedge(v,u,w);}w1[1]=inf,num2=0;dfs1(1,0,1);for(int i=1;i<=18;++i)for(int j=1;j<=n;++j)fa[j][i]=fa[fa[j][i-1]][i-1];m=read();memset(h1,-1,sizeof h1);for(int i=1;i<=m;++i){k=read(),num1=0;for(int j=1;j<=k;++j)q[j]=read();sort(q+1,q+k+1,cmp);top=1,sta[1]=1;for(int j=1;j<=k;++j)ins(q[j]);while(top>1)addedge1(sta[top-1],sta[top]),--top;dfs2(1,0);printf("%lld\n",dp[1]);}return 0;
}