BZOJ3583:杰杰的女性朋友 矩阵乘法+谜の卡常数
思路:
我们可以预处理出一个矩阵\(P\),其中\(P[i][j]\)表示从\(i\)到\(j\)的道路数目.
这样我们大概可以搞出从\(u\)到\(v\)正好是\(d\)条路径的方案数.如果是不大于\(d\)条路径,我们就可以加一个计数器,每次累加能够到\(v\)的方案数.
这样就是一个\((N+1)*(N+1)\)的转移矩阵.如果我们每次用这个转移矩阵乘的话就要爆掉了.
发现矩阵\(P\)的算法非常奇葩.我们可以将其分解成一个\((N+1)*(K+1)\)的矩阵\(out\)与一个\((K+1)*(N+1)\)的矩阵\(in\)的乘积.
于是\(P^d=(out*in)^d\).
我们发现\(P^d=(out*in)^d=out*(in*out)^{d-1}*in\).注意到\(in*out\)是一个\((K+1)*(K+1)\)的矩阵,算起来毫无压力.
我们的答案矩阵是一个\(1*(N+1)\)的矩阵,容易发现这些东西乘到一起也只是\(O(NK^2)\)级别的.
因此我们的复杂度就是\(O(Q(K^3logd+NK^2))\).
(注意本题卡常数)
#include<cstdio> #include<cstring> #include<cctype> #include<iostream> #include<algorithm> using namespace std; static const int mod=(1e9)+7; inline void inc(int&x,int y){if((x+=y)>=mod)x-=mod;} struct Matrix{ int d[1010][1010],w,h; Matrix(){} Matrix(int _w,int _h):w(_w),h(_h){memset(d,0,sizeof d);} void output(){ printf("w=%d h=%d\n",w,h); for(int i=1;i<=w;++i){ for(int j=1;j<=h;++j)printf("%d ",d[i][j]);puts(""); } puts(""); } }; long long calc=0; Matrix t,res,ret; Matrix mul(const Matrix&A,const Matrix&B){ res.w=A.w,res.h=B.h; for(int i=1;i<=res.w;++i)for(int j=1;j<=res.h;++j){ res.d[i][j]=0;for(int k=1;k<=A.h;++k)++calc,inc(res.d[i][j],(long long)A.d[i][k]*B.d[k][j]%mod); } return res; } Matrix Pow(const Matrix&x,int y){//for w=h t=ret=x;y--;for(;y;y>>=1,t=mul(t,t))if(y&1)ret=mul(ret,t);return ret; } #define N 1010 #define M 21 int n,m,in[N][M],out[N][M]; Matrix Solve1,Solve2,Solve3,ans; int main(){ #ifndef ONLINE_JUDGE freopen("tt.in","r",stdin); #endif int n,m;scanf("%d%d",&n,&m); register int i,j; for(i=1;i<=n;++i){for(j=1;j<=m;++j)scanf("%d",&out[i][j]);for(j=1;j<=m;++j)scanf("%d",&in[i][j]);} int q,u,v,d; scanf("%d",&q); while(q--){ calc=0;//debug scanf("%d%d%d",&u,&v,&d); if(d==0)printf("%d\n",u==v?1:0); else if(d==1){ int ret=0;for(int i=1;i<=m;++i)ret=(ret+(long long)out[u][i]*in[v][i]%mod)%mod; printf("%d\n",(ret+(u==v?1:0))%mod); } else{ Solve1=Matrix(n+1,m+1),Solve2=Matrix(m+1,n+1); for(i=1;i<=n;++i)for(j=1;j<=m;++j)Solve1.d[i][j]=out[i][j]; Solve1.d[n+1][m+1]=1; for(i=1;i<=n;++i)for(j=1;j<=m;++j)Solve2.d[j][i]=in[i][j]; for(j=1;j<=m;++j)Solve2.d[j][n+1]=in[v][j]; Solve2.d[m+1][n+1]=1; //Solve1.output(); //Solve2.output(); Solve3=mul(Solve2,Solve1);//Solve3.output(); Solve3=Pow(Solve3,d-1); ans=Matrix(1,n+1);ans.d[1][u]=1; ans=mul(ans,Solve1),ans=mul(ans,Solve3),ans=mul(ans,Solve2); printf("%d\n",(ans.d[1][n+1]+(u==v?1:0))%mod); } //printf("%I64d\n",calc); } return 0; }
#include<cstdio> #include<cstring> #include<cctype> #include<iostream> #include<algorithm> using namespace std; static const int mod=(1e9)+7; inline void inc(int&x,int y){if((x+=y)>=mod)x-=mod;} #define N 1010 #define M 22 int n,m,in[N][M],out[N][M]; int Solve1[N][M],Solve2[M][N],Solve3[M][M],ans[N][N],reg[N][N],t[M][M],one[M][M]; inline void work(int d){//solve Solve3^d register int i,j,k; memset(one,0,sizeof one);for(i=1;i<=m+1;++i)one[i][i]=1; for(i=1;i<=m+1;++i)for(j=1;j<=m+1;++j)t[i][j]=Solve3[i][j]; while(d){ if(d&1){ for(i=1;i<=m+1;++i)for(j=1;j<=m+1;++j){ reg[i][j]=0;for(k=1;k<=m+1;++k)inc(reg[i][j],(long long)one[i][k]*t[k][j]%mod); } for(i=1;i<=m+1;++i)for(j=1;j<=m+1;++j)one[i][j]=reg[i][j]; } d>>=1; for(i=1;i<=m+1;++i)for(j=1;j<=m+1;++j){ reg[i][j]=0;for(k=1;k<=m+1;++k)inc(reg[i][j],(long long)t[i][k]*t[k][j]%mod); } for(i=1;i<=m+1;++i)for(j=1;j<=m+1;++j)t[i][j]=reg[i][j]; } for(i=1;i<=m+1;++i)for(j=1;j<=m+1;++j)Solve3[i][j]=one[i][j]; } void pr1(){ puts("Solve1"); for(int i=1;i<=n+1;++i){ for(int j=1;j<=m+1;++j)printf("%9d ",Solve1[i][j]);puts(""); } } void pr2(){ puts("Solve2"); for(int i=1;i<=m+1;++i){ for(int j=1;j<=n+1;++j)printf("%9d ",Solve2[i][j]);puts(""); } } void pr3(){ puts("Solve3"); for(int i=1;i<=m+1;++i){ for(int j=1;j<=m+1;++j)printf("%9d ",Solve3[i][j]);puts(""); } } int main(){ #ifndef ONLINE_JUDGE freopen("tt.in","r",stdin); #endif scanf("%d%d",&n,&m); register int i,j,k; for(i=1;i<=n;++i){for(j=1;j<=m;++j)scanf("%d",&out[i][j]);for(j=1;j<=m;++j)scanf("%d",&in[i][j]);} int q,u,v,d; scanf("%d",&q); while(q--){ scanf("%d%d%d",&u,&v,&d); if(d==0)printf("%d\n",u==v?1:0); else if(d==1){ int ret=0;for(int i=1;i<=m;++i)ret=(ret+(long long)out[u][i]*in[v][i]%mod)%mod; printf("%d\n",(ret+(u==v?1:0))%mod); } else{ memset(Solve1,0,sizeof Solve1); for(i=1;i<=n;++i)for(j=1;j<=m;++j)Solve1[i][j]=out[i][j]; Solve1[n+1][m+1]=1; memset(Solve2,0,sizeof Solve2); for(i=1;i<=n;++i)for(j=1;j<=m;++j)Solve2[j][i]=in[i][j]; for(j=1;j<=m;++j)Solve2[j][n+1]=in[v][j]; Solve2[m+1][n+1]=1; for(i=1;i<=m+1;++i)for(j=1;j<=m+1;++j){ Solve3[i][j]=0;for(k=1;k<=n+1;++k)inc(Solve3[i][j],(long long)Solve2[i][k]*Solve1[k][j]%mod); } work(d-1); for(j=1;j<=n+1;++j)ans[1][j]=0;ans[1][u]=1; for(i=1;i<=1;++i)for(j=1;j<=m+1;++j){ reg[i][j]=0;for(k=1;k<=n+1;++k)inc(reg[i][j],(long long)ans[i][k]*Solve1[k][j]%mod); }for(i=1;i<=1;++i)for(j=1;j<=m+1;++j)ans[i][j]=reg[i][j]; for(i=1;i<=1;++i)for(j=1;j<=m+1;++j){ reg[i][j]=0;for(k=1;k<=m+1;++k)inc(reg[i][j],(long long)ans[i][k]*Solve3[k][j]%mod); }for(i=1;i<=1;++i)for(j=1;j<=m+1;++j)ans[i][j]=reg[i][j]; for(i=1;i<=1;++i)for(j=1;j<=n+1;++j){ reg[i][j]=0;for(k=1;k<=m+1;++k)inc(reg[i][j],(long long)ans[i][k]*Solve2[k][j]%mod); }for(i=1;i<=1;++i)for(j=1;j<=n+1;++j)ans[i][j]=reg[i][j]; printf("%d\n",(ans[1][n+1]+(u==v?1:0))%mod); } } return 0; }