BZOJ1129: [POI2008]Per 数学+树状数组
一道水题各种想错+写错搞了整整一天多,大概是看番太多了吧。。
算了就不说废话了。
首先显然应该用CRT对吧。
对于每个$p_i^{q_i}$算一个答案再合并起来就行了。
考虑如何算排名,分别计算第一位比这个排列要小的,在计算第二位比这个排列要小的。。。以此类推加起来就得到排名了。
考虑如何计算第一位小的。
假设共有$m$种数字,分别有$a_1,...a_m$种,总长度为$n$,让第$i$种数字排在最前面,那么剩下的数字自由组合方案数为:
\[\frac{n!}{{a_1}!{a_2}!...{a_{i-1}}!{a_{i+1}}!...{a_m}!}\]
那么如果是前$i$种数字的话,答案就应该是:
\[\frac{n!}{\prod_{j=1}^{m}{a_j}!}\times\sum_{j=1}^{i}a_j\]
用树状数组维护$a_i$的前缀和,利用一个计数器数组维护目前每种数字的个数。
算的时候将与$p$互质的单独拿出来算,详细见代码。
时间复杂度$O(nlognlogp)$。
代码:
#include<cstdio> #include<cstring> #include<cctype> #include<iostream> #include<algorithm> using namespace std; typedef long long ll; int getc(){ static const int L=1<<15; static char buf[L],*S=buf,*T=buf; if(S==T){ T=(S=buf)+fread(buf,1,L,stdin); if(S==T) return EOF; } return *S++; } int getint(){ int c; while(!isdigit(c=getc())); int x=c-'0'; while(isdigit(c=getc())) x=(x<<1)+(x<<3)+c-'0'; return x; } void exgcd(ll a,ll b,ll&d,ll&x,ll&y){ if(!b){ d=a; x=1; y=0; } else{ exgcd(b,a%b,d,y,x); y-=x*(a/b); } } int inv(int a,int n){ ll d,x,y; exgcd(a,n,d,x,y); return (x%n+n)%n; } #define N 300010 int n,m,_n,a[N],b[N],c[N]; int res,p,p_q; int f[N],num[N],pow[N]; int A[N]; void up(int x,int c){ for(;x<=n;x+=x&-x) A[x]+=c; } int ask(int x){ int res=0; for(;x;x^=(x&-x)) res+=A[x]; return res; } void inc(int&x,int y){ if((x+=y)>=p_q) x-=p_q; } void work(){ int i,j,k,_i; for(f[0]=1,i=1;i<=n;++i){ for(_i=i;_i%p==0;_i/=p); f[i]=(ll)f[i-1]*_i%p_q; } for(i=1;i<=n;++i) num[i]=num[i/p]+i/p; for(pow[0]=1,i=1;i<=n;++i) pow[i]=(ll)pow[i-1]*p%p_q; for(i=1;i<=n;++i) A[i]=c[i]=0; for(i=1;i<=n;++i) ++c[a[i]],up(a[i],1); int linv=1,lnum=0,ans=0; for(i=1;i<=_n;++i){ lnum+=num[c[i]]; linv=(ll)linv*inv(f[c[i]],p_q)%p_q; } for(i=1;i<=n;++i){ inc(ans,(ll)f[n-i]*linv%p_q*pow[num[n-i]-lnum]%p_q*ask(a[i]-1)%p_q); up(a[i],-1); lnum=lnum-num[c[a[i]]]+num[c[a[i]]-1]; linv=(ll)linv*f[c[a[i]]]%p_q*inv(f[c[a[i]]-1],p_q)%p_q; --c[a[i]]; } inc(ans,1); res=(res+(ll)ans*(m/p_q)%m*inv(m/p_q,p_q)%m)%m; } int main(){ #ifndef ONLINE_JUDGE freopen("tt.in","r",stdin); #endif int i,j; n=getint(); m=getint(); for(i=1;i<=n;++i) a[i]=b[i]=getint(); sort(b+1,b+n+1); _n=unique(b+1,b+n+1)-b-1; for(i=1;i<=n;++i) a[i]=lower_bound(b+1,b+_n+1,a[i])-b; int _m=m; for(i=2;i*i<=m;++i){ if(_m%i==0){ for(p=i,p_q=1;_m%i==0;_m/=i,p_q*=i); work(); } } if(_m!=1) p=p_q=_m,work(); cout<<res<<endl; return 0; }