BZOJ3583:杰杰的女性朋友 矩阵乘法+谜の卡常数

思路:

我们可以预处理出一个矩阵\(P\),其中\(P[i][j]\)表示从\(i\)到\(j\)的道路数目.

这样我们大概可以搞出从\(u\)到\(v\)正好是\(d\)条路径的方案数.如果是不大于\(d\)条路径,我们就可以加一个计数器,每次累加能够到\(v\)的方案数.

这样就是一个\((N+1)*(N+1)\)的转移矩阵.如果我们每次用这个转移矩阵乘的话就要爆掉了.

 

发现矩阵\(P\)的算法非常奇葩.我们可以将其分解成一个\((N+1)*(K+1)\)的矩阵\(out\)与一个\((K+1)*(N+1)\)的矩阵\(in\)的乘积.

于是\(P^d=(out*in)^d\).

我们发现\(P^d=(out*in)^d=out*(in*out)^{d-1}*in\).注意到\(in*out\)是一个\((K+1)*(K+1)\)的矩阵,算起来毫无压力.

我们的答案矩阵是一个\(1*(N+1)\)的矩阵,容易发现这些东西乘到一起也只是\(O(NK^2)\)级别的.

 

因此我们的复杂度就是\(O(Q(K^3logd+NK^2))\).

 

(注意本题卡常数)

