BZOJ2006 [NOI2010]超级钢琴

题意

小Z是一个小有名气的钢琴家,最近C博士送给了小Z一架超级钢琴,小Z希望能够用这架钢琴创作出世界上最美妙的音乐。这架超级钢琴可以弹奏出$n$个音符,编号为$1$至$n$。第$i$个音符的美妙度为$A_i$,其中$A_i$可正可负。一个“超级和弦”由若干个编号连续的音符组成,包含的音符个数不少于$L$且不多于$R$。我们定义超级和弦的美妙度为其包含的所有音符的美妙度之和。两个超级和弦被认为是相同的,当且仅当这两个超级和弦所包含的音符集合是相同的。小Z决定创作一首由$k$个超级和弦组成的乐曲,为了使得乐曲更加动听,小Z要求该乐曲由$k$个不同的超级和弦组成。我们定义一首乐曲的美妙度为其所包含的所有超级和弦的美妙度之和。小Z想知道他能够创作出来的乐曲美妙度最大值是多少。
$n,k \leqslant 500000$。

题解

个人认为是一道不错的题,思路比较巧妙。先求出以每个点为左端点的答案以及答案的位置,这个可以用ST表轻松解决。接下来把这些答案丢到一个优先队列里头。丢到优先队列里头的东西大概长这样:$\{i,l,r,t\}$。其中$i$表示区间的左端点,$l$和$r$分别表示右端点的范围,$t$表示在这个范围内答案最优的右端点。每次从队首取出一个元素$\{i,l,r,t\}$,就把$\{i,l,t-1,ans(l,t-1)\}$和$\{i,t+1,r,ans(t+1,r)\}$丢到优先队列里头,并统计答案即可。
时间复杂度$O((n+k)\log n)$。

代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
#include<queue>
#include<cstdio>
#include<cstdlib>
#include<cstring>
#include<iostream>
#include<algorithm>
using namespace std;
typedef long long ll;
const int INF=1e9;
const long double eps=1e-9;
const int maxn=5e5+10;
int sum[maxn],pos[maxn][20],logn[maxn];
struct node{
int i,l,r,t;
bool operator < (const node &rhs) const {
return sum[t]-sum[i-1]<sum[rhs.t]-sum[rhs.i-1];
}
};
priority_queue <node> q;
inline int read(){
int x=0,flag=1;
char ch=getchar();
while(!isdigit(ch) && ch!='-')ch=getchar();
if(ch=='-')flag=-1,ch=getchar();
while(isdigit(ch))x=(x<<3)+(x<<1)+ch-'0',ch=getchar();
return x*flag;
}
inline int GetPos(int l,int r){
int x=pos[l][logn[r-l+1]],y=pos[r-(1<<logn[r-l+1])+1][logn[r-l+1]];
return sum[x]>sum[y]?x:y;
}
int main(){
int i,j,k,m,n,l,r;
#ifndef ONLINE_JUDGE
freopen("BZOJ2006.in","r",stdin);
freopen("BZOJ2006.out","w",stdout);
#endif
n=read();k=read();l=read();r=read();
for(i=2;i<=n;i++)logn[i]=logn[i>>1]+1;
for(i=1;i<=n;i++)sum[i]=sum[i-1]+read(),pos[i][0]=i;
for(j=1;j<=20;j++)
for(i=1;i+(1<<j)-1<=n;i++){
int x=pos[i][j-1],y=pos[i+(1<<(j-1))][j-1];
pos[i][j]=sum[x]>sum[y]?x:y;
}
for(i=1;i<=n;i++){
int L=min(n+1,i+l-1),R=min(n,i+r-1);
if(L>R)continue;
int Pos=GetPos(L,R);
q.push((node){i,L,R,Pos});
}
ll ans=0;
for(i=1;i<=k;i++){
node t=q.top();q.pop();
ans+=1ll*(sum[t.t]-sum[t.i-1]);
if(t.l<=t.t-1)q.push((node){t.i,t.l,t.t-1,GetPos(t.l,t.t-1)});
if(t.t+1<=t.r)q.push((node){t.i,t.t+1,t.r,GetPos(t.t+1,t.r)});
}
printf("%lld\n",ans);
return 0;
}