Splay


Splay


概述

平衡树,尽量把时间复杂度压到nlogn,支持的方面有区间修改区间查询,区间翻转,区间删除,整段最大子序列,查询第K大的数,查询数是第几大。也可以运用在树套数中。


核心函数

核心思想通过判断节点的关系如果是直线则先右旋再左旋,否则连续左旋即可(对称反之)。同时用pushup和pushdown去维护节点之间的关系。

void rotate(int x)
{
    int y=tr[x].p,z=tr[y].p;
    pushdown(y),pushdown(x);
    int k=tr[y].s[1]==x;
    tr[z].s[tr[z].s[1]==y]=x,tr[x].p=z;
    tr[y].s[k]=tr[x].s[k^1],tr[tr[x].s[k^1]].p=y;
    tr[x].s[k^1]=y,tr[y].p=x;
    pushup(y),pushup(x);
}
void splay(int x,int k)
{
    while(tr[x].p!=k)
    {
        int y=tr[x].p,z=tr[y].p;
        if(z!=k)
            if((tr[y].s[1]==x)^(tr[z].s[1]==y)) rotate(x);
            else    rotate(y);
        rotate(x);
    }
    if(!k)  root=x;
}

插入值

void insert(int &root,int v)
{
    int u=root,p=0;
    while(u)    p=u,u=tr[u].s[v>tr[u].v];
    u=++idx;
    if(p)   tr[p].s[v>tr[p].v]=u;
    tr[u].init(v,p);
    splay(root,u,0);
}

区间操作

核心是先找区间左点l,右点r,主要要插入哨兵,splay(l,0)再slay(r,l),接下来只要去修改r的左子树即可。


寻找第k大数

操作也是先区间操作,然后递归查询。

int get_k(int k)
{
    int u=root;
    while(u)
    {
        pushdown(u);
        if(tr[tr[u].s[0]].size>=k)  u=tr[u].s[0];
        else if(tr[tr[u].s[0]].size+1==k)   return u;
        else k-=tr[tr[u].s[0]].size+1,u=tr[u].s[1];
    }
    return -1;
}

翻转

利用pushdown操作去进行懒标记的操作,标记代表交换左右子树。

void pushdown(int x)
{
    if(tr[x].flag){
        swap(tr[x].s[0],tr[x].s[1]);
        tr[tr[x].s[0]].flag^=1;
        tr[tr[x].s[1]].flag^=1;
        tr[x].flag^=1;
    }
}

求最大子序列

大致也是利用pushdown的操作,主要维护的是sum,前缀,后缀和最大序列。然后答案是三部分最大值。


删点

可以利用如下方法也可以之间用一个cnt记录重复的元素,以下是先找删点然后把它转到根,接下来找左子树的最右点和右子树的最左点,接下来 splay(l,0),splay(r,l),把左子树清空即可。

void update(int x,int y)
{
    int u=root;
    while(u)
    {
        if(tr[u].v==x)  break;
        if(tr[u].v<x)   u=tr[u].s[1];
        if(tr[u].v>x)   u=tr[u].s[0];
    }
    splay(root,u,0);
    int l=tr[u].s[0],r=tr[u].s[1];
    while(tr[l].s[1])   l=tr[l].s[1];
    while(tr[r].s[0])   r=tr[r].s[0];
    splay(l,0),splay(r,l);
    tr[r].s[0]=0;
    pushup(r),pushup(l);
}

维护数列

1.pos插入一段序列

2.pos删除一段序列

3.修改一段序列为c

4.区间翻转

5.区间求和

6.求最大子序列

