最近莫名其妙的做了许多矩阵优化DP的题。。。在这里做一下总结。。。
矩阵优化dp,是将dp的状态化为矩阵,dp的转移也化为矩阵,将转移变为矩阵乘法的过程。
1.什么样的题可以用矩阵优化?
目前发现两类:
第一类:
线性常系数递推方程,就是像斐波那契数列那样的(f[i]=f[i-1]+f[i-2])
对于这样的题,我们需要构造2个矩阵,初始矩阵,转移矩阵。
初始矩阵我们一般认为是列向量,转移矩阵一般为一个N*N的方阵。
举个例子,对于递推方程f[i]=a*f[i-1]+b*f[i-2]+c,我们如何构造?
首先,初始矩阵需要包含递推式右边出现的所有元素。
初始矩阵我们设为{f[2],f[1],c},那么最终答案应该为{f[n],f[n-1],c}
我们考虑转移前的矩阵{f[i-1],f[i-2],c}如何转移到转移后的矩阵{f[i],f[i-1],c}
现在我来介绍一种爽翻天的构造转移矩阵的方法:标准化方程法。
我们把转移后的矩阵放在左侧,转移前的矩阵放在右侧,列线性方程组,可以得到:
f[i]=a*f[i-1]+b*f[i-2]+1*c
f[i-1]=1*f[i-1]+0*f[i-2]+0*c
c=0*f[i-1]+0*f[i-2]+1*c
根据我们线代上课学过的知识,上面这个方程组可以写成矩阵形式,其系数矩阵就是我们要的转移矩阵:
[a,b,1]
[1,0,0]
[0,0,1]
(但愿这个矩阵没错)
我们再看一个经典例子:
poj 3070 求斐波那契数列第n项,n巨大
我们来看fib数列:f[i]=f[i-1]+f[i-2] 按照刚刚的方式找转移矩阵:
f[i]=1*f[i-1]+1*f[i-2]
f[i-1]=1*f[i-1]+0*f[i-2]
转移矩阵T为:
[1,1]
[1,0]
最终答案为:A*[T^(n-1)]
代码如下:
要注意:①n-1不要减成负的 ②矩阵快速幂要规范书写,不要瞎搞(下面代码其实就有点瞎搞了)
#include<iostream> #include<cstdio> #include<cstring> using namespace std; struct mix { long long d[3][3]; }s; const long long mod=10000; mix mcheng(mix s1,mix s2) { mix ret; memset(ret.d,0,sizeof ret.d); for(int i=1;i<=2;i++) { for(int j=1;j<=2;j++) { for(int k=1;k<=2;k++) ret.d[i][j]=(ret.d[i][j]+(s1.d[i][k]*s2.d[k][j]))%mod; } } return ret; } mix pow(mix x,long long n) { n--; mix b=s; while(n>0) { if(n&1)b=mcheng(b,x); n>>=1; x=mcheng(x,x); } return b; } void print(mix x) { printf("%I64d\n",x.d[2][1]%mod); } long long n=0; int main() { s.d[1][1]=s.d[1][2]=s.d[2][1]=1; while(scanf("%I64d",&n)) { if(n==-1)break; if(n==0){printf("0\n");continue;} mix ans=pow(s,n); print(ans); } return 0; }
来看下一题:
CF:Gym – 101473H
有两种公交车:小巴和大巴,小巴长5米,大巴长10米,小巴与大巴一起排成了一排。
现在小巴有A种涂装,大巴有B种涂装,告诉你小巴与大巴总长M,问给这个的车队涂装,有多少种涂装方案?
比如车队长为10,小巴有2种涂装,大巴有一种涂装,那么涂装方案有:
1.2辆小巴 方案数2*2
2.1辆大巴 方案数1*1
总方案4+1=5
这题一上来懵逼,队友开始向裴蜀定理和exgcd找规律,失败。
我考虑递推,计算在已知前S米的方案之后增加车队长度对答案的影响。
很容易(hard)想到f[i]=f[i-5]*A+f[i-10]*B;
紧接着(很久以后并且队友提示)想到所有i%5!=0的位置都没用,一开始先让M对5取模,递推式可以变成:
f[i]=f[i-1]*A+f[i-2]*B 一眼看出矩阵快速幂
构造矩阵:
f[i]=A*f[i-1]+B*f[i-2]
f[i-1]=1*f[i-1]+0*f[i-2]
转移矩阵为:
[A,B]
[1,0]
代码(中间写挫了,还写了个快速乘)
#include<iostream> #include<cstdio> #include<cstring> #include<algorithm> #include<string> using namespace std; long long n,a,b; struct mix { long long d[3][3]; }s,q,one; const long long mod=1000000; long long work(long long a,long long b,long long c) { long long ans=0; a=a%c;b=b%c; while(b>0) { if(b&1)ans=(ans+a)%c; a=(a+a)%c; b>>=1; } return ans; } mix mcheng(mix s1,mix s2) { mix ret; memset(ret.d,0,sizeof ret.d); for(int i=0;i<2;i++) { for(int j=0;j<2;j++) { for(int k=0;k<2;k++) ret.d[i][j]=(ret.d[i][j]+work(s1.d[i][k],s2.d[k][j],mod))%mod; } } return ret; } mix pow(mix x,long long n) { if(n==0)return one; n--; mix b=s; while(n>0) { if(n&1)b=mcheng(b,x); n>>=1; x=mcheng(x,x); } return b; } void print(mix x) { q.d[0][0]=a,q.d[0][1]=0; q.d[1][0]=1,q.d[1][1]=0; printf("%06lld\n",(mcheng(x,q).d[0][0])%mod); } int main() { one.d[0][0]=1,one.d[0][1]=0; one.d[1][0]=0,one.d[1][1]=1; cin>>n>>a>>b; s.d[0][0]=a,s.d[0][1]=b; s.d[1][0]=1,s.d[1][1]=0; mix ans=pow(s,n/5-1); print(ans); return 0; }
下一题:
HDU – 2243 考研路漫漫
给出n个串,问你至少含有其中1个串且长度不小于k的字符串个数有多少。
含有的不好找,首先转化为计算不含有的数目。
把这n个串扔进AC自动机,我们在AC自动机上跑一跑,对于每一个节点,如果是某串的结尾,意味着跑出了不合法串,不管他,跳过。
如果不是某串结尾,我们设他为状态s,他可以转移到的相邻状态为t,使得sum[s][t]++,意味着s状态到t状态多了一种转移方式。
然后,我们要统计不多于k的不合法串有多少个。这其实可以转化为在一个图上走k步的方案数问题。。
但是我们不是要走k步啊!是要不多于k啊!所以我们要进行一些小小的构造。。
注意 这个邻接矩阵跟上面的那些矩阵是不太一样的,上面那些矩阵都是列到列转移的,而这个实质上是行到行的转移。。
怎么说。,。就是上面那些矩阵的转置。
举个例子:1->2的路径条数,在邻接矩阵里是(1,2)位置,但是在上面那些转移矩阵里是在(2,1)位置,这就要求我们转置一下:
在邻接矩阵最下方添加一行全1的行,这样最后一行代表的就是之前所有的方案数总和,最后统计∑f[1][i]即可,注意把最后一行也算上!
然后用总方案减去不合法方案即可。总方案是∑f[i] f[i]=26*f[i-1] 也利用刚刚的方案列转移矩阵:
f[i]=26*f[i-1] +0*∑f[i-2]
∑f[i-1]=1*f[i-1]+1*∑f[i-2]
转移矩阵为:
[26,0]
[1 ,1]
代码:
由于一直AC不了。。我把fail指针改成win指针了。。。
#include<cstdio> #include<cstdlib> #include<cstring> #include<iostream> #include<queue> using namespace std; const int N=115; const long long mod=100000; struct AC_chicken { int ch[26]; int val; int win; }tr[N]; int cnt,root,n; long long m; void add(char s[]) { int len=strlen(s+1),p=root; for(int i=1;i<=len;i++) { int c=s[i]-'a'; if(tr[p].ch[c]==0)tr[p].ch[c]=++cnt; p=tr[p].ch[c]; } tr[p].val=1; } void build_chicken() { queue<int>q; q.push(root); while(!q.empty()) { int p=q.front();q.pop(); for(int i=0;i<26;i++) if(tr[p].ch[i]!=0) { int v=tr[p].ch[i]; if(p==root)tr[v].win=root; else { int u=tr[p].win; while(u!=0) { if(tr[u].ch[i]!=0) {tr[v].win=tr[u].ch[i];break;} u=tr[u].win; } if(u==0)tr[v].win=root; } q.push(v); } } } struct mix { unsigned long long a[N][N]; }s,q,one,ret; mix mcheng(mix s1,mix s2) { memset(ret.a,0,sizeof ret.a); for(int i=1;i<=cnt;i++) { for(int j=1;j<=cnt;j++) { for(int k=1;k<=cnt;k++) ret.a[i][j]=(ret.a[i][j]+s1.a[i][k]*s2.a[k][j]); } } return ret; } mix pow(mix x,long long n) { mix b=one; while(n>0) { if(n&1)b=mcheng(b,x); n>>=1; x=mcheng(x,x); } return b; } char tep[20]; int main() { while(scanf("%d%lld",&n,&m)!=EOF) { root=cnt=1; memset(tr,0,sizeof tr); memset(&s,0,sizeof s); for(int i=1;i<=n;i++) { scanf("%s",tep+1); add(tep); } for(int i=1;i<=cnt;i++)one.a[i][i]=1; build_chicken(); for(int i=root;i<=cnt;i++) if(!tr[i].val) { for(int j=0;j<26;j++) { int p=i; while(tr[p].ch[j]==0&&p!=root)p=tr[p].win; p=tr[p].ch[j]; if(p==0)p=root; int u=p;bool flag=1; while(u!=root) { if(tr[u].val==1){flag=0;break;} u=tr[u].win; } if(p==0)p=root; if(!tr[p].val&&flag)s.a[i][p]++; } } cnt++; for(int i=1;i<=cnt;i++)s.a[i][cnt]=1; q=pow(s,m); unsigned long long ans=0; for(int i=1;i<=cnt;i++)ans=ans+q.a[1][i]; ans--; cnt=2; s.a[1][1]=26,s.a[1][2]=0; s.a[2][1]=1,s.a[2][2]=1; q=pow(s,m); unsigned long long t=q.a[1][1]+q.a[2][1]; cout<<t-1-ans<<endl; } return 0; }