Splay
概述
平衡树,尽量把时间复杂度压到nlogn,支持的方面有区间修改区间查询,区间翻转,区间删除,整段最大子序列,查询第K大的数,查询数是第几大。也可以运用在树套数中。
核心函数
核心思想通过判断节点的关系如果是直线则先右旋再左旋,否则连续左旋即可(对称反之)。同时用pushup和pushdown去维护节点之间的关系。
void rotate(int x)
{
int y=tr[x].p,z=tr[y].p;
pushdown(y),pushdown(x);
int k=tr[y].s[1]==x;
tr[z].s[tr[z].s[1]==y]=x,tr[x].p=z;
tr[y].s[k]=tr[x].s[k^1],tr[tr[x].s[k^1]].p=y;
tr[x].s[k^1]=y,tr[y].p=x;
pushup(y),pushup(x);
}
void splay(int x,int k)
{
while(tr[x].p!=k)
{
int y=tr[x].p,z=tr[y].p;
if(z!=k)
if((tr[y].s[1]==x)^(tr[z].s[1]==y)) rotate(x);
else rotate(y);
rotate(x);
}
if(!k) root=x;
}
插入值
void insert(int &root,int v)
{
int u=root,p=0;
while(u) p=u,u=tr[u].s[v>tr[u].v];
u=++idx;
if(p) tr[p].s[v>tr[p].v]=u;
tr[u].init(v,p);
splay(root,u,0);
}
区间操作
核心是先找区间左点l,右点r,主要要插入哨兵,splay(l,0)再slay(r,l),接下来只要去修改r的左子树即可。
寻找第k大数
操作也是先区间操作,然后递归查询。
int get_k(int k)
{
int u=root;
while(u)
{
pushdown(u);
if(tr[tr[u].s[0]].size>=k) u=tr[u].s[0];
else if(tr[tr[u].s[0]].size+1==k) return u;
else k-=tr[tr[u].s[0]].size+1,u=tr[u].s[1];
}
return -1;
}
翻转
利用pushdown操作去进行懒标记的操作,标记代表交换左右子树。
void pushdown(int x)
{
if(tr[x].flag){
swap(tr[x].s[0],tr[x].s[1]);
tr[tr[x].s[0]].flag^=1;
tr[tr[x].s[1]].flag^=1;
tr[x].flag^=1;
}
}
求最大子序列
大致也是利用pushdown的操作,主要维护的是sum,前缀,后缀和最大序列。然后答案是三部分最大值。
删点
可以利用如下方法也可以之间用一个cnt记录重复的元素,以下是先找删点然后把它转到根,接下来找左子树的最右点和右子树的最左点,接下来 splay(l,0),splay(r,l),把左子树清空即可。
void update(int x,int y)
{
int u=root;
while(u)
{
if(tr[u].v==x) break;
if(tr[u].v<x) u=tr[u].s[1];
if(tr[u].v>x) u=tr[u].s[0];
}
splay(root,u,0);
int l=tr[u].s[0],r=tr[u].s[1];
while(tr[l].s[1]) l=tr[l].s[1];
while(tr[r].s[0]) r=tr[r].s[0];
splay(l,0),splay(r,l);
tr[r].s[0]=0;
pushup(r),pushup(l);
}
维护数列
1.pos插入一段序列
2.pos删除一段序列
3.修改一段序列为c
4.区间翻转
5.区间求和
6.求最大子序列
#include<iostream>
#include<cstring>
#include<algorithm>
#define inf 0x3f3f3f3f
using namespace std;
const int N = 5e5+10;
int n,m;
struct Node{
int s[2],p,v;
int rev,same;
int size,sum,ms,ls,rs;
void init(int _v,int _p)
{
s[0]=s[1]=0,p=_p,v=_v;
rev=same=0;
size=1,sum=ms=v;
ls=rs=max(v,0);
}
}tr[N];
int root,nodes[N],tt;
int w[N];
void pushup(int x)
{
auto& u=tr[x],&l=tr[tr[x].s[0]],&r=tr[tr[x].s[1]];
u.size=l.size+r.size+1;
u.sum=l.sum+r.sum+u.v;
u.ls=max(l.ls,l.sum+u.v+r.ls);
u.rs=max(r.rs,r.sum+u.v+l.rs);
u.ms=max(max(l.ms,r.ms),l.rs+u.v+r.ls);
}
void pushdown(int x)
{
auto &u=tr[x],&l=tr[u.s[0]],&r=tr[u.s[1]];
if(u.same)
{
u.same=u.rev=0;
if(u.s[0]) l.same=1,l.v=u.v,l.sum=l.v*l.size;
if(u.s[1]) r.same=1,r.v=u.v,r.sum=r.v*r.size;
if(u.v>0){
if(u.s[0]) l.ms=l.ls=l.rs=l.sum;
if(u.s[1]) r.ms=r.ls=r.rs=r.sum;
}
else{
if(u.s[0]) l.ms=l.v,l.ls=l.rs=0;
if(u.s[1]) r.ms=r.v,r.ls=r.rs=0;
}
}
else if(u.rev)
{
u.rev^=1,l.rev^=1,r.rev^=1;
swap(l.ls,l.rs),swap(r.ls,r.rs);
swap(l.s[0],l.s[1]),swap(r.s[0],r.s[1]);
}
}
void rotate(int x)
{
int y=tr[x].p,z=tr[y].p;
pushdown(y),pushdown(x);
int k=tr[y].s[1]==x;
tr[z].s[tr[z].s[1]==y]=x,tr[x].p=z;
tr[y].s[k]=tr[x].s[k^1],tr[tr[x].s[k^1]].p=y;
tr[x].s[k^1]=y,tr[y].p=x;
pushup(y),pushup(x);
}
void splay(int x,int k)
{
while(tr[x].p!=k)
{
int y=tr[x].p,z=tr[y].p;
if(z!=k)
if((tr[y].s[1]==x)^(tr[z].s[1]==y)) rotate(x);
else rotate(y);
rotate(x);
}
if(!k) root=x;
}
int get_k(int k)
{
int u=root;
while(u)
{
pushdown(u);
if(tr[tr[u].s[0]].size>=k) u=tr[u].s[0];
else if(tr[tr[u].s[0]].size+1==k) return u;
else k-=tr[tr[u].s[0]].size+1,u=tr[u].s[1];
}
return -1;
}
int build(int l,int r,int root)
{
int mid=l+r>>1;
int u=nodes[tt--];
tr[u].init(w[mid],root);
if(l<mid) tr[u].s[0]=build(l,mid-1,u);
if(r>mid) tr[u].s[1]=build(mid+1,r,u);
pushup(u);
return u;
}
void dfs(int u)
{
if(tr[u].s[0]) dfs(tr[u].s[0]);
if(tr[u].s[1]) dfs(tr[u].s[1]);
nodes[++tt]=u;
}
int main()
{
for(int i=1;i<N;i++) nodes[++tt]=i;
scanf("%d%d",&n,&m);
tr[0].ms=-inf,w[0]=w[n+1]=-inf;
for(int i=1;i<=n;i++) scanf("%d",&w[i]);
root = build(0,n+1,0);
char op[20];
while(m--)
{
scanf("%s",op);
if(!strcmp(op,"INSERT"))
{
int posi,tot;
scanf("%d%d",&posi,&tot);
for(int i=0;i<tot;i++) scanf("%d",&w[i]);
int l=get_k(posi+1),r=get_k(posi+2);
splay(l,0),splay(r,l);
int u=build(0,tot-1,r);
tr[r].s[0]=u;
pushup(r),pushup(l);
}
else if(!strcmp(op,"DELETE"))
{
int posi,tot;
scanf("%d%d",&posi,&tot);
int l=get_k(posi),r=get_k(posi+tot+1);
splay(l,0),splay(r,l);
dfs(tr[r].s[0]);
tr[r].s[0]=0;
pushup(r),pushup(l);
}
else if(!strcmp(op,"MAKE-SAME"))
{
int posi,tot,c;
scanf("%d%d%d",&posi,&tot,&c);
int l=get_k(posi),r=get_k(posi+tot+1);
splay(l,0),splay(r,l);
auto &son=tr[tr[r].s[0]];
son.same=1; son.v=c,son.sum=c*son.size;
if(c>0) son.ms=son.ls=son.rs=son.sum;
else son.ms=c,son.ls=son.rs=0;
pushup(r),pushup(l);
}
else if(!strcmp(op,"REVERSE"))
{
int posi,tot;
scanf("%d%d",&posi,&tot);
int l=get_k(posi),r=get_k(posi+tot+1);
splay(l,0),splay(r,l);
auto &son=tr[tr[r].s[0]];
son.rev^=1;
swap(son.ls,son.rs);
swap(son.s[0],son.s[1]);
pushup(r),pushup(l);
}
else if(!strcmp(op,"GET-SUM"))
{
int posi,tot;
scanf("%d%d",&posi,&tot);
int l=get_k(posi),r=get_k(posi+tot+1);
splay(l,0),splay(r,l);
printf("%d\n",tr[tr[r].s[0]].sum);
}
else printf("%d\n",tr[root].ms);
}
return 0;
}