题目链接:点我啊╭(╯^╰)╮
题目大意:
一棵树,两种操作:
①:在点 v v v 放 x x x 个蘑菇
②:将起点变为 v v v
每次操作后计算起点收集所有蘑菇的代价
收集一个蘑菇的代价为起点到终点最短路径上的第一条边权
解题思路:
计算一个点 u u u 的答案,分三部分计算
① : ①: ①:重儿子的所在子树的所有答案
这些蘑菇都以 u u u 到重儿子这条边的边权为代价
所以每次更新点 v v v 时,维护 v v v 到根节点的重链上的蘑菇数
② : ②: ②:所有轻儿子所在子树的答案
这部分在第一步时,轻重链交换时暴力统计
③ : ③: ③:以父亲边边权为代价的答案
用总蘑菇数 ? - ? 上两种情况的蘑菇和即可
核心:轻重链剖分的应用
#include<bits/stdc++.h>
#define rint register int
#define deb(x) cerr<<#x<<" = "<<(x)<<'\n';
#define fi first
#define se second
using namespace std;
typedef long long ll;
using pii = pair <ll,ll>;
const int maxn = 1e6 + 5;
int n, q, dep[maxn], fa[maxn], fv[maxn], size[maxn];
ll sum, t[maxn<<2], lz[maxn<<2], cnt[maxn];
int dfn[maxn], id[maxn], tot, son[maxn], top[maxn];
vector <pii> g[maxn];
pii ans[maxn];void dfs1(int u, int f, int de) {dep[u] = de, fa[u] = f, size[u] = 1;for(auto tmp : g[u]) {int v = tmp.fi;int w = tmp.se;if(v == f) continue;dfs1(v, u, de+1);fv[v] = w;size[u] += size[v];if(size[son[u]] < size[v]) son[u] = v;}
}void dfs2(int u, int tp) {top[u] = tp, dfn[++tot] = u, id[u] = tot;if(son[u]) dfs2(son[u], tp);for(auto tmp : g[u]) {int v = tmp.fi;if(v == fa[u]) continue;if(v == son[u]) continue;dfs2(v, v);}
}void pushdown(int rt) {if(lz[rt]) {t[rt<<1] += lz[rt];t[rt<<1|1] += lz[rt];lz[rt<<1] += lz[rt];lz[rt<<1|1] += lz[rt];lz[rt] = 0;}
}void update(ll x, int L, int R, int l, int r, int rt) {if(l>R || r<L) return;if(l>=L && r<=R) {t[rt] += x;lz[rt] += x;return;}pushdown(rt);int mid = l + r >> 1;update(x, L, R, l, mid, rt<<1);update(x, L, R, mid+1, r, rt<<1|1);t[rt] = t[rt<<1] + t[rt<<1|1];
}ll query(int pos, int l, int r, int rt) {if(pos>r || pos<l) return 0;if(l == r) return t[rt];pushdown(rt);int mid = l + r >> 1; ll ret = 0;ret += query(pos, l, mid, rt<<1);ret += query(pos, mid+1, r, rt<<1|1);return ret;
}void gao(int u, int x) {while(u) {update(x, id[top[u]], id[u], 1, n, 1);u = top[u];ans[fa[u]].fi += 1ll * x * fv[u];ans[fa[u]].se += x;u = fa[u];}
}void solve(int u) {ll res = 0, num = query(id[son[u]], 1, n, 1);res += 1ll * num * fv[son[u]];res += 1ll * (sum - cnt[u] - num - ans[u].se) * fv[u];res += ans[u].fi;printf("%lld\n", res);
}int main() {scanf("%d", &n);for(int i=1, u, v, w; i<n; i++) {scanf("%d%d%d", &u, &v, &w);g[u].push_back({v, w});g[v].push_back({u, w});}dfs1(1, 0, 1);dfs2(1, 1);scanf("%d", &q);int op, v, x, rt = 1;while(q--) {scanf("%d", &op);if(op == 1) {scanf("%d%d", &v, &x);sum += x;cnt[v] += x;gao(v, x);} else scanf("%d", &rt);solve(rt);}
}