#include<cstdio>
#include<cstring>
#include<cctype>
#include<iostream>
#include<algorithm>
using namespace std;
static const int mod=(1e9)+7;
inline void inc(int&x,int y){if((x+=y)>=mod)x-=mod;}
struct Matrix{
    int d[1010][1010],w,h;
    Matrix(){}
    Matrix(int _w,int _h):w(_w),h(_h){memset(d,0,sizeof d);}
    void output(){
        printf("w=%d h=%d\n",w,h);
        for(int i=1;i<=w;++i){
            for(int j=1;j<=h;++j)printf("%d ",d[i][j]);puts("");
        }
        puts("");
    }
};
long long calc=0;
Matrix t,res,ret;
Matrix mul(const Matrix&A,const Matrix&B){
    res.w=A.w,res.h=B.h;
    for(int i=1;i<=res.w;++i)for(int j=1;j<=res.h;++j){
        res.d[i][j]=0;for(int k=1;k<=A.h;++k)++calc,inc(res.d[i][j],(long long)A.d[i][k]*B.d[k][j]%mod);
    }
    return res;
}
Matrix Pow(const Matrix&x,int y){//for w=h
    t=ret=x;y--;for(;y;y>>=1,t=mul(t,t))if(y&1)ret=mul(ret,t);return ret;
}
#define N 1010
#define M 21
int n,m,in[N][M],out[N][M];
Matrix Solve1,Solve2,Solve3,ans;
int main(){
    #ifndef ONLINE_JUDGE
    freopen("tt.in","r",stdin);
    #endif
    int n,m;scanf("%d%d",&n,&m);
    register int i,j;
    for(i=1;i<=n;++i){for(j=1;j<=m;++j)scanf("%d",&out[i][j]);for(j=1;j<=m;++j)scanf("%d",&in[i][j]);}
    int q,u,v,d;
    scanf("%d",&q);
    while(q--){
        calc=0;//debug
        scanf("%d%d%d",&u,&v,&d);
        if(d==0)printf("%d\n",u==v?1:0);
        else if(d==1){
            int ret=0;for(int i=1;i<=m;++i)ret=(ret+(long long)out[u][i]*in[v][i]%mod)%mod;
            printf("%d\n",(ret+(u==v?1:0))%mod);
        }
        else{
            Solve1=Matrix(n+1,m+1),Solve2=Matrix(m+1,n+1);
            for(i=1;i<=n;++i)for(j=1;j<=m;++j)Solve1.d[i][j]=out[i][j];
            Solve1.d[n+1][m+1]=1;
            for(i=1;i<=n;++i)for(j=1;j<=m;++j)Solve2.d[j][i]=in[i][j];
            for(j=1;j<=m;++j)Solve2.d[j][n+1]=in[v][j];
            Solve2.d[m+1][n+1]=1;
            //Solve1.output();
            //Solve2.output();
            Solve3=mul(Solve2,Solve1);//Solve3.output();
            Solve3=Pow(Solve3,d-1);
            ans=Matrix(1,n+1);ans.d[1][u]=1;
            ans=mul(ans,Solve1),ans=mul(ans,Solve3),ans=mul(ans,Solve2);
            printf("%d\n",(ans.d[1][n+1]+(u==v?1:0))%mod);
        }
        //printf("%I64d\n",calc);
    }
    return 0;
}
#include<cstdio>
#include<cstring>
#include<cctype>
#include<iostream>
#include<algorithm>
using namespace std;
static const int mod=(1e9)+7;
inline void inc(int&x,int y){if((x+=y)>=mod)x-=mod;}
#define N 1010
#define M 22
int n,m,in[N][M],out[N][M];
int Solve1[N][M],Solve2[M][N],Solve3[M][M],ans[N][N],reg[N][N],t[M][M],one[M][M];
inline void work(int d){//solve Solve3^d
    register int i,j,k;
    memset(one,0,sizeof one);for(i=1;i<=m+1;++i)one[i][i]=1;
    for(i=1;i<=m+1;++i)for(j=1;j<=m+1;++j)t[i][j]=Solve3[i][j];
    while(d){
        if(d&1){
            for(i=1;i<=m+1;++i)for(j=1;j<=m+1;++j){
                reg[i][j]=0;for(k=1;k<=m+1;++k)inc(reg[i][j],(long long)one[i][k]*t[k][j]%mod);
            }
            for(i=1;i<=m+1;++i)for(j=1;j<=m+1;++j)one[i][j]=reg[i][j];
        }
        d>>=1;
        for(i=1;i<=m+1;++i)for(j=1;j<=m+1;++j){
            reg[i][j]=0;for(k=1;k<=m+1;++k)inc(reg[i][j],(long long)t[i][k]*t[k][j]%mod);
        }
        for(i=1;i<=m+1;++i)for(j=1;j<=m+1;++j)t[i][j]=reg[i][j];
    }
    for(i=1;i<=m+1;++i)for(j=1;j<=m+1;++j)Solve3[i][j]=one[i][j];
}
void pr1(){
    puts("Solve1");
    for(int i=1;i<=n+1;++i){
        for(int j=1;j<=m+1;++j)printf("%9d ",Solve1[i][j]);puts("");
    }
}
void pr2(){
    puts("Solve2");
    for(int i=1;i<=m+1;++i){
        for(int j=1;j<=n+1;++j)printf("%9d ",Solve2[i][j]);puts("");
    }
}
void pr3(){
    puts("Solve3");
    for(int i=1;i<=m+1;++i){
        for(int j=1;j<=m+1;++j)printf("%9d ",Solve3[i][j]);puts("");
    }
}
int main(){
    #ifndef ONLINE_JUDGE
    freopen("tt.in","r",stdin);
    #endif
    scanf("%d%d",&n,&m);
    register int i,j,k;
    for(i=1;i<=n;++i){for(j=1;j<=m;++j)scanf("%d",&out[i][j]);for(j=1;j<=m;++j)scanf("%d",&in[i][j]);}
    int q,u,v,d;
    scanf("%d",&q);
    while(q--){
        scanf("%d%d%d",&u,&v,&d);
        if(d==0)printf("%d\n",u==v?1:0);
        else if(d==1){
            int ret=0;for(int i=1;i<=m;++i)ret=(ret+(long long)out[u][i]*in[v][i]%mod)%mod;
            printf("%d\n",(ret+(u==v?1:0))%mod);
        }
        else{
            memset(Solve1,0,sizeof Solve1);
            for(i=1;i<=n;++i)for(j=1;j<=m;++j)Solve1[i][j]=out[i][j];
            Solve1[n+1][m+1]=1;
            memset(Solve2,0,sizeof Solve2);
            for(i=1;i<=n;++i)for(j=1;j<=m;++j)Solve2[j][i]=in[i][j];
            for(j=1;j<=m;++j)Solve2[j][n+1]=in[v][j];
            Solve2[m+1][n+1]=1;
            
            for(i=1;i<=m+1;++i)for(j=1;j<=m+1;++j){
                Solve3[i][j]=0;for(k=1;k<=n+1;++k)inc(Solve3[i][j],(long long)Solve2[i][k]*Solve1[k][j]%mod);
            }
            
            work(d-1);
            
            for(j=1;j<=n+1;++j)ans[1][j]=0;ans[1][u]=1;
            
            for(i=1;i<=1;++i)for(j=1;j<=m+1;++j){
                reg[i][j]=0;for(k=1;k<=n+1;++k)inc(reg[i][j],(long long)ans[i][k]*Solve1[k][j]%mod);
            }for(i=1;i<=1;++i)for(j=1;j<=m+1;++j)ans[i][j]=reg[i][j];
            
            for(i=1;i<=1;++i)for(j=1;j<=m+1;++j){
                reg[i][j]=0;for(k=1;k<=m+1;++k)inc(reg[i][j],(long long)ans[i][k]*Solve3[k][j]%mod);
            }for(i=1;i<=1;++i)for(j=1;j<=m+1;++j)ans[i][j]=reg[i][j];
            
            for(i=1;i<=1;++i)for(j=1;j<=n+1;++j){
                reg[i][j]=0;for(k=1;k<=m+1;++k)inc(reg[i][j],(long long)ans[i][k]*Solve2[k][j]%mod);
            }for(i=1;i<=1;++i)for(j=1;j<=n+1;++j)ans[i][j]=reg[i][j];
            
            printf("%d\n",(ans[1][n+1]+(u==v?1:0))%mod);
        }
    }
    return 0;
}