矩阵乘法在动态规划中的应用

/ 0评 /

\begin{bmatrix} x_{11} & x_{12} & x_{13} \\ x_{21} & x_{22} & x_{23} \\ x_{31} & x_{32} & x_{33} \end{bmatrix}

每增加一个维度,世界便会增加无限的美感。

从 Fibonacci 数列开始

矩阵乘法应用的入门题。

如果把 F_n 放在矩阵里,构造出一个 1×2 的矩阵 \begin{bmatrix} F_{n-1} & F_{n-2}\end{bmatrix},就可以构造一个转移矩阵:

\begin{bmatrix} F_{n-1} & F_{n-2} \end{bmatrix} \ast \begin{bmatrix} 1 & 1 \\ 1 & 0 \end{bmatrix} = \begin{bmatrix} F_{n-1}+F_{n-2} & F_{n-1} \end{bmatrix} = \begin{bmatrix} F_{n} & F_{n-1} \end{bmatrix}

这里 \begin{bmatrix} 1 & 1 \\ 1 & 0\end{bmatrix} 就是转移矩阵。也就是说只需要将构造的 \begin{bmatrix} F_{n-1} & F_{n-2}\end{bmatrix} 乘以转移矩阵就可以得到 \begin{bmatrix} F_{n} & F_{n-1}\end{bmatrix}。显然可以用快速幂优化了。

其实也可以写成 4×4 的矩阵进行推导:

\displaystyle \begin{bmatrix} F_{n+1} & F_n \\ F_n & F_{n-1} \end{bmatrix} = \begin{bmatrix} 1 & 1 \\ 1 & 0 \end{bmatrix}^n = \underbrace{ \begin{bmatrix} 1 & 1 \\ 1 & 0 \end{bmatrix} \begin{bmatrix} 1 & 1 \\ 1 & 0 \end{bmatrix}\cdots \begin{bmatrix} 1 & 1 \\ 1 & 0 \end{bmatrix} }_ {\text{n times}}

拓展一下?

由上面的例子,我们可以扩展一下:对于这样的递推式:

F_{n}=F_{n-1}+F_{n-2}+\dots+F_{n-k}

我们其实都可以通过构造矩阵的方法来优化,快速计算 F_n
(不过很多简单的递推式也可以求通项公式……)

POJ 3233 Matrix Power Series

Description

Link: POJ 3233 Matrix Power Series

Given a n × n matrix A and a positive integer k, find the sum S = A + A^2 + A^3 + \dots + A^k.

n \leqslant 30, k \leqslant 10^9 and m < 10^4.

Hint

一道比较典型的矩阵乘法优化递推的题目。有很多方法。

Analysis #1

先把矩阵 A 看成一个数字。由题可知:

F(n)=A + A^2 + A^3 + \dots + A^n

那么

F(n)=A \ast F(n-1) + A

显然我们可以构造一个转移矩阵:

\begin{bmatrix} A^{n-1} & F_{n-1} \\ 0 & 1 \end{bmatrix} \ast \begin{bmatrix} A & A \\ 0 & 1 \end{bmatrix} = \begin{bmatrix} A^{n} & F_{n} \\ 0 & 1 \end{bmatrix}

也就是说我们发现:

\begin{bmatrix} A & A \\ 0 & 1 \end{bmatrix} \ast \begin{bmatrix} A & A \\ 0 & 1 \end{bmatrix} = \begin{bmatrix} A^2 & A^2+A \\ 0 & 1 \end{bmatrix}

则:

S= \begin{bmatrix} A & A \\ 0 & 1 \end{bmatrix} ^k

写一个一百多行的矩阵套矩阵快速幂就可以了。

Analysis #2

这是最简单的方法。

F(n)=A + A^2 + A^3 + \dots + A^n

  1. n 为偶数:

F(n)=A+A^2+\dots+A^{n/2} +A^{n/2}\ast F(n/2)

  1. n 为奇数:

F(n)=A\ast F(n-1)+A

这样其实就是分治,可以 \log k 求解。

Analysis #3

看看 F(n)=A\ast F(n-1)+A 这个递推式,考虑用它推通项公式。

F(n)+\lambda = A\ast F(n-1)+A +\lambda

(我们用 I 表示单位矩阵)

F(n)+\lambda = A \ast (F(n-1)+I+\frac {\lambda} {A})

\displaystyle \lambda = I+\frac {\lambda} A ,解得 \displaystyle \lambda =A\ast (A-I)^{-1}

(A-I)^{-1} 就是 A-I逆矩阵。因为这个给你的矩阵是 n\ast n方阵,所以其实可以直接用高斯消元求其逆矩阵。
接下来令 G(n)=F(n)+A\ast (A-I)^{-1},那么

