题意
JYY有两棵树A和B:树A有N个点,编号为1到N;树B有N+1个点,编号为1到N+1。JYY知道树B恰好是由树A加上一个叶节点,然后将节点的编号打乱后得到的。他想知道,这个多余的叶子到底是树B中的哪一个叶节点呢?
题解
好久没有写过树hash了。。并不知道怎么写简单
Rose告诉了我一个不错的hash方法
我们只需要fx=base×(Πfson+totx)f_x=base\times (\Pi f_{son}+tot_x)fx?=base×(Πfson?+totx?)
fff是每一个子树的hash值
可以发现,这玩意支持换根,那么就不限于找重心了
那第一颗树所有的hash值丢进去
第二颗树删除一个节点的hash值也可以用类似的方法弄出来
弄个set,找一找有没有出现过就好了
一开始base设为2333,模数为1e9+7被卡了。。
然后把base改为了233才过
em…脸有点黑
CODE:
#include<cstdio>
#include<algorithm>
#include<iostream>
#include<cstring>
#include<set>
using namespace std;
typedef long long LL;
const int MOD=1e9+7;
const int base=233;
const int N=100005;
int Inv;
int add (int x,int y) {
x=x+y;return x>=MOD?x-MOD:x;}
int mul (int x,int y) {
return (LL)x*y%MOD;}
int dec (int x,int y) {
x=x-y;return x<0?x+MOD:x;}
int Pow (int x,int y)
{
if (y==1) return x;int lalal=Pow(x,y>>1);lalal=mul(lalal,lalal);if (y&1) lalal=mul(lalal,x);return lalal;
}
int n;
struct qq
{
int x,y,last;
}e[N*2];int num,last[N];
void init (int x,int y)
{
e[++num].x=x;e[num].y=y;e[num].last=last[x];last[x]=num;
}
int f[N];
int tot[N];
void dfs (int x,int fa)
{
tot[x]=1;f[x]=1;for (int u=last[x];u!=-1;u=e[u].last){
int y=e[u].y;if (y==fa) continue;dfs(y,x);tot[x]=tot[x]+tot[y];f[x]=mul(f[x],f[y]);}f[x]=add(f[x],tot[x]);f[x]=mul(f[x],base);
}
set<int> s;
void dfs1 (int x,int fa)
{
s.insert(f[x]);int lalal=dec(mul(f[x],Inv),n),xx;for (int u=last[x];u!=-1;u=e[u].last){
int y=e[u].y;if (y==fa) continue;xx=mul(lalal,Pow(f[y],MOD-2));xx=add(xx,n-tot[y]);xx=mul(xx,base);f[y]=mul(f[y],Inv);f[y]=dec(f[y],tot[y]);f[y]=mul(f[y],xx);f[y]=add(f[y],n);f[y]=mul(f[y],base);dfs1(y,x);}
}
int du[N];
void dfs2 (int x,int fa)
{
int lalal=dec(mul(f[x],Inv),n),xx;for (int u=last[x];u!=-1;u=e[u].last){
int y=e[u].y;if (y==fa) continue;xx=mul(lalal,Pow(f[y],MOD-2));xx=add(xx,n-tot[y]);xx=mul(xx,base);f[y]=mul(f[y],Inv);f[y]=dec(f[y],tot[y]);f[y]=mul(f[y],xx);f[y]=add(f[y],n);f[y]=mul(f[y],base);dfs2(y,x);}
}
int main()
{
Inv=Pow(base,MOD-2);num=0;memset(last,-1,sizeof(last));scanf("%d",&n);for (int u=1;u<n;u++){
int x,y;scanf("%d%d",&x,&y);init(x,y);init(y,x);}dfs(1,0);dfs1(1,0);/*for (int u=1;u<=n;u++) printf("%d ",f[u]);printf("\n");*/num=0;memset(last,-1,sizeof(last));n++;for (int u=1;u<n;u++){
int x,y;scanf("%d%d",&x,&y);du[x]++;du[y]++;init(x,y);init(y,x);}dfs(1,0);dfs2(1,0);for (int u=1;u<=n;u++){
if (du[u]==1){
int xx=mul(f[u],Inv);xx=dec(xx,n);if (s.find(xx)!=s.end()){
printf("%d\n",u);break;}}}return 0;
}