题意:
给出一棵 n 个点的无根树,请在这棵树上选三个互不相同的节点,使得这个三个
节点两两之间距离相等,输出方案数即可。
分析:
定义
设f(x,i)f(x,i)f(x,i)表示在以x为根的子树中,与x距离为i的节点数
g(x,i)g(x,i)g(x,i)表示在以x为根的子树中选择了两个节点,最后一个点需满足与x的距离为i的方案数。
所以f(x,i)=∑uf(u,i?1)f(x,i)=\sum_u f(u,i-1)f(x,i)=∑u?f(u,i?1)
g(x,i)=(∑ug(u,i+1))+(∑u,vf(u,i)?f(v,i))g(x,i)=(\sum_u g(u,i+1))+(\sum_{u,v}f(u,i)*f(v,i))g(x,i)=(∑u?g(u,i+1))+(∑u,v?f(u,i)?f(v,i))
可以通过继承其最大儿子的DP值,使其最大儿子的转移速度为O(1),再O(n)暴力枚举每个轻儿子的最大深度。
继承的具体操作,可以通过预开内存的方式实现。
(即建立一个内存池,然后f(x)从iii位置开始,f(sonx)f(son_x)f(sonx?)就从i+1i+1i+1位置开始,g(x)从i位置开始,g(sonx)g(son_x)g(sonx?)就从i-1位置开始,详见代码中两个dfs)
总的复杂度为O(∑重链长度)=O(n)O(\sum 重链长度)=O(n)O(∑重链长度)=O(n)
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<vector>
#define SF scanf
#define PF printf
#define MAXN 100010
using namespace std;
typedef long long ll;
vector<int> a[MAXN];
ll Gpool[MAXN*2];
ll Fpool[MAXN];
ll *f[MAXN],*g[MAXN],*fcnt=Fpool,*gcnt=Gpool;
int maxl[MAXN],son[MAXN];
void prepare(int x,int fa=0){
for(int i=0;i<int(a[x].size());i++){
int u=a[x][i];if(u==fa)continue;prepare(u,x);maxl[x]=max(maxl[x],maxl[u]+1);}for(int i=0;i<int(a[x].size());i++){
int u=a[x][i];if(u==fa)continue;if(son[x]==0||maxl[u]>maxl[son[x]])son[x]=u;}
}
void dfsf(int x,int fa=0){
f[x]=++fcnt;if(son[x])dfsf(son[x],x);for(int i=0;i<int(a[x].size());i++){
int u=a[x][i];if(u==fa||u==son[x])continue;dfsf(u,x);}
}
void dfsg(int x,int fa=0){
for(int i=0;i<int(a[x].size());i++){
int u=a[x][i];if(u==fa||u==son[x])continue;dfsg(u,x);gcnt+=maxl[u];}if(son[x])dfsg(son[x],x);g[x]=++gcnt;
}
ll ans;
void dp(int x,int fa=0){
f[x][0]=1;if(son[x]==0)return ;dp(son[x],x);ans+=g[x][0];for(int i=0;i<int(a[x].size());i++){
int u=a[x][i];if(u==fa||u==son[x])continue;dp(u,x); for(int i=1;i<=maxl[u]+1;i++)ans+=g[x][i]*f[u][i-1];for(int i=0;i<=maxl[u]-1;i++)ans+=f[x][i]*g[u][i+1];for(int i=0;i<=maxl[u]-1;i++)g[x][i]+=g[u][i+1];for(int i=1;i<=maxl[u]+1;i++)g[x][i]+=(f[u][i-1]*f[x][i]);for(int i=1;i<=maxl[u]+1;i++)f[x][i]+=f[u][i-1];}
// PF("%d %d(%lld):\n",x,maxl[x],ans);
// for(int i=0;i<=maxl[x];i++)
// PF("%d:[%lld %lld]\n",i,f[x][i],g[x][i]);
// PF("-------------------\n");
}
int n,u,v;
void init(){
fcnt=Fpool;gcnt=Gpool;memset(Fpool,0,sizeof Fpool);memset(Gpool,0,sizeof Gpool);memset(maxl,0,sizeof maxl);memset(son,0,sizeof son);for(int i=1;i<=n;i++)a[i].clear();ans=0;
}
int main(){
// freopen("three1-3.in","r",stdin);while(SF("%d",&n)!=EOF){
if(n==0)break;init();for(int i=1;i<n;i++){
SF("%d%d",&u,&v);a[u].push_back(v);a[v].push_back(u); }prepare(1);dfsf(1);dfsg(1);dp(1);PF("%lld\n",ans);}
}