前段时间学习过树链剖分,这个题正好试试水,还是比较简单的吧,自己一个人不看博客还是可以A了。
本人比较懒,题解就看「BZOJ1036」[ZJOI2008] 树的统计Count这个博客吧,讲的很好,orz。这种题一般都是按这种套路来。
#include<cstdio>
#include<vector>
#include<algorithm>
using namespace std;const int maxn=30000+1000;
int fa[maxn],son[maxn],size[maxn],du[maxn],top[maxn],Id[maxn],rev[maxn];
int id;
int n;
int val[maxn];
vector<int>G[maxn];
struct Node{
int Max;int Sum;
}node[maxn<<2];void dfs1(int u,int f,int d){
fa[u]=f;du[u]=d;size[u]=1;int len=G[u].size();for(int i=0;i<len;i++){
int v=G[u][i];if(v==f) continue;dfs1(v,u,d+1);size[u]+=size[v];if(size[son[u]]<size[v]) son[u]=v; }
}void dfs2(int u,int t){
top[u]=t;Id[u]=++id;rev[id]=u;if(!son[u]) return ;dfs2(son[u],t);int len=G[u].size();for(int i=0;i<len;i++){
int v=G[u][i];if(v==fa[u] || v==son[u]) continue;dfs2(v,v);}
}void build(int l,int r,int root){
if(l==r){
node[root].Max=val[rev[l]];node[root].Sum=val[rev[l]];return ;}int mid=(l+r)>>1;build(l,mid,root<<1);build(mid+1,r,root<<1|1);node[root].Max=max(node[root<<1].Max,node[root<<1|1].Max);node[root].Sum=node[root<<1].Sum+node[root<<1|1].Sum;
}void Update(int l,int r,int ind,int root,int v){
if(l==r){
node[root].Max=v;node[root].Sum=v;return ;}int mid=(l+r)>>1;if(mid>=ind) Update(l,mid,ind,root<<1,v);else Update(mid+1,r,ind,root<<1|1,v);node[root].Max=max(node[root<<1].Max,node[root<<1|1].Max);node[root].Sum=node[root<<1].Sum+node[root<<1|1].Sum;
}int queryMax(int l,int r,int L,int R,int root){
if(L<=l&&R>=r) return node[root].Max;int mid=(l+r)>>1;if(mid>=R) return queryMax(l,mid,L,R,root<<1);else if(mid<L) return queryMax(mid+1,r,L,R,root<<1|1);else return max(queryMax(l,mid,L,R,root<<1),queryMax(mid+1,r,L,R,root<<1|1));
}int querySum(int l,int r,int L,int R,int root){
if(L<=l&&R>=r) return node[root].Sum;int mid=(l+r)>>1;if(mid>=R) return querySum(l,mid,L,R,root<<1);else if(mid<L) return querySum(mid+1,r,L,R,root<<1|1);else return querySum(l,mid,L,R,root<<1)+querySum(mid+1,r,L,R,root<<1|1);
}int getMax(int u,int v){
int Max=-1e9;while(top[u]!=top[v]){
if(du[top[u]]<du[top[v]]) swap(u,v);Max=max(Max,queryMax(1,n,Id[top[u]],Id[u],1));u=fa[top[u]];}if(du[u]<du[v]) swap(u,v);Max=max(Max,queryMax(1,n,Id[v],Id[u],1));return Max;
}int getSum(int u,int v){
int Sum=0;while(top[u]!=top[v]){
if(du[top[u]]<du[top[v]]) swap(u,v);Sum+=querySum(1,n,Id[top[u]],Id[u],1);u=fa[top[u]];}if(du[u]<du[v]) swap(u,v);Sum+=querySum(1,n,Id[v],Id[u],1);return Sum;
}int main(){
scanf("%d",&n);for(int i=0;i<n-1;i++){
int u,v;scanf("%d%d",&u,&v);G[u].push_back(v);G[v].push_back(u);}for(int i=1;i<=n;i++) scanf("%d",&val[i]);dfs1(1,-1,0);dfs2(1,1);build(1,n,1);int q;scanf("%d",&q);while(q--){
char tmp[100];int x,y;scanf("%s%d%d",tmp,&x,&y);if(tmp[1]=='H') Update(1,n,Id[x],1,y);else if(tmp[1]=='M') printf("%d\n",getMax(x,y));else printf("%d\n",getSum(x,y));}
}