矩阵乘法与矩阵快速幂 求斐波那契(Fibonacci)数列第n项

/ 0评 /

斐波那契(Fibonacci)数列的递推式是:F_{i}=F_{i-1}+F_{i-2} 。根据这个递推式,我们可以在 \Theta (n) 复杂度内求出第 n 项,但是当 n 很大时,这种方法就显得很慢。其实利用矩阵快速幂,我们可以在 \Theta (\log_2 n) 内求出第 n 项。

矩阵

定义

一个 m×n 的矩阵(matrix)是一个由 m 行(row)n 列(column)元素排列成的矩形阵列。矩阵里的元素可以是数字、符号或数学式。例如,以下就是一个 2×3 的矩阵:

\begin{bmatrix}
1 & 2 & 3 \\\\
4 & 6 &5
\end{bmatrix}

矩阵是线性代数的知识,更多这里就不介绍了。

矩阵的加减法

矩阵的加法减法很简单:

\begin{bmatrix}
1 & 2 & 3 \\\\
4 & 6 & 5
\end{bmatrix} +
\begin{bmatrix}
2 & 1 & 0 \\\\
2 & 0 & 1
\end{bmatrix} =
\begin{bmatrix}
3 & 3 & 3 \\\\
6 & 6 & 6
\end{bmatrix}

减法也同理。

矩阵乘以矩阵

矩阵之间的相乘就比较复 (bian) 杂 (tai) 了,根据 Wikipedia

矩阵相乘最重要的方法是一般矩阵乘积。它只有在第一个矩阵的列数(column)和第二个矩阵的行数(row)相同时才有定义。一般单指矩阵乘积时,指的便是一般矩阵乘积。若 A m\times n 矩阵,Bn\times p 矩阵,则他们的乘积 AB (有时记做 A \cdot B)会是一个 m\times p 矩阵。其乘积矩阵的元素如下面式子得出:

\displaystyle AB_{ij} = \sum_{r=1}^{n}a_{ir}b_{rj}= a_{i1}b_{1j}+a_{i2}b_{2j}+\dots+a_{in}b_{nj}

看个例子就懂了:

\begin{bmatrix}
1 & 2 \\\\
4 & 3
\end{bmatrix} +
\begin{bmatrix}
2 & 1 \\\\
2 & 0
\end{bmatrix} =
\begin{bmatrix}
3 & 3 \\\\
6 & 6
\end{bmatrix}

第一个矩阵第一行的每个数字(1 和 2)各自乘以第二个矩阵第一列对应位置的数字(2 和 2),然后将乘积相加( 1 x 2 + 2 x 2),得到结果矩阵左上角的值 6。同理,其他的几个值也是这么算的。维基百科上有一张形象的图:

矩阵相乘示意

矩阵快速幂求斐波那契数列

其实 POJ上这道模板题 已经告诉你怎么做了:

\displaystyle
\begin{bmatrix}
F_{n+1} & F_n \\\\
F_n & F_{n-1}
\end{bmatrix} =
\begin{bmatrix}
1 & 1 \\\\
1 & 0
\end{bmatrix}^n =
\underbrace{
\begin{bmatrix}
1 & 1 \\\\
1 & 0
\end{bmatrix}
\begin{bmatrix}
1 & 1 \\\\
1 & 0
\end{bmatrix}\cdots
\begin{bmatrix}
1 & 1 \\\\
1 & 0
\end{bmatrix}
}_ {\text{n times}}

也就是说矩阵 \begin{bmatrix} 1 & 1 \\\\ 1 & 0 \end{bmatrix} 的 n 次幂里面四个数就分别是 F_{n+1}F_nF_nF_{n-1}

例题

模板题

洛谷 3390 【模板】矩阵快速幂
POJ3070 - Fibonacci

综合

CodeForces 551D. GukiZ and Binary Operations

代码

以下是洛谷上的模板题代码:

