[SDOI2016] 游戏(树链剖分+李超线段树)

it2023-07-13  64

题意

其中, n , m ≤ 100000 , ∣ a ∣ ≤ 10000 n,m\leq 100000,|a|\leq 10000 n,m100000,a10000

分析

i i i 到根的距离为 d i s i dis_i disi。 首先化简一下式子,从 s s s l c a ( s , t ) lca(s,t) lca(s,t) 路径: a d d = a × ( d i s s − d i s i ) + b = − a × d i s i + ( b + a × d i s s ) add=a\times (dis_s-dis_i)+b=-a\times dis_i+(b+a\times dis_s) add=a×(dissdisi)+b=a×disi+(b+a×diss) 可以看出是一条线段的形式, x x x 坐标是 d i s dis dis。从 l c a ( s , t ) lca(s,t) lca(s,t) t t t 同理。 然后,考虑树链剖分,这样对于一条重链的区间,它的 d i s dis dis 是递增的。 因此,这个问题就变得明显了起来,就是树链剖分后用李超线段树加线段。 不过,这里的李超线段树每个节点还要维护一个区间的最小值。考虑李超线段树的性质,这个东西是可以维护的,就是每次用最优线段去尝试更新这个区间的最值(最值肯定在两端取到),同时,还要像普通线段树那样从子节点向父节点更新。 时间复杂度是 O ( n l o g 4 n ) O(nlog^4n) O(nlog4n) 的,但是由于树剖和李超树的常数很小,可以在 1 s 1s 1s 内跑过。

代码如下

#include <bits/stdc++.h> #include <ext/pb_ds/hash_policy.hpp> #include <ext/pb_ds/assoc_container.hpp> #define lson l, m, rt << 1 #define rson m + 1, r, rt << 1 | 1 #define int long long using namespace __gnu_pbds; using namespace std; typedef long long LL; typedef unsigned long long uLL; LL z = 1; int ksm(int a, int b, int p){ int s = 1; while(b){ if(b & 1) s = z * s * a % p; a = z * a * a % p; b >>= 1; } return s; } const int N = 1e5 + 5, inf = 123456789123456789; struct node{ int a, b, c, n; }d[N * 2]; int h[N], son[N], dfn[N], siz[N], top[N], fa[N], dep[N], dis[N], dft, re[N], n, cnt; int tag[N * 4], K[N * 2], B[N * 2], mn[N * 4], tot; int get(int p, int x){ return K[p] * dis[re[x]] + B[p]; } void cr(int a, int b, int c){ d[++cnt] = {a, b, c, h[a]}, h[a] = cnt; } void dfs1(int a){ siz[a] = 1; for(int i = h[a]; i; i = d[i].n){ int b = d[i].b, c = d[i].c; if(b == fa[a]) continue; fa[b] = a; dis[b] = dis[a] + c; dep[b] = dep[a] + 1; dfs1(b); siz[a] += siz[b]; if(siz[b] >= siz[son[a]]) son[a] = b; } } void dfs2(int a, int f){ top[a] = f, dfn[a] = ++dft, re[dft] = a; if(son[a]) dfs2(son[a], f); for(int i = h[a]; i; i = d[i].n){ int b = d[i].b; if(b != fa[a] && b != son[a]) dfs2(b, b); } } int lca(int a, int b){ int f1 = top[a], f2 = top[b]; while(f1 != f2){ if(dep[f1] < dep[f2]) swap(f1, f2), swap(a, b); a = fa[f1], f1 = top[a]; } return dep[a] < dep[b]? a: b; } void build(int l, int r, int rt){ tag[rt] = 1, mn[rt] = inf; if(l == r) return; int m = l + r >> 1; build(lson); build(rson); } void update(int l, int r, int rt, int a, int b, int u){ int m = l + r >> 1, &v = tag[rt]; if(l >= a && r <= b){ if(get(u, m) < get(v, m)) swap(u, v); mn[rt] = min(mn[rt], min(get(v, l), get(v, r)));//更新 mn[rt] if(l == r) return; if(get(u, l) < get(v, l)) update(lson, a, b, u); else if(get(u, r) < get(v, r)) update(rson, a, b, u); mn[rt] = min(mn[rt], min(mn[rt << 1], mn[rt << 1 | 1]));//比普通李超树就多了这个 return; } if(a <= m) update(lson, a, b, u); if(b > m) update(rson, a, b, u); mn[rt] = min(mn[rt], min(mn[rt << 1], mn[rt << 1 | 1]));//比普通李超树就多了这个 } int query(int l, int r, int rt, int a, int b){ if(l >= a && r <= b) return mn[rt];//如果到达某一被覆盖的节点,直接返回最值 int ans = min(get(tag[rt], max(a, l)), get(tag[rt], min(r, b))), m = l + r >> 1;//注意这里的实际区间是 [max(a, l), min(b, r)],用当前节点的优势线段求一下两端 if(a <= m) ans = min(ans, query(lson, a, b)); if(b > m) ans = min(ans, query(rson, a, b)); return ans; } void add(int a, int b, int p){ int f1 = top[a], f2 = top[b]; while(f1 != f2){ update(1, n, 1, dfn[f1], dfn[a], p);//往重链加线段 a = fa[f1], f1 = top[a]; } update(1, n, 1, dfn[b], dfn[a], p);//往重链加线段 } int find(int a, int b){ int f1 = top[a], f2 = top[b], ans = inf; while(f1 != f2){ if(dep[f1] < dep[f2]) swap(a, b), swap(f1, f2); ans = min(ans, query(1, n, 1, dfn[f1], dfn[a])); a = fa[f1], f1 = top[a]; } if(dep[a] < dep[b]) swap(a, b); ans = min(ans, query(1, n, 1, dfn[b], dfn[a])); return ans; } main(){ ios::sync_with_stdio(false); cin.tie(0), cout.tie(0); int m; cin >> n >> m; for(int i = 1; i < n; i++){ int a, b, c; cin >> a >> b >> c; cr(a, b, c); cr(b, a, c); } dfs1(1); dfs2(1, 1); B[tot = 1] = inf; build(1, n, 1);//初始化线段树的每个节点 for(int i = 1; i <= m; i++){ int o; cin >> o; if(o == 1){ int s, t, a, b; cin >> s >> t >> a >> b; int ff = lca(s, t); K[++tot] = -a, B[tot] = b + a * dis[s];//两种情况 add(s, ff, tot); K[++tot] = a, B[tot] = b + a * dis[s] - a * 2 * dis[ff]; add(t, ff, tot); } else{ int s, t; cin >> s >> t; cout << find(s, t) << '\n'; } } return 0; }
最新回复(0)