#include<iostream>
#include<cstring>
#include<algorithm>
#define inf 0x3f3f3f3f
using namespace std;
const int N = 5e5+10;
int n,m;
struct Node{
    int s[2],p,v;
    int rev,same;
    int size,sum,ms,ls,rs;
    void init(int _v,int _p)
    {
        s[0]=s[1]=0,p=_p,v=_v;
        rev=same=0;
        size=1,sum=ms=v;
        ls=rs=max(v,0);
    }
}tr[N];
int root,nodes[N],tt;
int w[N];
void pushup(int x)
{
    auto& u=tr[x],&l=tr[tr[x].s[0]],&r=tr[tr[x].s[1]];
    u.size=l.size+r.size+1;
    u.sum=l.sum+r.sum+u.v;
    u.ls=max(l.ls,l.sum+u.v+r.ls);
    u.rs=max(r.rs,r.sum+u.v+l.rs);
    u.ms=max(max(l.ms,r.ms),l.rs+u.v+r.ls);
}
void pushdown(int x)
{
    auto &u=tr[x],&l=tr[u.s[0]],&r=tr[u.s[1]];
    if(u.same)
    {
        u.same=u.rev=0;
        if(u.s[0])  l.same=1,l.v=u.v,l.sum=l.v*l.size;
        if(u.s[1])  r.same=1,r.v=u.v,r.sum=r.v*r.size;
        if(u.v>0){
            if(u.s[0])  l.ms=l.ls=l.rs=l.sum;
            if(u.s[1])  r.ms=r.ls=r.rs=r.sum;
        }
        else{
            if(u.s[0])  l.ms=l.v,l.ls=l.rs=0;
            if(u.s[1])  r.ms=r.v,r.ls=r.rs=0;
        }
    }
    else if(u.rev)
    {
        u.rev^=1,l.rev^=1,r.rev^=1;
        swap(l.ls,l.rs),swap(r.ls,r.rs);
        swap(l.s[0],l.s[1]),swap(r.s[0],r.s[1]);
    }
}
void rotate(int x)
{
    int y=tr[x].p,z=tr[y].p;
    pushdown(y),pushdown(x);
    int k=tr[y].s[1]==x;
    tr[z].s[tr[z].s[1]==y]=x,tr[x].p=z;
    tr[y].s[k]=tr[x].s[k^1],tr[tr[x].s[k^1]].p=y;
    tr[x].s[k^1]=y,tr[y].p=x;
    pushup(y),pushup(x);
}
void splay(int x,int k)
{
    while(tr[x].p!=k)
    {
        int y=tr[x].p,z=tr[y].p;
        if(z!=k)
            if((tr[y].s[1]==x)^(tr[z].s[1]==y)) rotate(x);
            else    rotate(y);
        rotate(x);
    }
    if(!k)  root=x;
}
int get_k(int k)
{
    int u=root;
    while(u)
    {
        pushdown(u);
        if(tr[tr[u].s[0]].size>=k)  u=tr[u].s[0];
        else if(tr[tr[u].s[0]].size+1==k)   return u;
        else k-=tr[tr[u].s[0]].size+1,u=tr[u].s[1];
    }
    return -1;
}
int build(int l,int r,int root)
{
    int mid=l+r>>1;
    int u=nodes[tt--];
    tr[u].init(w[mid],root);
    if(l<mid)   tr[u].s[0]=build(l,mid-1,u);
    if(r>mid)   tr[u].s[1]=build(mid+1,r,u);
    pushup(u);
    return u;
}
void dfs(int u)
{
    if(tr[u].s[0])  dfs(tr[u].s[0]);
    if(tr[u].s[1])  dfs(tr[u].s[1]);
    nodes[++tt]=u;
}
int main()
{
    for(int i=1;i<N;i++)    nodes[++tt]=i;
    scanf("%d%d",&n,&m);
    tr[0].ms=-inf,w[0]=w[n+1]=-inf;
    for(int i=1;i<=n;i++)   scanf("%d",&w[i]);
    root = build(0,n+1,0);
    char op[20];
    while(m--)
    {
        scanf("%s",op);
        if(!strcmp(op,"INSERT"))
        {
            int posi,tot;
            scanf("%d%d",&posi,&tot);
            for(int i=0;i<tot;i++)  scanf("%d",&w[i]);
            int l=get_k(posi+1),r=get_k(posi+2);
            splay(l,0),splay(r,l);
            int u=build(0,tot-1,r);
            tr[r].s[0]=u;
            pushup(r),pushup(l);
        }
        else if(!strcmp(op,"DELETE"))
        {
            int posi,tot;
            scanf("%d%d",&posi,&tot);
            int l=get_k(posi),r=get_k(posi+tot+1);
            splay(l,0),splay(r,l);
            dfs(tr[r].s[0]);
            tr[r].s[0]=0;
            pushup(r),pushup(l);
        }
        else if(!strcmp(op,"MAKE-SAME"))
        {
            int posi,tot,c;
            scanf("%d%d%d",&posi,&tot,&c);
            int l=get_k(posi),r=get_k(posi+tot+1);
            splay(l,0),splay(r,l);
            auto &son=tr[tr[r].s[0]];
            son.same=1; son.v=c,son.sum=c*son.size;
            if(c>0) son.ms=son.ls=son.rs=son.sum;
            else son.ms=c,son.ls=son.rs=0;
            pushup(r),pushup(l);
        }
        else if(!strcmp(op,"REVERSE"))
        {
            int posi,tot;
            scanf("%d%d",&posi,&tot);
            int l=get_k(posi),r=get_k(posi+tot+1);
            splay(l,0),splay(r,l);
            auto &son=tr[tr[r].s[0]];
            son.rev^=1;
            swap(son.ls,son.rs);
            swap(son.s[0],son.s[1]);
            pushup(r),pushup(l);
        }
        else if(!strcmp(op,"GET-SUM"))
        {
            int posi,tot;
            scanf("%d%d",&posi,&tot);
            int l=get_k(posi),r=get_k(posi+tot+1);
            splay(l,0),splay(r,l);
            printf("%d\n",tr[tr[r].s[0]].sum);
        }
        else printf("%d\n",tr[root].ms);
    }
    return 0;
}


文章作者: Dydong
版权声明: 本博客所有文章除特別声明外,均采用 CC BY 4.0 许可协议。转载请注明来源 Dydong !
  目录