题意:
一棵树,每个点有一种颜色(10以内),一个价值,现在有Q次询问,每次询问在path(u,v)上,对于每种颜色算价值,价值是路径上相同颜色相邻的点的价值差的平方,如:u到v上颜色为1的点价值分别为a,b,c,则产生的价值是 (a?b)2+(b?c)2 ( a ? b ) 2 + ( b ? c ) 2 ,注意,LCA(u,v)的颜色的价值不要
思路:
树上莫队,那么首先对树分块,然后存询问,然后思考下怎么转移
首先,L->L’和R-R’肯定是要分开维护的,所以使用双端队列维护,一个用front,一个用back,其次,路径顺序是对答案有影响的,那么比如我们在转移u->v的时候,要u->lca(u,v)->v,那么我们要先把路径保存下来,再进行处理
错误及反思:
代码:
#include<bits/stdc++.h>
using namespace std;
const int N = 100100;stack<int> s; //分块时需要
int blocks,nowblo;//块大小,当前处于哪个块
int n,m;//节点数,问题数
int to[N][20],depth[N];//LCA
int block[N];
long long now,ans[N];
int head[N],tot;//链式前向星
bool vis[N];//标记路径long long num[15];
int col[N],val[N];
deque<int> c[15];struct EDGE{int to,nex;
}e[N*2];//双向边void add(int x,int y){e[tot].nex=head[x];e[tot].to=y;head[x]=tot++;
}struct Q{int l,r,id;
}q[N];bool cmp(Q a,Q b){ //排序if(block[a.l]!=block[b.l])return block[a.l]<block[b.l];return block[a.r]<block[b.r];
}void dfsblock(int now,int fa,int dep){ //树上分块int si=s.size();to[now][0]=fa; depth[now]=dep;for(int i=head[now];i!=-1;i=e[i].nex){int x=e[i].to;if(x!=fa){dfsblock(x,now,dep+1);if(s.size()-si>=blocks){while(s.size()!=si){block[s.top()]=nowblo;s.pop();}nowblo++;}}}s.push(now);
}void getlca(){ //预处理lcafor(int i=1;i<=18;i++)for(int j=1;j<=n;j++)to[j][i]=to[to[j][i-1]][i-1];
}int lca(int a,int b){ //得到lcaif(depth[a]>depth[b]) swap(a,b);for(int i=18;i>=0;i--) //注意大小if(depth[to[b][i]]>=depth[a])b=to[b][i];if(a==b) return a;for(int i=18;i>=0;i--){ //注意大小if(to[a][i]!=to[b][i]){a=to[a][i];b=to[b][i];}}return to[a][0];
}void change(int u,int v,int x){ //修改的时候将路径取反且不要动lcaint ancs=lca(u,v);while(u!=ancs){if(!vis[u]){if(x){if(c[col[u]].size()>0){now+=1ll*(val[c[col[u]].front()]-val[u])*(val[c[col[u]].front()]-val[u]);num[col[u]]+=1ll*(val[c[col[u]].front()]-val[u])*(val[c[col[u]].front()]-val[u]);}c[col[u]].push_front(u);}else{if(c[col[u]].size()>0){now+=1ll*(val[c[col[u]].back()]-val[u])*(val[c[col[u]].back()]-val[u]);num[col[u]]+=1ll*(val[c[col[u]].back()]-val[u])*(val[c[col[u]].back()]-val[u]);}c[col[u]].push_back(u);}}else{if(x){if(c[col[u]].size()>=2){int w=val[c[col[u]].front()];c[col[u]].pop_front();num[col[u]]-=1ll*(val[c[col[u]].front()]-w)*(val[c[col[u]].front()]-w);now-=1ll*(val[c[col[u]].front()]-w)*(val[c[col[u]].front()]-w);}elsec[col[u]].pop_front();}else{if(c[col[u]].size()>=2){int w=val[c[col[u]].back()];c[col[u]].pop_back();num[col[u]]-=1ll*(val[c[col[u]].back()]-w)*(val[c[col[u]].back()]-w);now-=1ll*(val[c[col[u]].back()]-w)*(val[c[col[u]].back()]-w);}elsec[col[u]].pop_back();}}vis[u]=!vis[u];u=to[u][0];}vector<int> tmp;while(v!=ancs){tmp.push_back(v);v=to[v][0];}for(int i=tmp.size()-1;i>=0;i--){int k=tmp[i];// printf("%d %d\n",ancs,k);if(!vis[k]){if(x){if(c[col[k]].size()>0){now+=1ll*(val[c[col[k]].front()]-val[k])*(val[c[col[k]].front()]-val[k]);num[col[k]]+=1ll*(val[c[col[k]].front()]-val[k])*(val[c[col[k]].front()]-val[k]);}c[col[k]].push_front(k);}else{if(c[col[k]].size()>0){now+=1ll*(val[c[col[k]].back()]-val[k])*(val[c[col[k]].back()]-val[k]);num[col[k]]+=1ll*(val[c[col[k]].back()]-val[k])*(val[c[col[k]].back()]-val[k]);}c[col[k]].push_back(k);}}else{if(x){if(c[col[k]].size()>=2){int w=val[c[col[k]].front()];c[col[k]].pop_front();num[col[k]]-=1ll*(val[c[col[k]].front()]-w)*(val[c[col[k]].front()]-w);now-=1ll*(val[c[col[k]].front()]-w)*(val[c[col[k]].front()]-w);}elsec[col[k]].pop_front();}else{if(c[col[k]].size()>=2){int w=val[c[col[k]].back()];c[col[k]].pop_back();num[col[k]]-=1ll*(val[c[col[k]].back()]-w)*(val[c[col[k]].back()]-w);now-=1ll*(val[c[col[k]].back()]-w)*(val[c[col[k]].back()]-w);}elsec[col[k]].pop_back();}}vis[k]=!vis[k];}
}void init(){memset(head,-1,sizeof(head));memset(vis,false,sizeof(vis));blocks=sqrt(n);nowblo=0;now=0; tot=0;
}int main(){scanf("%d",&n);init(); //初始化for(int i=1;i<=n;i++) scanf("%d",&col[i]);for(int i=1;i<=n;i++) scanf("%d",&val[i]);for(int i=1,u,v;i<n;i++){ //树的边scanf("%d%d",&u,&v);add(v,u); add(u,v);}scanf("%d",&m);for(int i=1;i<=m;i++){ //问题scanf("%d%d",&q[i].l,&q[i].r);q[i].id=i;}dfsblock(1,1,1); //树上分块while(s.empty()){block[s.top()]=nowblo-1;s.pop();}getlca(); //预处理倍增lcasort(q+1,q+m+1,cmp); //问题排序for(int i=1,l=1,r=1;i<=m;i++){if(l!=q[i].l) change(l,q[i].l,0); //左边节点if(r!=q[i].r) change(r,q[i].r,1); //右边节点int ancs=lca(q[i].l,q[i].r);ans[q[i].id]=now-num[col[ancs]];l=q[i].l; r=q[i].r;// printf("%lld %lld\n",num[col[ancs]],now);}for(int i=1;i<=m;i++)printf("%lld\n",ans[i]);
}