SP10707 COT2 - Count on a tree II (树上莫队)

it2025-05-24  13

传送门 普通莫队是对一段一维的序列上操作的算法,即使带修莫队也是增加一个时间轴修改的序列也是一维的,如果对于一个树能否操作呢?可以将对树处理成一个欧拉序来实现,欧拉序是在dfs序的基础上在绕回每个点时将这个点填进去的序列。举个例子: dfs序:12345678 欧拉序:1233445526778861 定义:s[i]表示节点i入栈的时间戳,t[i]表示节点i出栈的时间戳。 所以欧拉序的长度是2n的,对树上两点u和v之间的路径有两种情况(假设u的深度比v小): u是v的祖先:那么s[u]到s[v]之间的点(出现两次的点除外)都在u到v之间的路径上,u=1,v=5时就对应欧拉序上区间[1,7]。 u不是v的祖先,这种情况下s[u]到s[v]之间欧拉序u会出现两次,所以要用t[u]到s[v],但是这样的话u和v的lca没有出现在欧拉序内,所以移动区间时还要将lca加进去,记得统计过后在把lca的贡献去掉。 这样就能将树上的路径查询转变为欧拉序的区间查询,套莫队就完了。 欧拉序dfs就可以,求lca可以树剖,欧拉序顺带就出来了。

#include<iostream> #include<cstdio> #include<algorithm> #include<cstring> #include<unordered_map> #include<cmath> using namespace std; typedef long long ll; const int N=100010; int n,m,a[N],nums[N]; int h[N],ne[N],e[N],idx; void add(int a,int b) { e[idx]=b;ne[idx]=h[a];h[a]=idx++; } int dep[N],sz[N],son[N],fa[N],dfn[N],top[N],cnt; int oula[N],s[N],t[N]; void dfs1(int u,int p) { sz[u]=1; oula[++cnt]=u; s[u]=cnt; for(int i=h[u];~i;i=ne[i]) { int v=e[i]; if(v==p) continue; dep[v]=dep[u]+1; fa[v]=u; dfs1(v,u); sz[u]+=sz[v]; if(sz[v]>sz[son[u]]) son[u]=v; } oula[++cnt]=u; t[u]=cnt; } void dfs2(int u,int t) { top[u]=t; if(son[u]) dfs2(son[u],t); for(int i=h[u];~i;i=ne[i]) { int v=e[i]; if(v==fa[u]||v==son[u]) continue; dfs2(v,v); } } void Swap(int &x,int &y){ x^=y^=x^=y;} int get_lca(int x,int y) { while(top[x]!=top[y]) { if(dep[top[x]]>dep[top[y]]) Swap(x,y); y=fa[top[y]]; } return dep[x]>dep[y]?y:x; } int block; struct node { int l,r,lca,id; bool operator < (const node &w) const { if((l/block)==(w.l/block)) { if((l/block)&1) return r>w.r; return r<w.r; } return l<w.l; } }ed[N]; int ans,res[N],num[N]; int mp[N]; void add(int x) { if(mp[x]==0) ans++; mp[x]++; } void del(int x) { if(mp[x]==1) ans--; mp[x]--; } void check(int x) { if(num[x]==0) add(a[x]);//在区间[l,r]内有0个或者2个x else del(a[x]); num[x]^=1; } int main() { memset(h,-1,sizeof h); scanf("%d%d",&n,&m); for(int i=1;i<=n;i++) scanf("%d",a+i),nums[i]=a[i]; sort(nums+1,nums+n+1); int len=unique(nums+1,nums+n+1)-(nums+1);//离散化再预处理一下 for(int i=1;i<=n;i++) a[i]=lower_bound(nums+1,nums+len+1,a[i])-nums; int u,v; for(int i=1;i<n;i++) { scanf("%d%d",&u,&v); add(u,v);add(v,u); } dfs1(1,-1); dfs2(1,1); for(int i=1;i<=m;++i) { scanf("%d%d",&u,&v); if(s[u]>s[v]) Swap(u,v); ed[i].lca=get_lca(u,v); if(ed[i].lca==u) ed[i].l=s[u],ed[i].r=s[v];ed[i].lca=0;//u是v的lca else ed[i].l=t[u],ed[i].r=s[v]; ed[i].id=i; } block=n*2/sqrt(m); sort(ed+1,ed+m+1); int l=0,r=-1; for(int i=1;i<=m;i++) { while(l<ed[i].l) check(oula[l++]); while(l>ed[i].l) check(oula[--l]); while(r>ed[i].r) check(oula[r--]); while(r<ed[i].r) check(oula[++r]); if(ed[i].lca) add(a[ed[i].lca]); res[ed[i].id]=ans; if(ed[i].lca) del(a[ed[i].lca]); } for(int i=1;i<=m;++i) printf("%d\n",res[i]); }
最新回复(0)