Icebound

icebound-area

【高级数据结构】Splay整理

Yts教授喜欢数据结构,他教了我们许多高级数据结构。
这一天Yts教授看到了萌萌的icebound同学,于是就找到icebound同学,教会了他splay。。。。
在算法竞赛里常用的平衡树只有splay和treap,然而treap由于std::set的原因,鸡肋了许多。
Splay的高明之处在于,Splay有一个神奇的操作:换根。

我们定义函数splay(x,y):将x点换成y的儿子,并且不改变整颗树的中序遍历。
这太imba了!!!在不改变中序遍历的情况下,可以随便换!
而且,如果我们splay(x,0) 就可以把x点换成根!
这太imba了!!!
而splay操作则是由zig、zag(左旋与右旋)组成的。
所谓左旋和右旋,是在一个很小的子树内执行的操作。
zig(x),只能对左孩子使用,会在不改变中序遍历的情况下使得x点的深度减少1,但是会使得x点从左儿子变成右儿子。
zag(x),把右儿子变成左儿子,深度减1。
有了zig和zag,我们可以着手构建splay函数了。
首先,我们要把一个点转到某个点下方,需要用到好多次zig和zag减少深度,同时,我们还要保证中序遍历不变。
这就需要分情况讨论了。我们设目标节点为ro,操作节点为x,x的父亲为y,y的父亲为z
第一种情况:
如果z是ro了,那么已经到了“肉眼可见“范围内了,我们只需要根据左右执行一次zig或zag即可。
如果z不是ro。那么再分两种情况讨论:
如果y是z的左儿子,且x是y的左儿子,那么zig(y),zig(x) 如果x是y的右儿子,需要先zag(x)再zig(x)一次
同理 如果是右儿子 则zag(y) zag(x) 或者 zig(x) zag(x) 即可
然后就一直向上跑就行了
在有了这个操作之后,我们可以用splay树处理区间问题了。
我们维护一个数组下标为1-n的数组,就需要建立一棵中序遍历为1-n的平衡树。
我们查询a[i] 就需要查询期中第i大的数,这个由查询size得到。
我们想要得到区间[l,r]的信息,就需要把l-1翻转到根,r+1翻转到根的儿子位置,这样根据中序遍历,根的右儿子的左儿子就是代表[l,r]区间的子树了。
如果我们想翻转区间[l,r],只需要改变中序遍历,就把[l,r]子树的左右子树全部交换即可。
如果想要区间修改,也在[l,r]子树操作就行。不过注意打lazy标记,标记下传顺序也需要仔细考虑,在此不再赘述了。
一道完整的模板题:bzoj1500

