树套树


树套树


概述

树套树一般应用在当一棵树解决不了的时候,一般是利用树状数组或线段树套平衡树或线段树。


线段树套平衡树

  1. 1 l r x,查询整数 x 在区间 [l,r][l,r] 内的排名。
  2. 2 l r k,查询区间 [l,r][l,r] 内排名为 k 的值。
  3. 3 pos x,将 pos 位置的数修改为 x。
  4. 4 l r x,查询整数 x 在区间 [l,r][l,r] 内的前驱(前驱定义为小于 x,且最大的数)。
  5. 5 l r x,查询整数 x 在区间 [l,r][l,r] 内的后继(后继定义为大于 x,且最小的数)。

如上所示我们既要维护一个顺序的序列,同时又要维护它们的相对大小关系,但是一棵树一般只能维护一个关系所以我们要使用嵌套的关系。用线段树来维护相对的序号,用平衡树来维护大小关系。那么操作一就是把左子树和右子树相对位置排名取和,操作二利用二分的思想套操作一,操作三是进行删点然后对每一个线段树里的平衡树修改,操作四五是对称操作,做一次递归操作。

#include<cmath>
#include<cstdio>
#include<cstring>
#include<iostream>
#include<algorithm>
#define inf 0x3f3f3f3f
using namespace std;
const int N = 1500010;
int n,m;
struct Node{
    int s[2],p,v;
    int size;
    void init(int _v,int _p)
    {
        v=_v,p=_p;
        size=1;
    }
}tr[N];
int L[N],R[N],T[N],idx;
int w[N];

void pushup(int x)  {
    tr[x].size=tr[tr[x].s[0]].size+tr[tr[x].s[1]].size+1;
}

void rotate(int x)  {
    int y=tr[x].p,z=tr[y].p;
    int k=tr[y].s[1]==x;
    tr[z].s[tr[z].s[1]==y]=x,tr[x].p=z;
    tr[y].s[k]=tr[x].s[k^1],tr[tr[x].s[k^1]].p=y;
    tr[x].s[k^1]=y,tr[y].p=x;
    pushup(y),pushup(x);
}

void splay(int& root,int x,int k)
{
    while(tr[x].p!=k)
    {
        int y=tr[x].p,z=tr[y].p;
        if(z!=k)    
            if((tr[y].s[1]==x)^(tr[z].s[1]==y)) rotate(x);
            else    rotate(y);
        rotate(x);
    }
    if(!k)  root=x;
}

void insert(int &root,int v)
{
    int u=root,p=0;
    while(u)    p=u,u=tr[u].s[v>tr[u].v];
    u=++idx;
    if(p)   tr[p].s[v>tr[p].v]=u;
    tr[u].init(v,p);
    splay(root,u,0);
}

int get_k(int root,int v)
{
    int u=root,res=0;
    while(u)
    {
        if(tr[u].v<v)   res+=tr[tr[u].s[0]].size+1,u=tr[u].s[1];
        else    u=tr[u].s[0];
    }
    return res;
}

void update(int &root,int x,int y)
{
    int u=root;
    while(u)
    {
        if(tr[u].v==x)  break;
        if(tr[u].v<x)   u=tr[u].s[1];
        if(tr[u].v>x)   u=tr[u].s[0];
    }
    splay(root,u,0);
    int l=tr[u].s[0],r=tr[u].s[1];
    while(tr[l].s[1])   l=tr[l].s[1];
    while(tr[r].s[0])   r=tr[r].s[0];
    splay(root,l,0),splay(root,r,l);
    tr[r].s[0]=0;
    pushup(r),pushup(l);
    insert(root,y);
}

int get_pre(int root,int v)
{
    int u=root,res=-inf;
    while(u)
    {
        if(tr[u].v<v)   res=max(res,tr[u].v),u=tr[u].s[1];
        else if(tr[u].v>=v) u=tr[u].s[0];
    }
    return res;
}

int get_suc(int root,int v)
{
    int u=root,res=inf;
    while(u)
    {
        if(tr[u].v>v)   res=min(res,tr[u].v),u=tr[u].s[0];
        else    u=tr[u].s[1];
    }
    return res;
}

void build(int u,int l,int r)
{
    L[u]=l,R[u]=r;
    insert(T[u],-inf),insert(T[u],inf);
    for(int i=l;i<=r;i++)   insert(T[u],w[i]);
    if(l==r)    return;
    int mid=l+r>>1;
    build(u<<1,l,mid),build(u<<1|1,mid+1,r);
}

int query(int u,int a,int b,int x)
{
    if(L[u]>=a&&R[u]<=b)    return get_k(T[u],x)-1;
    int mid=L[u]+R[u]>>1,res=0;
    if(a<=mid)  res+=query(u<<1,a,b,x);
    if(mid<b)   res+=query(u<<1|1,a,b,x);
    return res;
}

void change(int u,int p,int x)
{
    update(T[u],w[p],x);
    if(L[u]==R[u])  return;
    int mid=L[u]+R[u]>>1;
    if(p<=mid)  change(u<<1,p,x);
    else    change(u<<1|1,p,x);
}

int query_ptr(int u,int a,int b,int x)
{
    if(L[u]>=a&&R[u]<=b)  return get_pre(T[u],x);
    int mid=L[u]+R[u]>>1,res=-inf;
    if(a<=mid)  res=max(res,query_ptr(u<<1,a,b,x));
    if(b>mid)   res=max(res,query_ptr(u<<1|1,a,b,x));
    return res;
}

