当前位置: 代码迷 >> 综合 >> bzoj2286 虚树第一道
  详细解决方案

bzoj2286 虚树第一道

热度:28   发布时间:2024-01-04 12:43:18.0

 

据说虚树问题总是要套dp的。

 

主要的巧妙之处在于,可以在等同于数据大小的复杂度内建好虚树,关键是利用了dfs序的性质。

 

#include<cstdio>
#include<cmath>
#include<cstdlib>
#include<algorithm>
#include<cstdlib>
#include<stack>
#include<cstring>
#include<vector>
using namespace std;
const int N=250005;
typedef long long ll;
int n,m,k[N],num;
int head[N],tot;
struct aa
{
int to,pre;ll c;
}edge[N*2];
void addedge(int u,int v,ll c)
{
edge[++tot].to=v;edge[tot].pre=head[u];edge[tot].c=c;head[u]=tot;
}
int fa[N][22],dep[N],dfn[N],cnt;ll g[N][22];
void dfs(int u,int depth)
{
dep[u]=depth;dfn[u]=++cnt;
for (int i=1;i<=20;i++) 
{
fa[u][i]=fa[fa[u][i-1]][i-1];
g[u][i]=min(g[u][i-1],g[fa[u][i-1]][i-1]);
}
for (int v,i=head[u];i;i=edge[i].pre)
if (fa[u][0]!=(v=edge[i].to)) 
{
fa[v][0]=u;
g[v][0]=edge[i].c;
dfs(v,depth+1);
}
}
bool cmp(int a,int b) {return dfn[a]<dfn[b];}
vector<int> E[N];
ll dis[N];
bool b[N];
ll dp(int u)
{
if (b[u]) return dis[u];
ll tmp=0;int sz=E[u].size();
for (int i=0;i<sz;i++) tmp+=dp(E[u][i]);
return min(tmp,dis[u]);
}
void clear(int u)
{
b[u]=false;dis[u]=0;
int sz=E[u].size();
for (int i=0;i<sz;i++) clear(E[u][i]);
E[u].clear();
}
void add(int u,int v) {E[u].push_back(v);}
void up(int &u,int tmp)
{
for (int i=0;i<=20;i++) if (tmp&(1<<i)) u=fa[u][i];
}
int LCA(int u,int v)
{
if (dep[u]>dep[v]) up(u,dep[u]-dep[v]);
else up(v,dep[v]-dep[u]);
for (int i=20;i>=0;i--) if (fa[u][i]!=fa[v][i]) u=fa[u][i],v=fa[v][i];
if (u!=v) return fa[u][0];
return u;
}
ll mi(int u,int v)
{
ll ans=1e18;
int tmp=dep[v]-dep[u];
for (int i=0;i<=20;i++) if (tmp&(1<<i)) ans=min(ans,g[v][i]),v=fa[v][i];
return ans;
}
int top,s[N];
void work()
{
sort(k+1,k+num+1,cmp);
top=0;s[++top]=1;
for (int i=1;i<=num;i++)
{
while (top)
{
int lca=LCA(k[i],s[top]);
if (lca==s[top]) {s[++top]=k[i];break;}
else if (dep[lca]>dep[s[top-1]])
{
add(lca,s[top]);dis[s[top]]=mi(lca,s[top]);
s[top]=lca,s[++top]=k[i];
break;
}
else if (dep[lca]==dep[s[top-1]])
{
add(lca,s[top]);dis[s[top]]=mi(lca,s[top]);
s[top]=k[i];
break;
}
else 
{
add(s[top-1],s[top]);dis[s[top]]=mi(s[top-1],s[top]);
top--;
}
}
}
while (top>1)
{
add(s[top-1],s[top]);dis[s[top]]=mi(s[top-1],s[top]);
top--;
}
dis[1]=1e18;
printf("%lld\n",dp(1));
clear(1);
}
int main()
{
scanf("%d",&n);
int u,v;ll c;
for (int i=1;i<n;i++)
{
scanf("%d%d%lld",&u,&v,&c);
addedge(u,v,c);
addedge(v,u,c);
}
for (int i=0;i<=20;i++)
for (int j=1;j<=n;j++) g[j][i]=1e18;
dfs(1,1);
scanf("%d",&m);
for (int i=1;i<=m;i++)
{
scanf("%d",&num);
for (int j=1;j<=num;j++) scanf("%d",&k[j]),b[k[j]]=true;
work();
}
return 0;
}