ST表(Sparse Table)


ST表(Sparse Table)


  1. 问题概述
  2. ST表介绍
  3. 与BIT的比对
  4. 代码实现

  问题概述:已知有一个长度位n的数组,里面的数是固定的,进行m次的查询求出给出区间范围内的最大||最小值。

  暴力算法:直接一波硬算,疯狂重复查询,时间复杂度是O(mn),于是引进ST表~


ST表介绍:

  ST表是一个对固定的数组区间初始化的打表方法,把时间复杂度降低至O(nlogn),查询的时间为O(1)。具体的做法是定义一个ST的二维表格,其中的

ST[i][j]代表覆盖的区间是i<=dx<=i+2^j-1。如图所示是dx=1和dx=6的覆盖区间,ST[i][j]之中记录的是区间中的最大||最小值,如果要查询区间[1,9]的最大||最小值,我们只需要取[1,8]和[6,9]或者更大重复的覆盖区间即可,ans=max(ST[1,3],ST[6,2])。


与BIT对比:

  那么如此的算法与BIT有什么联系和区别呢,显而易见他们都可以查询去区间最大||最小值,初始化的时间复杂度都是O(nlogn),但是ST表却不支持动态的修改与BIT的前缀和维护,但是查询速度ST表是O(1),BIT是O(nlogn)。copy一遍之前BIT的代码:

 1   #include<bits/stdc++.h>
 2   using namespace std;
 3   const int maxn = 1e6+10;
 4   int n,m;
 5   int a[maxn],tr[maxn];
 6   int lowbit(int x){return x& -x;}
 7   void add(int point,int dig)    {for(int i=point;i<=n;i+=lowbit(i))    tr[i]+=dig;}
 8   int ask(int x)
 9   {
10      int ans=0;
11      for(int i=x;i>0;i-=lowbit(i))    ans+=tr[i];
12      return ans;
13  }
14  int main()
15  {
16      memset(tr,0,sizeof(tr));
17      scanf("%d%d",&n,&m);
18      for(int i=1;i<=n;i++)    scanf("%d",&a[i]);
19      for(int i=1;i<=n;i++)    add(i,a[i]);
20      while(m--)
21      {
22          int op;    scanf("%d",&op);
23          if(op==1){
24              int x,k;    scanf("%d%d",&x,&k);
25              add(x,k);
26          }
27          else{
28              int x,y;    scanf("%d%d",&x,&y);
29              printf("%d\n",ask(y)-ask(x-1));
30          }
31      }
32      return 0;
33   }

代码实现:

  ST的代码实现分为两个阶段:

  初始化函数:从区间为单个点开始扩展,输入的时候是单点输入那么就可以直接读入到ST[i][0]中,然后就是从j=1(len=1)开始扩展到一把梭直接覆盖到全部区间结束,那么i的范围是从1开始到覆盖不下了停止,每一次新的j所用的数据都是之前j-1所提供的数据,所以st[i][j]=max(st[i][j-1],st[i+(1<<j-1)][j-1])。

1   scanf("%d%d",&n,&m);
2     for(int i=1;i<=n;i++)    scanf("%d",&st[i][0]);
3     for(int j=1;j<21;j++)
4         for(int i=1;i+(1<<j)-1<=n;i++)
5             st[i][j]=max(st[i][j-1],st[i+(1<<j-1)][j-1]);

  查询方法:我们要选取从左端点开始选一段长度需要大于(r-l>>1)的长度,显而易见不然查询是失败的,这里补一个公式就是2^log2x>x/2,因此我们选用的长度len=log2(r-l+1),因此返回max(st[l][len],st[r-(1<<len)+1][len])即可。

1 int ask(int l,int r)
2 {
3     int len=log2(r-l+1);
4     return max(st[l][len],st[r-(1<<len)+1][len]);
5 }

  完整代码:

 1 #include<bits/stdc++.h>
 2 using namespace std;
 3 const int maxn = 1e5+10;
 4 int a[maxn];
 5 int n,m;
 6 int st[maxn][25];
 7 int ask(int l,int r)
 8 {
 9     int len=log2(r-l+1);
10     return max(st[l][len],st[r-(1<<len)+1][len]);
11 }
12 int main()
13 {
14     scanf("%d%d",&n,&m);
15     for(int i=1;i<=n;i++)    scanf("%d",&st[i][0]);
16     for(int j=1;j<21;j++)
17         for(int i=1;i+(1<<j)-1<=n;i++)
18             st[i][j]=max(st[i][j-1],st[i+(1<<j-1)][j-1]);
19     while(m--)
20     {
21         int l,r;
22         scanf("%d%d",&l,&r);
23         printf("%d\n",ask(l,r));
24     }
25     return 0;
26 }


文章作者: Dydong
版权声明: 本博客所有文章除特別声明外,均采用 CC BY 4.0 许可协议。转载请注明来源 Dydong !
  目录