Problem
给出一个有n个节点的树,每个节点i有权值v[i] (1≤v[i]≤n) ,对于三个不同的点x,y,z,如果y一定出现在x到z的路径上且满足v[z]-v[y]=v[y]-v[x]>0,则称公差d(d=v[z]-v[y]=v[y]-v[x])出现过。问有多少不同的公差出现过。
1≤n≤50000
Solution
这样的题一眼就知道要考虑用bitset啦
假设我们现在已经知道了y,看是否存在可行的x和z
那么就是要以v[y]为中心,将左边的翻过去,然后and一下,最后xor得出答案
但是,现在的问题是bitset不支持翻转,于是我们可以开两个bitset,一个存的是v[i],另一个存的是n-v[i]+1,这样我们就可以将翻转操作去掉了。
由于这样做对于每个点都要开一个bitset,这样就导致内存大到飞起,于是有两种解决办法:
1、用莫队算法,但是这样做我们就需要稍微调一下块大小的说。
2、在dfn序上每隔一定数量就记录一个前缀的bitset和后缀的bitset,这样的复杂度是优于莫队的。
Code
第一次在Noi Linux下改题诶,惊恐的发现很多软件都装不了,实在不爽。
#include<iostream>
#include<algorithm>
#include<cstring>
#include<cstdio>
#include<cmath>
#include<set>
#include<map>
#include<bitset>#define fo(i,a,b) for(int i=a;i<=b;i++)
#define fd(i,a,b) for(int i=a;i>=b;i--)using namespace std;typedef long long LL;
typedef double db;int get(){char ch;while(ch=getchar(),(ch<'0'||ch>'9')&&ch!='-');if(ch=='-'){int s=0;while(ch=getchar(),ch>='0'&&ch<='9')s=s*10+ch-'0';return -s;}int s=ch-'0';while(ch=getchar(),ch>='0'&&ch<='9')s=s*10+ch-'0';return s;
}const int N = 50010;
const int blk = 300;bitset<N>t1,t2,ans,T1,T2;
struct section{int l,r,v;
}a[N];
struct edge{int x,nxt;
}e[N*2];
int h[N],tot;
int fa[N],v[N],d[N],k;
int n;
int bl[N];
int s1[N],s2[N];void inse(int x,int y){e[++tot].x=y;e[tot].nxt=h[x];h[x]=tot;
}void dfs(int x){a[x].v=v[fa[x]];d[a[x].l=++k]=x;for(int p=h[x];p;p=e[p].nxt)if (!a[e[p].x].l){fa[e[p].x]=x;dfs(e[p].x);}a[x].r=k;
}bool cmp(section a,section b){if (bl[a.l]<bl[b.l])return bl[a.l]<bl[b.l];return a.r<b.r;
}void add(int v){int v_=n+1-v;s1[v]++;s2[v]--;if (s1[v]==1)t1[v_]=1,T1[v]=1;if (s2[v]==0)t2[v]=0,T2[v_]=0;
}void del(int v){int v_=n+1-v;s1[v]--;s2[v]++;if (s1[v]==0)t1[v_]=0,T1[v]=0;if (s2[v]==1)t2[v]=1,T2[v_]=1;
}void solve(){int l=1,r=0;fo(i,1,n){s2[v[i]]++;if (s2[v[i]]==1)t2[v[i]]=1,T2[n-v[i]+1]=1;}bitset<N>u;fo(i,1,n){while(r<a[i].r)add(v[d[++r]]);while(r>a[i].r)del(v[d[r--]]);while(l<a[i].l)del(v[d[l++]]);while(l>a[i].l)add(v[d[--l]]);int v=a[i].v;if (!v)continue;ans|=((t2>>v)&(t1>>(n+1-v)))|((T2>>(n+1-v))&(T1>>v));u=((t2>>v)&(t1>>(n+1-v)))|((T2>>(n+1-v))&(T1>>v));}
}void getans(){int fan=0;fo(i,1,n)if (ans[i]==1){fan++;}printf("%d\n",fan);
}int main(){n=get();fo(i,1,n)v[i]=get();fo(i,2,n){int x=get(),y=get();inse(x,y);inse(y,x);}dfs(1);fo(i,1,n)bl[i]=i/blk;sort(a+1,a+1+n,cmp);solve();getans();return 0;
}