#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<cmath>
using namespace std;
const int N=510005*2;
int root,cnt,a[N],ht,es[N];
int n,m,pos,len;
char s[30];
struct splaytree
{
    int l,r,f,v,size;
    int rev;
    int ch;
    int tot;
    int lmx,rmx,smx;
}tr[N],zero;
inline void down_rev(int &p)
{
	if(!p)return;
	swap(tr[p].l,tr[p].r);
	swap(tr[p].lmx,tr[p].rmx);
	tr[p].rev^=1;
}
inline void down_same(int &p,int x)
{
	if(!p)return;
	tr[p].v=x;
	tr[p].tot=x*tr[p].size;
	tr[p].lmx=tr[p].rmx=tr[p].smx=max(x,x*tr[p].size);
	tr[p].ch=1;
}
inline void down(int &p)
{
	if(p==0)return;
    if(tr[p].ch==1)
    {
    	down_same(tr[p].l,tr[p].v);
    	down_same(tr[p].r,tr[p].v);
        tr[p].ch=0;
    }
    if(tr[p].rev==1)
    {
    	down_rev(tr[p].l);
    	down_rev(tr[p].r);
        tr[p].rev=0;
    }
}
inline void up(int p)
{
	if(p==0)return;
    int l=tr[p].l,r=tr[p].r;
    tr[p].size=tr[l].size+tr[r].size+1;
    tr[p].tot=tr[l].tot+tr[r].tot+tr[p].v;
    tr[p].lmx=max(tr[l].lmx,tr[l].tot+tr[p].v+max(0,tr[r].lmx));
    tr[p].rmx=max(tr[r].rmx,tr[r].tot+tr[p].v+max(0,tr[l].rmx));
    tr[p].smx=max(0,tr[l].rmx)+tr[p].v+max(0,tr[r].lmx);
    tr[p].smx=max(tr[p].smx,max(tr[l].smx,tr[r].smx));
}
inline void zig(int &x)
{
    int y=tr[x].f,z=tr[y].f;
    down(y),down(x);
    if(y==tr[z].l)tr[z].l=x;
    else tr[z].r=x;
    tr[x].f=z;
    tr[y].l=tr[x].r,tr[tr[x].r].f=y;
    tr[x].r=y,tr[y].f=x;
    up(y);up(x);
    if(y==root)root=x;
}
inline void zag(int &x)
{
    int y=tr[x].f,z=tr[y].f;
    down(y),down(x);
    if(y==tr[z].l)tr[z].l=x;
    else tr[z].r=x;
    tr[x].f=z;
    tr[y].r=tr[x].l,tr[tr[x].l].f=y;
    tr[x].l=y,tr[y].f=x;
    up(y);up(x);
    if(y==root)root=x;
}
inline void splay(int &x,int ro)
{
    if(x==0||x==ro)return;
    down(x);
    while(tr[x].f!=ro)
    {
        int y=tr[x].f,z=tr[y].f;
        if(z==ro)
        {
            down(y);
            down(x);
            if(tr[y].l==x)zig(x);
            else zag(x);
        }
        else
        {
            down(z);
            down(y);
            down(x);
            if(tr[z].l==y)
            {
                if(tr[y].l==x)zig(y),zig(x);
                else zag(x),zig(x);
            }
            else
            {
                if(tr[y].r==x)zag(y),zag(x);
                else zig(x),zag(x);
            }
        }
    }
    if(ro==0)root=x;
    up(x);
}
void build(int &p,int l,int r,int f)
{
    if(l>r)return;
    if(ht>0)p=es[ht],ht--;
    else p=++cnt;
    tr[p]=zero;
    int mid=(l+r)/2;
    tr[p].tot=tr[p].lmx=tr[p].rmx=tr[p].smx=tr[p].v=a[mid];
    tr[p].size=1;
    tr[p].f=f;
    build(tr[p].l,l,mid-1,p);
    build(tr[p].r,mid+1,r,p);
    up(p);
}
int ask_kth(int &p,int x)
{
    if(p==0)return 0;
    down(p);
    if(x<=tr[tr[p].l].size)return ask_kth(tr[p].l,x);
    if(x>tr[tr[p].l].size+1)return ask_kth(tr[p].r,x-tr[tr[p].l].size-1);
    return p;
}
inline void insert(int pos,int a[],int len)
{
    int l=ask_kth(root,pos+1);//pos
    int r=ask_kth(root,pos+2);//pos+1
    splay(l,0);
    splay(r,root);
    build(tr[r].l,1,len,tr[root].r);
    up(tr[root].r);
    up(root);
}
void hs(int t)
{
    if(t==0)return;
    es[++ht]=t;
    hs(tr[t].l);
    hs(tr[t].r);
//  tr[t]=zero;
}
inline void del(int pos,int len)
{
    int l=ask_kth(root,pos);//pos+1
    int r=ask_kth(root,pos+len+1);//pos+1
    splay(l,0);
    splay(r,root);
    hs(tr[r].l);
    tr[r].l=0;
    up(r);
    up(root);
}
inline void change(int pos,int len,int x)
{
    int l=ask_kth(root,pos);//pos+1
    int r=ask_kth(root,pos+len+1);//pos+1
    splay(l,0);
    splay(r,root);
    int p=tr[r].l;
    down_same(p,x);
    up(r);
    up(root);
}
inline void rever(int pos,int len)
{
    int l=ask_kth(root,pos);//pos+1
    int r=ask_kth(root,pos+len+1);//pos+1
    splay(l,0);
    splay(r,root);
    int p=tr[r].l;
    down_rev(p);
    up(tr[root].r);
    up(root);
}
inline int get_sum(int pos,int len)
{
    int l=ask_kth(root,pos);//pos+1
    int r=ask_kth(root,pos+len+1);//pos+1
    splay(l,0);
    splay(r,root);
    return tr[tr[r].l].tot;
}
inline int get_mx(int pos,int len)
{
    int l=ask_kth(root,pos);//pos+1
    int r=ask_kth(root,pos+len+1);//pos+1
    splay(l,0);
    splay(r,root);
    return tr[tr[r].l].smx;
}
int main()
{
    scanf("%d%d",&n,&m);
    a[0]=-1;
    a[n+1]=-1;
    for(int i=1;i<=n;i++)scanf("%d",&a[i]);
   	tr[0].lmx=tr[0].rmx=tr[0].smx=-0x3f3f3f3f;	
    build(root,0,n+1,0);
  	up(tr[root].r);
	up(root);
    while(m--)
    {
        scanf("%s",s);
        if(s[0]=='I')
        {
            int len;scanf("%d%d",&pos,&len);
            for(int i=1;i<=len;i++)scanf("%d",&a[i]);
            insert(pos,a,len);
        }
        else if(s[0]=='D')
        {
            scanf("%d%d",&pos,&len);
            del(pos,len);
        }
        else if(s[0]=='M'&&s[2]=='K')
        {
            int x;
            scanf("%d%d%d",&pos,&len,&x);
            change(pos,len,x);
        }
        else if(s[0]=='R')
        {
            scanf("%d%d",&pos,&len);
            rever(pos,len);
        }
        else if(s[0]=='G')
        {
            scanf("%d%d",&pos,&len);
            printf("%d\n",get_sum(pos,len));
        }
        else printf("%d\n",get_mx(1,tr[root].size-2));
    }
    return 0;
}