#include<cstdio>
#include<cstring>
#include<iostream>
using namespace std;
const int maxn=105,tt=1000000007;;
int n;
long long m;
inline int read(){
    int ret=0,f=1;char ch=getchar();
    while (ch<'0'||ch>'9') {if (ch=='-') f=-1;ch=getchar();}
    while (ch>='0'&&ch<='9') ret=ret*10+ch-'0',ch=getchar();
    return ret*f;
}
struct Matrix{
    int a[maxn][maxn];
    void ReadMatrix(){
        memset(a,0,sizeof(a));
        for (int i=0;i<n;i++)
        for (int j=0;j<n;j++) a[i][j]=read();
    }
    void Init(){
        memset(a,0,sizeof(a));
        for (int i=0;i<n;i++) a[i][i]=1;
    }
    void Clear(){
        memset(a,0,sizeof(a));
    }
    Matrix operator *(Matrix b){
        Matrix c; c.Clear();
        for (int i=0;i<n;i++)
        for (int j=0;j<n;j++)
        for (int k=0;k<n;k++)
            c.a[i][j]=((long long)c.a[i][j]+(long long)a[i][k]*b.a[k][j])%tt;
        return c;
    }
    Matrix operator ^(long long b){
        Matrix ret; ret.Init();
        Matrix w;   w.Clear();
        for (int i=0;i<n;i++)
        for (int j=0;j<n;j++) w.a[i][j]=a[i][j];
        while (b){
            if (b%2) ret=ret*w;
            b=b/2;w=w*w;
        }
        return ret;
    }
    void Write(){
        for (int i=0;i<n;i++){
            for (int j=0;j<n;j++) printf("%d ",a[i][j]);
            printf("\n");
        }
    }
};
int main(){
    n=read();scanf("%lld",&m);
    Matrix a; a.ReadMatrix();
    // printf("Read part finished.\n");
    a=a^m;
    a.Write();
    return 0;
}

下面是 POJ 上模板题代码:

#include<cstdio>
#include<cstring>
#include<iostream>
using namespace std;
const int tt=10000;
int n;
struct Matrix{
    int a[3][3];
    void Init(){
        memset(a,0,sizeof(a));
        a[0][0]=a[1][0]=a[0][1]=1;
        a[1][1]=0;
    }
    void Clear(){
        memset(a,0,sizeof(a));
    }
    Matrix operator *(Matrix b){
        Matrix c; c.Clear();
        for (int i=0;i<2;i++)
        for (int j=0;j<2;j++)
        for (int k=0;k<2;k++)
            c.a[i][j]=((long long)c.a[i][j]+(long long)a[i][k]*b.a[k][j])%tt;
        return c;
    }
    Matrix operator ^(long long b){
        Matrix ret; ret.Init();
        Matrix w;   w.Clear();
        for (int i=0;i<2;i++)
        for (int j=0;j<2;j++) w.a[i][j]=a[i][j];
        while (b){
            if (b%2) ret=ret*w;
            b=b/2;w=w*w;
        }
        return ret;
    }
    void Write(){
        for (int i=0;i<2;i++){
            for (int j=0;j<2;j++) printf("%d ",a[i][j]);
            printf("\n");
        }
    }
};
int main(){
    scanf("%d",&n);n--;
    while (n!=-2){
        if (n==-1) printf("0\n"); else{
            Matrix a; a.Init();
            a=a^n;
            printf("%d\n",a.a[1][0]);
        }
        scanf("%d",&n);n--;
    }
    return 0;
}

参考

矩阵 - 维基百科,自由的百科全书
矩阵乘法 - 维基百科,自由的百科全书
理解矩阵乘法 - 阮一峰的网络日志


知识共享许可协议 本文章采用 知识共享署名-非商业性使用-相同方式共享 4.0 国际许可协议 进行许可。
欢迎转载,如有错误欢迎指出。
本文链接:https://skywt.cn/posts/matrix-multiply/


发表评论

电子邮件地址不会被公开。 必填项已用*标注