BZOJ2958 序列染色

题意

给出一个长度为$n$由BWX三种字符组成的字符串$S$,你需要把每一个X染成BW中的一个。
对于给出的$k$,问有多少种染色方式使得存在整数$a,b,c,d$使得:

  • $1 \le a \le b < c \le d \le n$;
  • $S_a,S_{a+1},…,S_b$均为B;
  • $S_c,S_{c+1},…,S_d$均为W;

其中$b=a+k-1,d=c+k-1$。
由于方法可能很多,因此只需要输出最后的答案对$10^9+7$取模的结果。
$n,k \le 10^6$。

题解

不难把题目转化为这样:你需要将一个序列分为$5$段。其中第$2$段与第$4$段分别为全为W与全为B,且第二段与第四段长度均为$k$,求其方案数。

考虑这样的一个DP:设$dp[i][j][k]$表示当前到了第$i$位,第$i$位填的是W还是B(0表示W,1表示B),且当前正处于第$k$段。转移时需要注意,由于直接DP可能把方案数算重,因此必须要在第一次出现连续$k$个WB的时候强制转移。直接这么讲似乎不太清晰,配合代码看看似乎比较好。来看一下这些转移:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
for(i=1;i<=n;i++){
if(a[i]!='W'){
dp[i][0][0]=Mod(dp[i][0][0],Mod(dp[i-1][0][0],dp[i-1][1][0]));
if(i>=k && sumw[i]-sumw[i-k]==0)dp[i][0][0]=Mod(dp[i][0][0],-dp[i-k][1][0]);
if(i>=k && sumw[i]-sumw[i-k]==0)dp[i][0][1]=Mod(dp[i][0][1],dp[i-k][1][0]);
dp[i][0][1]=Mod(dp[i][0][1],Mod(dp[i-1][0][1],dp[i-1][1][1]));
dp[i][0][2]=Mod(dp[i][0][2],Mod(dp[i-1][0][2],dp[i-1][1][2]));
}
if(a[i]!='B'){
dp[i][1][0]=Mod(dp[i][1][0],Mod(dp[i-1][0][0],dp[i-1][1][0]));
dp[i][1][1]=Mod(dp[i][1][1],Mod(dp[i-1][0][1],dp[i-1][1][1]));
if(i>=k && sumb[i]-sumb[i-k]==0)dp[i][1][1]=Mod(dp[i][1][1],-dp[i-k][0][1]);
if(i>=k && sumb[i]-sumb[i-k]==0)dp[i][1][2]=Mod(dp[i][1][2],dp[i-k][0][1]);
dp[i][1][2]=Mod(dp[i][1][2],Mod(dp[i-1][0][2],dp[i-1][1][2]));
}
}

其实除了几个if里头的转移不太明显,其他的还是很显然的。第二行转移里,之所以要减掉一个dp[i-k][1][0]的贡献,是因为我们强制让第一次连续出现$k$个B的时候就转移,因此要减掉把这$k$个位置全部填成B且不转移到下一段的贡献。第二个if里,之所以只让dp[i-k][1][0]转移过来,也是同样的道理。如果还从dp[i-k][1][0]处转移过来的话,那么就没有在第一次连续$k$个B的地方转移,至少有$k+1$个连续的B了。

代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
#include<cstdio>
#include<cstdlib>
#include<cstring>
#include<iostream>
#include<algorithm>
using namespace std;
typedef long long ll;
const int INF=1e9;
const long double eps=1e-9;
const int mod=1e9+7;
const int maxn=1e6+10;
int dp[maxn][2][3],sumw[maxn],sumb[maxn];
char a[maxn];
inline int read(){
int x=0,flag=1;
char ch=getchar();
while(!isdigit(ch) && ch!='-')ch=getchar();
if(ch=='-')flag=-1,ch=getchar();
while(isdigit(ch))x=(x<<3)+(x<<1)+ch-'0',ch=getchar();
return x*flag;
}
inline int Mod(int x,int y){
x+=y;
if(x>=mod)x-=mod;
if(x<0)x+=mod;
return x;
}
int main(){
int i,j,k,l,m,n,lim;
#ifndef ONLINE_JUDGE
freopen("color.in","r",stdin);
freopen("color.out","w",stdout);
#endif
n=read();k=read();
scanf("%s",a+1);
for(i=1;i<=n;i++)sumb[i]=sumb[i-1]+(a[i]=='B');
for(i=1;i<=n;i++)sumw[i]=sumw[i-1]+(a[i]=='W');
dp[0][1][0]=1;
for(i=1;i<=n;i++){
if(a[i]!='W'){
dp[i][0][0]=Mod(dp[i][0][0],Mod(dp[i-1][0][0],dp[i-1][1][0]));
if(i>=k && sumw[i]-sumw[i-k]==0)dp[i][0][0]=Mod(dp[i][0][0],-dp[i-k][1][0]);
if(i>=k && sumw[i]-sumw[i-k]==0)dp[i][0][1]=Mod(dp[i][0][1],dp[i-k][1][0]);
dp[i][0][1]=Mod(dp[i][0][1],Mod(dp[i-1][0][1],dp[i-1][1][1]));
dp[i][0][2]=Mod(dp[i][0][2],Mod(dp[i-1][0][2],dp[i-1][1][2]));
}
if(a[i]!='B'){
dp[i][1][0]=Mod(dp[i][1][0],Mod(dp[i-1][0][0],dp[i-1][1][0]));
dp[i][1][1]=Mod(dp[i][1][1],Mod(dp[i-1][0][1],dp[i-1][1][1]));
if(i>=k && sumb[i]-sumb[i-k]==0)dp[i][1][1]=Mod(dp[i][1][1],-dp[i-k][0][1]);
if(i>=k && sumb[i]-sumb[i-k]==0)dp[i][1][2]=Mod(dp[i][1][2],dp[i-k][0][1]);
dp[i][1][2]=Mod(dp[i][1][2],Mod(dp[i-1][0][2],dp[i-1][1][2]));
}
}
printf("%d\n",Mod(dp[n][0][2],dp[n][1][2]));
return 0;
}