Description
有一个树形结构的宾馆,n个房间,n-1条无向边,每条边的长度相同,任意两个房间可以相互到达。吉丽要给他的三个妹子各开(一个)房(间)。三个妹子住的房间要互不相同(否则要打起来了),为了让吉丽满意,你需要让三个房间两两距离相同。
有多少种方案能让吉丽满意?
Input
第一行一个数n。 接下来n-1行,每行两个数x,y,表示x和y之间有一条边相连。
Output
让吉丽满意的方案数。
Sample Input
7
1 2
5 7
2 5
2 3
5 6
4 5
Sample Output
5
HINT
【样例解释】
{1,3,5},{2,4,6},{2,4,7},{2,6,7},{4,6,7}
【数据范围】
n≤5000
题解
首先有个性质
这三个点的路径会交于一点,三个点到这个点的距离都相等且位于同一棵子树
那么 N 2 N^2 N2的就可以直接做了…
这个 O ( N ) O(N) O(N)的有点神仙…
设一个dp状态
f [ i ] [ j ] f[i][j] f[i][j]表示 i i i点子树中有多少距离为 j j j的点
g [ i ] [ j ] g[i][j] g[i][j]表示 i i i点子树中,有多少对,在 i i i的子树外面找一个距离为 j j j的点就能构成一个合法三元组的数量
每扫到一棵子树 y y y
可以这样累加答案
a n s + = ∑ f [ x ] [ j ] ? g [ y ] [ j + 1 ] + g [ x ] [ j ] ? f [ y ] [ j ? 1 ] ans+=\sum f[x][j]*g[y][j+1]+g[x][j]*f[y][j-1] ans+=∑f[x][j]?g[y][j+1]+g[x][j]?f[y][j?1]
然后更新这样更新
g [ x ] [ i ] + = f [ x ] [ i ] ? f [ y ] [ i ? 1 ] g[x][i]+=f[x][i]*f[y][i-1] g[x][i]+=f[x][i]?f[y][i?1]
f [ x ] [ i ] + = f [ y ] [ i ? 1 ] f[x][i]+=f[y][i-1] f[x][i]+=f[y][i?1]
g [ x ] [ i ] + = g [ y ] [ i + 1 ] g[x][i]+=g[y][i+1] g[x][i]+=g[y][i+1]
注意到这样是 ∑ m a x d e p [ i ] \sum maxdep[i] ∑maxdep[i]的
考虑如何优化
我们发现,每个点的第一次转移其实就是找一个儿子,让他的 g g g数组左移一位, f f f数组右移一位
指针移动即可
这样可以节省 ∑ m a x d e p [ s o n [ i ] ] \sum maxdep[son[i]] ∑maxdep[son[i]]的复杂度
长链剖分,每个点选择它的重儿子做更新
其他儿子暴力更新
单考虑重链上的转移,一次是 O ( 1 ) O(1) O(1)的,所以总复杂度不会超过 O ( N ) O(N) O(N)
再考虑轻边,每条重链只会在头的位置被计算完整个重链的长度
显然重链没有重叠
所以轻边的转移总复杂度也是 O ( N ) O(N) O(N)的
于是就是 O ( N ) O(N) O(N)的优秀复杂度了…
空间的话…学习了一下别人的开法毕竟不会指针数组这个东西
upd:
考试忘记怎么开了
那就手写了一个
感觉还挺好写的qaq
#include<cstdio>
#include<cstring>
#include<cstdlib>
#include<algorithm>
#include<cmath>
#include<queue>
#include<vector>
#include<ctime>
#include<map>
#include<bitset>
#include<set>
#define LL long long
#define mp(x,y) make_pair(x,y)
#define pll pair<long long,long long>
#define pii pair<int,int>
#define f(x,i) f1[beginf[x]+i]
#define g(x,i) f2[beging[x]+i]
using namespace std;
inline int read()
{
int f=1,x=0;char ch=getchar();while(ch<'0'||ch>'9'){
if(ch=='-')f=-1;ch=getchar();}while(ch>='0'&&ch<='9'){
x=x*10+ch-'0';ch=getchar();}return x*f;
}
int stack[20];
inline void write(LL x)
{
if(x<0){
putchar('-');x=-x;}if(!x){
putchar('0');return;}int top=0;while(x)stack[++top]=x%10,x/=10;while(top)putchar(stack[top--]+'0');
}
inline void pr1(int x){
write(x);putchar(' ');}
inline void pr2(LL x){
write(x);putchar('\n');}
const int MAXN=100005;
struct edge{
int x,y,next;}a[2*MAXN];int len,last[MAXN];
void ins(int x,int y){
len++;a[len].x=x;a[len].y=y;a[len].next=last[x];last[x]=len;}LL f1[MAXN*2],f2[MAXN*2];
int beginf[MAXN],beging[MAXN],T;int dep[MAXN],son[MAXN],maxdep[MAXN],fa[MAXN];
void pre_tree_node(int x)
{
maxdep[x]=dep[x];for(int k=last[x];k;k=a[k].next){
int y=a[k].y;if(y!=fa[x]){
fa[y]=x;dep[y]=dep[x]+1;pre_tree_node(y);if(maxdep[y]>maxdep[son[x]])son[x]=y;maxdep[x]=max(maxdep[x],maxdep[y]);}}
}
void pre_tree_edge(int x,int u1,int u2)
{
beginf[x]=u1;beging[x]=u2;if(son[x])pre_tree_edge(son[x],u1+1,u2-1);for(int k=last[x];k;k=a[k].next){
int y=a[k].y;if(y!=fa[x]&&y!=son[x]){
int lin=T;T+=(maxdep[y]-dep[y]+1)*2;pre_tree_edge(y,lin+1,T-(maxdep[y]-dep[y]+1)+1);}}
}
int n;
LL ans;
void dp(int x,int fa)
{
f(x,0)=1;if(son[x])dp(son[x],x);ans+=g(x,0);for(int k=last[x];k;k=a[k].next){
int y=a[k].y;if(y!=fa&&y!=son[x]){
dp(y,x);int ln=maxdep[y]-dep[y];for(int i=0;i<=ln;i++)ans+=f(y,i)*g(x,i+1);for(int i=1;i<=ln;i++)ans+=g(y,i)*f(x,i-1);for(int i=0;i<=ln;i++)g(x,i+1)+=f(y,i)*f(x,i+1);for(int i=1;i<=ln;i++)g(x,i-1)+=g(y,i);for(int i=0;i<=ln;i++)f(x,i+1)+=f(y,i);}}
}
int main()
{
n=read();for(int i=1;i<=2*n;i++)f1[i]=f2[i]=0;len=0;memset(last,0,sizeof(last));for(int i=1;i<n;i++){
int x=read(),y=read();ins(x,y);ins(y,x);}memset(fa,0,sizeof(fa));memset(son,0,sizeof(son));pre_tree_node(1);T=2*(maxdep[1]-dep[1]+1);pre_tree_edge(1,1,T-(maxdep[1]-dep[1]+1)+1);ans=0;dp(1,0);pr2(ans);
// }return 0;
}
#include<cstdio>
#include<cstring>
#include<cstdlib>
#include<algorithm>
#include<cmath>
#include<queue>
#include<vector>
#include<ctime>
#include<map>
#include<bitset>
#define LL long long
#define mp(x,y) make_pair(x,y)
using namespace std;
inline int read()
{
int f=1,x=0;char ch=getchar();while(ch<'0'||ch>'9'){
if(ch=='-')f=-1;ch=getchar();}while(ch>='0'&&ch<='9'){
x=x*10+ch-'0';ch=getchar();}return x*f;
}
int stack[20];
inline void write(LL x)
{
if(!x){
putchar('0');return;}int top=0;while(x)stack[++top]=x%10,x/=10;while(top)putchar(stack[top--]+'0');
}
inline void pr1(int x){
write(x);putchar(' ');}
inline void pr2(LL x){
write(x);putchar('\n');}
struct node{
int y,next;}a[210000];int len,last[110000];
void ins(int x,int y){
len++;a[len].y=y;a[len].next=last[x];last[x]=len;}
LL spa[100005*10];
LL *f[100005],*g[100005],*now=spa+100005,ans;
int dep[100005],son[100005];
void init(int x,int fa)
{
for(int k=last[x];k;k=a[k].next){
int y=a[k].y;if(y!=fa){
init(y,x);if(dep[y]>dep[son[x]])son[x]=y;}}dep[x]=dep[son[x]]+1;
}
void newnode(int x)
{
f[x]=now;now=now+2*dep[x]+1;g[x]=now;now=now+2*dep[x]+1;
}
void dp(int x,int fa)
{
f[x][0]=1;if(son[x]){
f[son[x]]=f[x]+1;g[son[x]]=g[x]-1;dp(son[x],x);ans+=g[son[x]][1];}for(int k=last[x];k;k=a[k].next){
int y=a[k].y;if(y!=fa&&y!=son[x]){
newnode(y);dp(y,x);for(int j=dep[y];j>=0;j--){
if(j)ans=ans+g[y][j]*f[x][j-1];ans=ans+g[x][j+1]*f[y][j];g[x][j+1]+=f[x][j+1]*f[y][j];}for(int j=dep[y];j>=0;j--){
if(j)g[x][j-1]+=g[y][j];f[x][j+1]+=f[y][j];}}}
}
int n;
int main()
{
n=read();for(int i=1;i<n;i++){
int x=read(),y=read();ins(x,y);ins(y,x);}init(1,0);newnode(1);dp(1,0);pr2(ans);return 0;
}