BZOJ4919 大根堆

题意

给你一颗有 $n$ 个点的树,其中 $1$ 号点为根节点,每个点都有一个权值 $val_i$
你可以从树中选择一些点,注意如果 $i$ 与 $j$ 都被选中且 $i$ 是 $j$ 的祖先,那么必须满足 $val_i > val_j$
请你求出最多能同时选出多少个点。
$n \le 100000$。

题解

multiset+启发式合并

首先考虑序列上是如何在$O(n\log n)$内维护LIS的。相似的方式,我们考虑用std::multiset来维护一下这个序列。每次先将儿子上的序列的值启发式合并到父节点。再在父节点上加入一个新的值,也就是父节点的权值即可。每次合并之后,记得将儿子节点的std::multiset清空。做法看似十分暴力,但时间复杂度却是正确的。时间复杂度$O(n\log^2 n)$。

线段树合并

首先不难想到一个$O(n^2)$的DP。首先将权值离散化,记$dp[i][j]$表示到了$i$号节点,当前选择最大权值为$j$时能够选择的最多的点数,转移显然。注意到,转移有两个过程,第一个是把子树的$dp$值加起来,第二个是对$j$大于等于$val_u$的$dp$值对$dp[i][val[u]-1]$取$\max$即可。注意到这里的$dp$关于$j$是单调的,我们可以考虑维护$dp$数组的差分数组。观察性质后,容易发现,把子树的$dp$值加起来,也就是把差分数组加起来;而chkmax就是把差分数组中下标大于等于$val_u$的第一个有值的地方$-1$,再在$val_u$上$+1$。修改操作可以用线段树轻松实现,而将差分数组相加,我们可以直接用线段树合并实现。时间复杂度$O(n\log^2 n)$。

代码

multiset+启发式合并

注意std::multiset中,使用clear()后,空间将会被释放。所以这么做空间复杂度是正确的。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const int INF=1e9;
const long double eps=1e-9;
const int maxn=2e5+10;
int to[maxn<<1],nex[maxn<<1],beg[maxn],val[maxn];
int e;
multiset <int> Set[maxn];
multiset <int>::iterator it;
inline int read(){
int x=0,flag=1;
char ch=getchar();
while(!isdigit(ch) && ch!='-')ch=getchar();
if(ch=='-')flag=-1,ch=getchar();
while(isdigit(ch))x=(x<<3)+(x<<1)+ch-'0',ch=getchar();
return x*flag;
}
inline void add(int x,int y){
to[++e]=y;
nex[e]=beg[x];
beg[x]=e;
}
inline void Merge(int x,int y){
for(it=Set[y].begin();it!=Set[y].end();++it)Set[x].insert(*it);
Set[y].clear();
}
void dfs(int x,int fa){
int i,Max=0,id=0;
for(i=beg[x];i;i=nex[i]){
if(to[i]==fa)continue;
dfs(to[i],x);
if(Set[to[i]].size()>Max){
Max=Set[to[i]].size();
id=to[i];
}
}
swap(Set[x],Set[id]);
for(i=beg[x];i;i=nex[i]){
if(to[i]==fa || to[i]==id)continue;
Merge(x,to[i]);
}
it=Set[x].lower_bound(val[x]);
if(it==Set[x].end())Set[x].insert(val[x]);
else Set[x].erase(it),Set[x].insert(val[x]);
}
int main(){
int i,n,f;
#ifndef ONLINE_JUDGE
freopen("tree.in","r",stdin);
freopen("tree.out","w",stdout);
#endif
n=read();
for(i=1;i<=n;i++){
val[i]=read(),f=read();
if(i>1)add(i,f),add(f,i);
}
dfs(1,0);
printf("%d\n",Set[1].size());
return 0;
}

线段树合并

线段树合并的时候,必须采用动态开点线段树。稍微注意一下Merge()函数的写法。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const int INF=1e9;
const long double eps=1e-9;
const int maxn=2e5+10;
const int maxnode=1e7+10;
int to[maxn<<1],nex[maxn<<1],beg[maxn],val[maxn],Hash[maxn],rt[maxn];
int e,cnt;
inline int read(){
int x=0,flag=1;
char ch=getchar();
while(!isdigit(ch) && ch!='-')ch=getchar();
if(ch=='-')flag=-1,ch=getchar();
while(isdigit(ch))x=(x<<3)+(x<<1)+ch-'0',ch=getchar();
return x*flag;
}
inline void add(int x,int y){
to[++e]=y;
nex[e]=beg[x];
beg[x]=e;
}
struct Seg_T{
#define mid ((l+r)>>1)
int ch[maxnode][2],delta[maxnode];
int cnt;
int Merge(int x,int y){
if(!x || !y)return x+y;
delta[x]+=delta[y];
ch[x][0]=Merge(ch[x][0],ch[y][0]);
ch[x][1]=Merge(ch[x][1],ch[y][1]);
return x;
}
bool Delete(int o,int l,int r,int p){
if(!delta[o])return false;
if(l==r){delta[o]--;return true; }
if(p<=mid && Delete(ch[o][0],l,mid,p)){ delta[o]--;return true; }
if(Delete(ch[o][1],mid+1,r,p)){ delta[o]--;return true; }
return false;
}
void Modify(int &o,int l,int r,int p){
if(!o)o=++cnt;delta[o]++;
if(l==r)return;
if(p<=mid)Modify(ch[o][0],l,mid,p);
else Modify(ch[o][1],mid+1,r,p);
}
}T;
void dfs(int x,int fa){
int i;
for(i=beg[x];i;i=nex[i]){
if(to[i]==fa)continue;
dfs(to[i],x);
rt[x]=T.Merge(rt[x],rt[to[i]]);
}
T.Delete(rt[x],1,cnt,val[x]);
T.Modify(rt[x],1,cnt,val[x]);
}
int main(){
int i,n,f;
#ifndef ONLINE_JUDGE
freopen("tree.in","r",stdin);
freopen("tree.out","w",stdout);
#endif
n=read();
for(i=1;i<=n;i++){
val[i]=read(),f=read();Hash[i]=val[i];
if(i>1)add(i,f),add(f,i);
}
sort(Hash+1,Hash+n+1);
cnt=unique(Hash+1,Hash+n+1)-Hash-1;
for(i=1;i<=n;i++)val[i]=lower_bound(Hash+1,Hash+cnt+1,val[i])-Hash;
dfs(1,0);
printf("%d\n",T.delta[rt[1]]);
return 0;
}