int query_suc(int u,int a,int b,int x)
{
    if(L[u]>=a&&R[u]<=b)  return get_suc(T[u],x);
    int mid=L[u]+R[u]>>1,res=inf;
    if(a<=mid)  res=min(res,query_suc(u<<1,a,b,x));
    if(b>mid)   res=min(res,query_suc(u<<1|1,a,b,x));
    return res;
}

int main()
{
    scanf("%d%d",&n,&m);
    for(int i=1;i<=n;i++)   scanf("%d",&w[i]);
    build(1,1,n);
    while(m--)
    {
        int op,a,b,x;
        scanf("%d",&op);
        if(op==1)
        {
            scanf("%d%d%d",&a,&b,&x);
            printf("%d\n",query(1,a,b,x)+1);
        }
        else if(op==2)
        {
            scanf("%d%d%d",&a,&b,&x);
            int l=0,r=1e8;
            while(l<r)
            {
                int mid=l+r+1>>1;
                if(query(1,a,b,mid)+1<=x)   l=mid;
                else r=mid-1;
            }
            printf("%d\n",l);
        }
        else if(op==3)
        {
            scanf("%d%d",&a,&x);
            change(1,a,x);
            w[a]=x;
        }
        else if(op==4)
        {
            scanf("%d%d%d",&a,&b,&x);
            printf("%d\n",query_ptr(1,a,b,x));
        }
        else if(op==5)
        {
            scanf("%d%d%d",&a,&b,&x);
            printf("%d\n",query_suc(1,a,b,x));
        }
    }
    return 0;
}

线段树套线段树

有 N 个位置,M 个操作。每个位置可以同时存储多个数。

操作有两种,每次操作:

  • 如果是 1 a b c 的形式,表示在第 a 个位置到第 b 个位置,每个位置加入一个数 c。
  • 如果是 2 a b c 的形式,表示询问从第 a 个位置到第 b 个位置,第 c 大的数是多少。

操作的思想是利用一棵权值线段树来维护各个点的数量,但是由于存在区间的原因所以我们需要再在里面套一棵线段树来维护一下区间的点数量,可以利用标记持久化和动态开点的思想来解决问题。

#include<iostream>
#include<cstring>
#include<algorithm>
#include<vector>
#define inf 0x3f3f3f3f
using namespace std; 
const int N = 5e4+10,P = N*17*17,M = N*4;
int n,m;
struct Node{
    int l,r,sum,add;
}tr[P];
int L[M],R[M],T[M],idx;
struct Query{
    int op,a,b,c;
}q[N];
vector<int> nums;
int get(int c)  {
    return lower_bound(nums.begin(),nums.end(),c)-nums.begin();
}
void build(int u,int l,int r)
{
    L[u]=l,R[u]=r,T[u]=++idx;
    if(l==r)    return;
    int mid=l+r>>1;
    build(u<<1,l,mid),build(u<<1|1,mid+1,r);
}
int intersection(int a,int b,int c,int d)   {
    return  min(b,d)-max(a,c)+1;
}
void update(int u,int l,int r,int pl,int pr)
{
    tr[u].sum+=intersection(l,r,pl,pr);
    if(l>=pl&&r<=pr)
    {
        tr[u].add++;
        return;
    }
    int mid=l+r>>1;
    if(pl<=mid)
    {
        if(!tr[u].l)    tr[u].l=++idx;
        update(tr[u].l,l,mid,pl,pr);
    }
    if(pr>mid)
    {
        if(!tr[u].r)    tr[u].r=++idx;
        update(tr[u].r,mid+1,r,pl,pr);
    }
}
int get_sum(int u,int l,int r,int pl,int pr,int add)
{
    if(l>=pl&&r<=pr)    return tr[u].sum+(r-l+1)*add;
    int mid=l+r>>1,res=0;
    add+=tr[u].add;
    if(pl<=mid)
    {
        if(tr[u].l) res+=get_sum(tr[u].l,l,mid,pl,pr,add);
        else    res+=intersection(l,mid,pl,pr)*add;
    }
    if(pr>mid)
    {
        if(tr[u].r) res+=get_sum(tr[u].r,mid+1,r,pl,pr,add);
        else    res+=intersection(mid+1,r,pl,pr)*add;
    }
    return res;
}
void change(int u,int a,int b,int c)
{
    update(T[u],1,n,a,b);
    if(L[u]==R[u])  return;
    int mid=L[u]+R[u]>>1;
    if(c<=mid)  change(u<<1,a,b,c);
    else    change(u<<1|1,a,b,c);
}
int query(int u,int a,int b,int c)
{
    if(L[u]==R[u])  return R[u];
    int mid=L[u]+R[u]>>1;
    int k=get_sum(T[u<<1|1],1,n,a,b,0);
    if(k>=c)    return query(u<<1|1,a,b,c);
    return  query(u<<1,a,b,c-k);
}
int main()
{
    scanf("%d%d",&n,&m);
    for(int i=0;i<m;i++)
    {
        scanf("%d%d%d%d",&q[i].op,&q[i].a,&q[i].b,&q[i].c);
        if(q[i].op==1)  nums.push_back(q[i].c);
    }
    sort(nums.begin(),nums.end());
    nums.erase(unique(nums.begin(),nums.end()),nums.end());
    build(1,0,nums.size()-1);
    for(int i=0;i<m;i++)
    {
        int op=q[i].op,a=q[i].a,b=q[i].b,c=q[i].c;
        if(op==1)   change(1,a,b,get(c));
        else    printf("%d\n",nums[query(1,a,b,c)]);
    }
    return 0;
}


文章作者: Dydong
版权声明: 本博客所有文章除特別声明外,均采用 CC BY 4.0 许可协议。转载请注明来源 Dydong !
  目录