当前位置: 代码迷 >> 综合 >> 树的统计Count HYSBZ - 1036 (树链剖分)
  详细解决方案

树的统计Count HYSBZ - 1036 (树链剖分)

热度:33   发布时间:2024-01-14 22:20:09.0

题目 https://cn.vjudge.net/problem/HYSBZ-1036

树链剖分 

#include <algorithm>
#include <iostream>
#include <cstdio>
#include <cstdlib>
#include <map>
#include <cstring>
#include <vector>
using namespace std;
#define lson o<<1
#define rson o<<1|1
#define MID int m = (l+r)/2
const int inf = 0x3f3f3f3f;
const int maxn = 300000 + 100;
int n;
vector<int> edge[maxn];
int data[maxn];
int cnt;
int fa[maxn],deep[maxn];
int siz[maxn],son[maxn],top[maxn],tid[maxn];
int id_data[maxn];
struct Info
{int sum,MAX;
}tree[maxn*10];void build(int o,int l,int r)
{if(l == r){tree[o].sum = id_data[l];tree[o].MAX = id_data[l];return ;}MID;build(lson,l,m);build(rson,m+1,r);tree[o].sum = tree[lson].sum + tree[rson].sum;tree[o].MAX = max(tree[lson].MAX,tree[rson].MAX);
}
void updata(int o,int l, int r, int x, int y)
{if(l > x || r < x) return ;if(l == x && r == x){tree[o].sum = y;tree[o].MAX = y;return ;}MID;updata(lson,l,m,x,y);updata(rson,m+1,r,x,y);tree[o].sum = tree[lson].sum + tree[rson].sum;tree[o].MAX = max(tree[lson].MAX,tree[rson].MAX);
}
int query1(int o, int l, int r, int ul, int ur)
{if(ul > r || ur < l) return 0;if(ul <= l && r <= ur){return tree[o].sum;}MID;return query1(lson,l,m,ul,ur) + query1(rson,m+1,r,ul,ur);
}
int query2(int o, int l, int r, int ul, int ur)
{if(ul > r || ur < l) return -333333;if(ul <= l && r <= ur){return tree[o].MAX;}MID;return max(query2(lson,l,m,ul,ur) , query2(rson,m+1,r,ul,ur));
}
int Query1(int x,int y)
{if(x == y) return query1(1,1,n,tid[x],tid[x]);int tx = top[x],ty = top[y];int ans = 0;while(tx != ty){if(deep[tx] < deep[ty]) swap(x,y),swap(tx,ty);ans = ans + query1(1,1,n,tid[tx],tid[x]);x = fa[tx],tx = top[x];}if(deep[x] > deep[y]) swap(x,y);ans = ans +  query1(1,1,n,tid[x],tid[y]);return ans;
}
int Query2(int x,int y)
{if(x == y) return query2(1,1,n,tid[x],tid[x]);int tx = top[x],ty = top[y];int ans = -inf;while(tx != ty){if(deep[tx] < deep[ty]) swap(x,y),swap(tx,ty);ans = max(ans,query2(1,1,n,tid[tx],tid[x]));x = fa[tx],tx = top[x];}if(deep[x] > deep[y]) swap(x,y);ans = max(ans,query2(1,1,n,tid[x],tid[y]));return ans;
}
void dffs(int u,int f,int d)
{fa[u] = f,deep[u] = d;siz[u] = 1,son[u] = -1;for(int i = 0; i < edge[u].size(); i++){int v = edge[u][i];if(v != f){dffs(v,u,d+1);siz[u] += siz[v];if(son[u] == -1||siz[son[u]] < siz[v]){son[u] = v;}}}
}
void dfss(int u,int t)
{tid[u] = ++cnt;top[u] = t;id_data[cnt] = data[u];if(son[u] != -1){dfss(son[u],t);}for(int i = 0;i<edge[u].size();i++){int v = edge[u][i];if(son[u] != v && fa[u] != v) dfss(v,v);}
}
int main()
{scanf("%d", &n);for(int i=1;i<=n;i++) edge[i].clear();for(int i = 1; i < n; i++){int u,v;scanf("%d %d", &u, &v);edge[u].push_back(v);edge[v].push_back(u);}for(int i = 1;i <= n;i++){scanf("%d",&data[i]);}cnt = 0;dffs(1,-1,0);dfss(1,1);build(1,1,n);int t;scanf("%d", &t);char str[10];while(t--){scanf("%s",str);int a, b, c;if(strcmp(str,"CHANGE") == 0){scanf("%d %d",&a, &b);updata(1,1,n,tid[a],b);}else{scanf("%d %d",&a, &b);if(strcmp(str,"QMAX") == 0){printf("%d\n",Query2(a,b));}else{printf("%d\n",Query1(a,b));}}}
}

 

  相关解决方案