JZOJ【NOIP2017提高A组模拟9.7】计数题

it2025-02-16  4

JZOJ 【NOIP2017提高A组模拟9.7】计数题

题目

Description

Input

Output

Sample Input

5 2 2 3 4 5

Sample Output

8 6

Data Constraint

题解

题意

给出 a [ i ] a[i] a[i],有一完全图, i i i j j j之间的边的值为 a [ i ] ⊕ a [ j ] a[i] \oplus a[j] a[i]a[j] ⊕ \oplus 为异或的意思) 求最小生成树及方案数

题解

科普一个东西, n n n个点的完全图的生成树个数是 n n − 2 n^{n-2} nn2 这个东西叫做凯莱定理,大家可以自行了解一下


100 % \% % 看到异或,而且要最小,且 a [ i ] a[i] a[i]二进制做多只有30位 想到可以按照最高位往下分治,分成当前这位是0和1的两堆,然后为了取值最小,那么这两堆只能连一条 那么就找到这两堆里面异或值最小的,这是 t r i e trie trie应用的经典问题 然后分治一位一位往下 最后把所有最小值加一起,方案数乘起来即可

Code

#include<cmath> #include<cstdio> #include<algorithm> #define mod 1000000007 using namespace std; long long n,mx,num,ans,ans1,tot,a[1000001],er[31],c1[1000001],c2[1000001]; struct node { long long left,right,size; }trie[400005]; long long read() { long long res=0;char ch=getchar(); while (ch<'0'||ch>'9') ch=getchar(); while (ch>='0'&&ch<='9') res=(res<<1)+(res<<3)+(ch-'0'),ch=getchar(); return res; } void insert(long long x) { long long now=1; ++trie[now].size; for (long long i=mx;i>=0;--i) { if (x&er[i]) { if (trie[now].left==0) trie[now].left=++num,trie[num].left=trie[num].right=trie[num].size=0; now=trie[now].left; ++trie[now].size; } else { if (trie[now].right==0) trie[now].right=++num,trie[num].left=trie[num].right=trie[num].size=0; now=trie[now].right; ++trie[now].size; } } } long long calc(long long x) { long long now=1,s=0; for (long long i=mx;i>=0;--i) { if (x&er[i]) { if (trie[trie[now].left].size>0) now=trie[now].left; else s+=er[i],now=trie[now].right; } else { if (trie[trie[now].right].size>0) now=trie[now].right; else s+=er[i],now=trie[now].left; } } tot=trie[now].size; return s; } long long ksm(long long x,long long y) { long long res=1; while (y) { if (y&1) res=res*x%mod; x=x*x%mod; y>>=1; } return res; } long long dg(long long l,long long r,long long d) { if (r<=l) return 1; if (d<0) return ksm(r-l+1,r-l-1); long long t1=0,t2=0; for (long long i=l;i<=r;++i) { if (a[i]&er[d]) c1[++t1]=a[i]; else c2[++t2]=a[i]; } for (long long i=1;i<=t1;++i) a[l+i-1]=c1[i]; for (long long i=1;i<=t2;++i) a[l+t1+i-1]=c2[i]; long long s1=dg(l,l+t1-1,d-1),s2=dg(l+t1,r,d-1); long long s3=(s1*s2)%mod,s4=2147483647,s5=0; if (t1==0||t2==0) return s3; num=1; trie[1].left=trie[1].right=trie[1].size=0; for (long long i=1;i<=t1;++i) insert(a[l+i-1]); for (long long i=1;i<=t2;++i) { long long sum=calc(a[l+t1+i-1]); if (sum<s4) s4=sum,s5=tot; else if (sum==s4) s5=(s5+tot)%mod; } ans+=s4; return (s3*s5)%mod; } int main() { freopen("jst.in","r",stdin); freopen("jst.out","w",stdout); n=read(); for (long long i=1;i<=n;++i) a[i]=read(),mx=max(mx,a[i]); mx=log2(mx); er[0]=1; for (long long i=1;i<=31;++i) er[i]=er[i-1]*2%mod; num=1; ans1=dg(1,n,mx); printf("%lld\n%lld\n",ans,ans1); fclose(stdin); fclose(stdout); return 0; }
最新回复(0)