G(n)=A\ast G(n-1)

G(n)=A^{n-1}\ast G(1)

矩阵快速幂即可求解。

Code

我只写了第一种思路的代码……

这题代码有个坑点,如果全都开 long long 会过不了,必须开 int 让其自然溢出……可能数据有点问题[1]

(不过谁能想到有人居然所有代码前面都加 #define int long long 的……)

/*
 * Vjudge CONTEST244508 矩阵乘法专题训练
 * POJ 3233
 * C - Matrix Power Series
 * 180929 By SkyWT
 */

#include<cstdio>
#include<iostream>
#include<cstring>
#include<string>
#include<cmath>
#include<algorithm>
#include<vector>
#include<queue>
#include<stack>
#include<map>

using namespace std;

#define memset_clear(_) memset(_,0,sizeof(_))
#define memset_clear_tre(_) memset(_,1,sizeof(_))
#define memset_clear_reg(_) memset(_,-1,sizeof(_))
#define memset_clear_max(_) memset(_,0x3f,sizeof(_))
#define memset_clear_min(_) memset(_,0x80,sizeof(_))
#define sqr(_) ((_)*(_))

// #define int long long

const int maxn=35;
int n,k,tt;

inline int read(){
    int ret=0,f=1;char ch=getchar();
    while (ch<'0'||ch>'9'){if (ch=='-') f=-1;ch=getchar();}
    while (ch>='0'&&ch<='9') ret=ret*10+ch-'0',ch=getchar();
    return ret*f;
}

struct matrix{
    int a[maxn][maxn];
    void init(){ // 构造一个单位矩阵
        memset(a,0,sizeof(a));
        for (int i=0;i<n;i++) a[i][i]=1;
    }
    void clear(){
        memset(a,0,sizeof(a));
    }
    matrix operator +(matrix b){
        matrix c;c.clear();
        for (int i=0;i<n;i++)
        for (int j=0;j<n;j++)
            c.a[i][j]=a[i][j]+b.a[i][j];
        return c;
    }
    matrix operator *(matrix b){
        matrix c; c.clear();
        for (int i=0;i<n;i++)
        for (int j=0;j<n;j++)
        for (int k=0;k<n;k++)
            c.a[i][j]=(c.a[i][j]+a[i][k]*b.a[k][j]%tt)%tt;
        return c;
    }
    matrix operator ^(int b){
        matrix ret; ret.init();
        matrix w;   w.clear();
        for (int i=0;i<2;i++)
        for (int j=0;j<2;j++)
            w.a[i][j]=a[i][j];
        while (b){
            if (b&1) ret=ret*w;
            w=w*w;b>>=1;
        }
        return ret;
    }
}fst;

struct matrix_matrix{
    matrix a[3][3];
    void init(matrix number,bool flag){
        if (flag){
            a[0][0].init();a[1][1].init();
            a[0][1].clear();a[1][0].clear();
        } else {
            a[0][0]=a[0][1]=number;
            a[1][0].clear();a[1][1].init();
        }
    }
    void clear(){
        a[0][0].clear();a[0][1].clear();
        a[1][0].clear();a[1][1].clear();
    }
    matrix_matrix operator *(matrix_matrix b){
        matrix_matrix c; c.clear();
        for (int i=0;i<2;i++)
        for (int j=0;j<2;j++)
        for (int k=0;k<2;k++)
            c.a[i][j]=(c.a[i][j]+a[i][k]*b.a[k][j]);
        return c;
    }
    matrix_matrix operator ^(int b){
        matrix tmp;
        matrix_matrix ret;ret.init(tmp,true);
        matrix_matrix w;w.init(fst,false);
        while (b){
            if (b&1) ret=ret*w;
            w=w*w;b>>=1;
        }
        return ret;
    }
    void write(){
        for (int i=0;i<n;i++){
            for (int j=0;j<n;j++) printf("%d ",a[0][1].a[i][j]%tt);
            printf("\n");
        }
    }juzhegn
};

signed main(){
    n=read();k=read();tt=read();
    for (int i=0;i<n;i++)
    for (int j=0;j<n;j++)
        fst.a[i][j]=read()%tt;
    matrix_matrix now;
    now.init(fst,false);
    now=now^k;
    now.write();
    return 0;
}

HDU 2065 "红色病毒"问题

Description

Link:HDU 2065 "红色病毒"问题

医学界发现的新病毒因其蔓延速度和 Internet 上传播的“红色病毒”不相上下,被称为“红色病毒”。经研究发现,该病毒及其变种的 DNA 的一条单链中,胞嘧啶、腺嘧啶均是成对出现的。
现在有一长度为N的字符串,满足以下条件:

  1. 字符串仅由 A,B,C,D 四个字母组成;
  2. A 出现偶数次(也可以不出现);
  3. C 出现偶数次(也可以不出现)。

计算满足条件的字符串个数。
当 N=2 时,所有满足条件的字符串有如下 6 个:BB,BD,DB,DD,AA,CC。
由于这个数据肯能非常庞大,你只要给出最后两位数字即可。

Hint

矩阵优化 DP 的典型。

对于前面一题递推的题目,我们尚可直接用求通项的数学方法,但是这题似乎不行了……

Analysis #1

首先可以想到 F[i] 表示长度为 i 的字符串,偶数个 A 并且偶数个 C 的数量。但是直接这样定义显然是没法直接状态转移的……
我们自然想到分别在 DP 状态里记录下 A 和 C 了。
为了方便放进矩阵里,我们尽量定义成一个二维的:F[i][0/1/2/3],i 表示字符串长度为 i,第二维讨论一下:

  1. A 和 C 均出现了偶数次;
  2. A 出现了偶数次,C 出现了奇数次;
  3. A 出现了奇数次,C 出现了偶数次;
  4. A 和 C 均出现奇数次。

那么就有如下转移方程:

F(i,0)=2\ast F(i-1,0) + F(i-1,1) + F(i-1,2) \\ F(i,1)=2\ast F(i-1,1) + F(i-1,0) + F(i-1,3) \\ F(i,2)=2\ast F(i-1,2) + F(i-1,0) + F(i-1,3) \\ F(i,3)=2\ast F[i-1,3] + F(i-1,1) + F(i-1,2)

让我们对齐一下,以便于后续的推导:

F(i,0)=2\ast F(i-1,0) + 1\ast F(i-1,1) + 1\ast F(i-1,2) + 0\ast F(i-1,3) \\ F(i,1)=1\ast F(i-1,0) + 2\ast F(i-1,1) + 0\ast F(i-1,2) + 1\ast F(i-1,3) \\ F(i,2)=1\ast F(i-1,0) + 0\ast F(i-1,0) + 2\ast F(i-1,2) + 1\ast F(i-1,3) \\ F(i,3)=0\ast F(i-1,0) + 1\ast F(i-1,1) + 1\ast F(i-1,2) + 2\ast F(i-1,3)

现在考虑通过矩阵转移。假设我们把 F(i,j) 看成如下矩阵:

\begin{bmatrix} F(i,0) \\ F(i,1) \\ F(i,2) \\ F(i,3) \end{bmatrix}

现在我们需要转移成这样子:

\begin{bmatrix} 2\ast F(i-1,0) + 1\ast F(i-1,1) + 1\ast F(i-1,2) + 0\ast F(i-1,3) \\ 1\ast F(i-1,0) + 2\ast F(i-1,1) + 0\ast F(i-1,2) + 1\ast F(i-1,3) \\ 1\ast F(i-1,0) + 0\ast F(i-1,0) + 2\ast F(i-1,2) + 1\ast F(i-1,3) \\ 0\ast F(i-1,0) + 1\ast F(i-1,1) + 1\ast F(i-1,2) + 2\ast F(i-1,3) \end{bmatrix}

其实转移矩阵已经很显然了,就是那几个系数构成的矩阵:

\begin{bmatrix} 2 & 1 & 1 & 0 \\ 1 & 2 & 0 & 1 \\ 1 & 0 & 2 & 1 \\ 0 & 1 & 1 & 2 \end{bmatrix}

现在还有个微小的问题,这个矩阵是 4×4 的,但是原来是 4×1 的……4 行 1 列的矩阵可不能与 4 行 4 列的矩阵相乘。我们只能把 4×1 的矩阵写成 1×4 的:

\begin{bmatrix} F(i,0) & F(i,1) & F(i,2) & F(i,3) \end{bmatrix}

(可以发现这是一个沿”主对角线“对称的矩阵)
接下来考虑下初始矩阵,不难得出答案:

\begin{bmatrix} 2 & 1 & 1 & 0 \end{bmatrix} \ast \begin{bmatrix} 2 & 1 & 1 & 0 \\ 1 & 2 & 0 & 1 \\ 1 & 0 & 2 & 1 \\ 0 & 1 & 1 & 2 \end{bmatrix} ^{n-1} = \begin{bmatrix} F(n,0) & F(n,1) & F(n,2) & F(n,3) \end{bmatrix}

矩阵快速幂优化即可。

Analysis #2

网上题解搜得:泰勒级数?不会。

Code

需要注意题目的输入格式:

每组输入的第一行是一个整数 T,表示测试实例的个数。下面是 T 行数据,每行一个整数 N (1\leqslant N<2^{64}),当 T=0 时结束。

N 达到了 2^{64},所以我们需要开 unsigned long long
(正常的操作当然是#define int unsigned long long 啦~)

解法 #1 的代码:

/*
 * Vjudge CONTEST 244508
 * HDU 2065
 * B - "红色病毒"问题
 * 181002 By SkyWT
 */

#include<cstdio>
#include<iostream>
#include<cstring>
#include<string>
#include<cmath>
#include<algorithm>
#include<vector>
#include<queue>
#include<stack>
#include<map>

using namespace std;

#define memset_clear(_) memset(_,0,sizeof(_))
#define memset_clear_tre(_) memset(_,1,sizeof(_))
#define memset_clear_reg(_) memset(_,-1,sizeof(_))
#define memset_clear_max(_) memset(_,0x3f,sizeof(_))
#define memset_clear_min(_) memset(_,0x80,sizeof(_))
#define sqr(_) ((_)*(_))

#define int unsigned long long

inline int read(){
    int ret=0,f=1;char ch=getchar();
    while (ch<'0'||ch>'9') {if (ch=='-') f=-1;ch=getchar();}
    while (ch>='0'&&ch<='9') ret=ret*10+ch-'0',ch=getchar();
    return ret*f;
}

const int maxn=5,tt=100;
const int transfer[4][4]={{2,1,1,0},
                          {1,2,0,1},
                          {1,0,2,1},
                          {0,1,1,2}};

struct matrix{
    int n,m,a[maxn][maxn];
    void set_identity(){
        memset_clear(a);
        for (int i=0;i<min(n,m);i++) a[i][i]=1;
    }
    void init(){
        memset_clear(a);
    }
    void give(matrix &b){
        b.n=n;b.m=m;
        for (int i=0;i<n;i++)
        for (int j=0;j<m;j++)
            b.a[i][j]=a[i][j];
    }
    matrix operator +(matrix b){
        if (b.n!=n||b.m!=m) printf("ERROR: TWO DEFFERENT MATRIX MAKE +\n");
        matrix c; c.init();
        c.n=n;c.m=m;
        for (int i=0;i<n;i++)
        for (int j=0;j<m;j++)
            c.a[i][j]=(a[i][j]+b.a[i][j])%tt;
        return c;
    }
    matrix operator -(matrix b){
        if (b.n!=n||b.m!=m) printf("ERROR: TWO DEFFERENT MATRIX MAKE -\n");
        matrix c; c.init();
        c.n=n;c.m=m;
        for (int i=0;i<n;i++)
        for (int j=0;j<m;j++)
            c.a[i][j]=(a[i][j]-b.a[i][j]+tt)%tt;
        return c;
    }
    matrix operator *(matrix b){
        if (m!=b.n) printf("ERROR: TWO DEFFERENT MATRIX MAKE *\n");
        matrix c; c.init();
        c.n=n;c.m=b.m;
        for (int i=0;i<  n;i++)
        for (int j=0;j<  m;j++)
        for (int k=0;k<b.m;k++)
            c.a[i][j]=(c.a[i][j]+a[i][k]*b.a[k][j]%tt)%tt;
        return c;
    }
    matrix operator ^(int b){
        matrix ret;ret.n=n;ret.m=m;ret.set_identity();
        matrix w;w.n=n;w.m=m;give(w);
        while (b){
            if (b&1) ret=ret*w;
            w=w*w;b>>=1;
        }
        return ret;
    }
};

signed main(){
    int T=0,n=0;
    scanf("%llu",&T);
    while (T){
        for (int k=0;k<T;k++){
            scanf("%llu",&n);

            matrix now;now.init();
            now.n=1;now.m=4;
            now.a[0][0]=2;now.a[0][1]=now.a[0][2]=1;

            matrix tomul;
            tomul.n=4;tomul.m=4;
            for (int i=0;i<4;i++)
            for (int j=0;j<4;j++)
                tomul.a[i][j]=transfer[i][j];

            tomul=tomul^(n-1);
            now=now*tomul;
            printf("Case %llu: %llu\n",k+1,now.a[0][0]);
        }
        printf("\n");
        scanf("%llu",&T);
    }
    return 0;
}

知识共享许可协议 本文章采用 知识共享署名-非商业性使用-相同方式共享 4.0 国际许可协议 进行许可。
欢迎转载,如有错误欢迎指出。
本文链接:https://skywt.cn/posts/matrix-multiply-in-dp/


发表评论

电子邮件地址不会被公开。 必填项已用*标注