虚树就是将树上我们需要的关键信息,浓缩到一颗新的树上,这棵树上除了关键点还有任意一对关键点的 l c a lca lca的信息
这里介绍利用单调栈的做法,首先我们要明确一个目的,我们要用单调栈来维护一条虚树上的链。也就是一个栈里相邻的两个节点在虚树上也是相邻的,而且栈是从底部到栈首单调递增的(指的是栈中节点 DFS 序单调递增)
我们每次遇到一个新的关键点,将他和栈顶求 l c a lca lca,若 l c a lca lca不在栈顶,就弹栈,然后将 l c a lca lca和栈顶比较 d f n dfn dfn,重复这个操作直到栈顶的 d f n dfn dfn序比 l c a lca lca小,然后将 l c a lca lca压进栈
tip:每次弹栈时就将新的栈顶和弹出去的点连边虚树模板题,我们每次将关键点建出一颗虚树,虚树上的边就是原树上该关键点到 l c a lca lca的路径上最小的一条边,然后直接DP
记 f [ i ] f[i] f[i]表示 i i i不和子树内任何一个关键点联通的最小花费,转移就是:
v v v是关键点 f[u]=f[u]+w(u,v); v v v不是关键点 f[u]=f[u]+min( w(u,v), f[v] ); 代码: #include<bits/stdc++.h> using namespace std; namespace zzc { const int maxn = 2.5e5+5; struct island { int u,dfn; }k[maxn]; bool cmp(island a,island b) { return a.dfn<b.dfn; } struct edge { int to,val; edge (int to,int val):to(to),val(val){} }; int kcnt=0,n,m,tim; vector<edge> mp[maxn],vt[maxn]; int dfn[maxn],fa[maxn][25],dep[maxn]; long long g[maxn][25],d[maxn]; bool vis[maxn]; void dfs(int u,int ff,long long w) { dep[u]=dep[ff]+1; dfn[u]=++tim; fa[u][0]=ff; g[u][0]=w; for(int i=0;i<(int)mp[u].size();i++) { int v=mp[u][i].to; long long val=mp[u][i].val; if(v==ff) continue; dfs(v,u,val); } } int lca(int x,int y) { if(dep[x]<dep[y]) swap(x,y); for(int i=20;i>=0;i--) { if(dep[fa[x][i]]>=dep[y]) x=fa[x][i]; } if(x==y) return x; for(int i=20;i>=0;i--) { if(fa[x][i]!=fa[y][i]) { x=fa[x][i]; y=fa[y][i]; } } return fa[x][0]; } int query(int x,int y) { long long ans=0x3f3f3f; if(dep[x]<dep[y]) swap(x,y); for(int i=20;i>=0;i--) { if(dep[fa[x][i]]>=dep[y]) { ans=min(ans,g[x][i]); x=fa[x][i]; } } return ans; } void add(int u,int v) { int w=query(v,u); vt[u].push_back(edge(v,w)); //printf("%d %d %d \n",u,v,w); } void solve(int u) { for(int i=0;i<(int)vt[u].size();i++) { int v=vt[u][i].to; int val=vt[u][i].val; solve(v); if(vis[v]) d[u]+=val; else d[u]+=min(1ll*val,d[v]); vis[v]=false;d[v]=0; } vt[u].clear(); } void work() { int a,b,c; scanf("%d",&n); for(int i=1;i<n;i++) { scanf("%d%d%d",&a,&b,&c); mp[a].push_back(edge(b,c)); mp[b].push_back(edge(a,c)); } dfs(1,0,0); for(int i=1;i<=20;i++) { for(int j=1;j<=n;j++) { fa[j][i]=fa[fa[j][i-1]][i-1]; g[j][i]=min(g[j][i-1],g[fa[j][i-1]][i-1]); } } scanf("%d",&m); while(m--) { int cnt; kcnt=0; scanf("%d",&cnt); for(int i=1;i<=cnt;i++) { scanf("%d",&k[++kcnt].u); k[kcnt].dfn=dfn[k[kcnt].u]; vis[k[kcnt].u]=true; } stack<int> s; sort(k+1,k+cnt+1,cmp); s.push(1); for(int i=1;i<=cnt;i++) { int u=k[i].u; int l=lca(u,s.top()); while(s.top()!=l) { int tmp=s.top();s.pop(); if(dfn[s.top()]<dfn[l]) s.push(l); add(s.top(),tmp); } s.push(u); } while(s.top()!=1) { int tmp=s.top(); s.pop(); add(s.top(),tmp); } solve(1); printf("%lld\n",d[1]); d[1]=0; } } } int main() { zzc::work(); return 0; }