本篇文章前面都是废话,重点在于 Splay 压行,猛击键盘 End 键前往页面底部

定义 Definition

Splay 树是一种二叉平衡树的替代品,支持平衡树的所有操作,和 AVL 树严格的 $O(logN)$ 不同,其每次操作的均摊复杂度为 $O(logN)$。它不根据子树 size 或其他附加域进行调整,而是在每次操作后对树进行 Splay (伸展)操作。

Splay 树往往用来维护序列。将每个节点的权值赋为序列下标,建树,这样得到的树的中序遍历即是原数列的下标(1, 2, 3, 4 …)。又因为旋转操作可以维持二叉查找树左小右大的性质,即中序遍历不变,因而建立在旋转操作基础上的伸展操作也不会改变其中序遍历。

操作 Operation

伸展 Splay

我们要支持这样一种操作:Splay(o, t),将以节点 o 为根的子树旋转为节点 t 的一个儿子。保证 o 是 t 的后代。

容易想到,我们可以将 o 不断向上旋转直到它为 t 的一个儿子。这样可以实现单旋 Splay 树,但是这种数据结构的时间复杂度没有保证。

使用分三种情况的旋转策略可以使均摊复杂度达到 $O(logN)$。相关资料有很多,不再赘述。

平衡树操作

求前驱、后继、第 k 大、插入操作都可以按二叉搜索树的方法实现,最后将操作节点伸展到根来保证均摊复杂度。

由于我们实现了伸展操作,所以我们还能更简单的实现合并和删除。

合并 Merge

设 o 子树中的所有元素都小于 t 子树中的元素。Merge(o, t) 将 o 作为 t 的左儿子。

将 t 中的最小元素伸展到根,此时 t 子树一定没有左儿子。因此可以将 o 作为 t 的左儿子。

删除 Delete

将待删除节点伸展到根,丢弃根节点,合并根节点的两颗子树。
区间删除同样很简单,留给读者思考。

扩展

标记 Tag

思想和线段树的标记是一样的,大部分线段树可做的都可做。

关键在于标记下传和更新信息的时机。

回忆一下线段树 down 的时机:访问子树时。

这一点在 Splay 树中仍然成立,不过发生了一些微妙的变化。

我们知道,线段树中当当前区间完全被待查区间包含时就可以立即返回,此时不需要访问当前区间的子区间,因此不用 down。

但是,Splay 树中当前结点即为待查节点时,不能立即返回,而是应该先 push 再返回。

我们明明没有访问子树,为什么也要 push 呢?

关键在于,我们实际上访问子树了。因为每次操作后都要将待查节点伸展到根,所以每次查询必定会对待查节点进行旋转。旋转中用到了待查节点的子节点。我们忘记了这一次访问。

清除标记 Clear

如果我们通过某种方式得到了要伸展的结点,但是从该节点到根的路径上有标记,那么我们就不能直接伸展它。
而自底向上的边伸展边 push 显然是错的。

所以我们可以自底向上走直到根,结点入栈。然后不断对栈顶结点执行 push 直到栈空。

懒惰插入/删除 Lazy Operations

实践中常用 cnt[] 表示某个节点上的数的出现次数。注意 cnt 虽然可以是 0, 但不能是负数。

有了懒惰插入/删除,大部分操作还是很容易实现。

不过求前驱/后继的时候开始变得有点微妙。

如果没有空点,求 o 的前驱,可以把 o 伸展到根,从 o 的左儿子开始,一直向右走直到没有右儿子为止。对称地,有求后继的方法。

由于有了空点,我们甚至不能判断一个节点有没有左右儿子。因为以左/右儿子为根的子树里可能全是空点。

从 o 的左儿子开始:
如果当前节点是空点,那么如果右子树不全空就走右节点,否则就走左节点。
如果当前结点不是空点,那么如果右子树不全空就走右节点,否则当前结点就是后继。

它的正确性基于这样的想法:先前不存在空点时,我们可以根据当前节点为界缩小搜索范围,因为后继大于等于当前结点,所以只用考虑右子树。但是现在引入了空点,当前结点为空时根本无法根据当前结点的值缩小搜索范围,也就是必须同时考虑左右子树。但由于右子树一定比左子树优,所以如果右子树非空,只用考虑右子树。

实现

下文讨论数组+迭代实现。

储存

为了储存一棵树,我们维护以下信息:

fa[o]; // father of o
ch[o][0], ch[o][1] // left child and right child of o
sz[o]; // size of o
rt; // root of tree

虚点

为了方便,我们定义 0 号节点为 null。于是同时规定:

sz[0] = 0;
fa[0] = 0;
//...

0 号节点为逻辑 null。它不能被任何函数操作。也就是任何以 0 为操作节点的函数应该立刻返回。

辅助过程

为了操作树的结构,我们编写以下函数:

int gd(int o) {// return x is which child
  return ch[fa[o]][1] == o;
}
void lk(int x, int y, int d){ // make x become a child of y
  //ignore virutal node
  if (y) ch[y][d] = x;
  if (x) fa[x] = y;
}
int cut(int o) { // cut o from its father. rarely used.
  ch[fa[o]][gd(o)] = 0;
  fa[o] = 0;
  return o;
}
void rot(int o) {//roate o up
  // must guratee that there is no tag from o to the root
  int x = fa[o], d = gd(o);
  lk(o, fa[x], gd(x));
  lk(ch[o][d^1], x, d);
  lk(x, o, d^1);
  up(x);// up x before o becasue x is child of o
  up(o);
  if (x == rt) rt = o;
}

为了维护标记:

void up(int o); // update info from children of o
void down(int o); // give tags of o to its children

为了清楚, 在操作树的函数中,我们只允许 rot 调用 updown 维护附加域。其他场合 lk , cut 必须和 up , down 配套使用。而不是把 updown 写在他们里面。

up 中要对虚点的附加域赋无影响的值或者特判儿子是否为虚点。

主要操作

Splay 树的核心操作在于伸展。自底向上的伸展操作要求被伸展节点到根的链上没有标记。这一条件实际上很好满足。我们在查找满足条件的节点时就已经对这一条链上的标记全部 push 了。

int kth(int o, int k) {// find k-th node
  while (k) {
  down(o);
    if (k == sz[lch] + 1) return o;
    else if (k <= sz[lch]) o = lch;
    else k -= sz[lch] + 1, o = rch;
  }
  return o;
}
int splay(int o, int t) {// rotate o up until o is a child of t, used with kth()
  while(fa[o] != t) {
    int x = fa[o];
    if (fa[x] != t) gd(o) == gd(x) ? rot(x) : rot(o);
    rot(o);
  }
  return o;
}

有了以上函数就可以进行 [l, r] 区间修改了。只需要把一个端点伸展到根,另一个端点伸展到根的子树。两个端点“之间夹住的部分”就是要修改的区间。注意此处的端点不一定就是第 l 大和第 r 大,细节留给读者思考。

不过有些题目更加复杂,要求支持区间 “剪切”,即把一段区间拿出来插入到另一个位置。这时候组合使用上述函数就好了。此处我想给出一个比较精炼的操作,合并 Merge。

inline int merge(int o, int t) {//make o become the left child of t
  kth(t, 1);
  t = splay(kth(t, 1), fa[t]);
  lk(o, t, 0);
  up(t);
  return t;
}

刚刚提到的剪切操作就可以用 kth + splay + merge 实现。

上面给出的 merge 暴露了笔者实现中最 ugly 的一个地方:t = splay(kth(t, 1), fa[t]);。这行的含义是将以 t 为根的树中的最左节点伸展到根。那为什么是个赋值语句呢?原因在于,以 t 为根的子树的根改变了。而旋到 rt 的时候不需要赋值,因为旋到 rt 的时候,rot过程中的一个if语句修复了这个问题。
可能另一种写法会让这个实现稍微漂亮一些:去掉 rot 中的判断,靠人在每次调用时判断是否要赋值。这样根的行为统一了,也让维护 Splay 森林变得可行。代价是:你必须非常小心,在必要的时候修改 rt 。

应用

Prob Hint
BZOJ 3323 文艺平衡树 区间翻转
BZOJ 1251 序列终结者 区间翻转,询问最值
BZOJ 1895 supermemo 区间加,翻转,剪切,询问最值。点插入,删除。
BZOJ 1056 排名系统 专治操作完不伸展
BZOJ 1552 robotic sort 区间反转,清除标记,splay 的灵活运用
BZOJ 3224 普通平衡树 像普通平衡树一样使用 Splay

代码

这是普通平衡树的代码。
前驱后继写的比较飘逸。

#include <cstdio>
#define lch ch[o][0]
#define rch ch[o][1]
namespace I {
  const int L = 1 << 15;
  char buf[L], *s, *t;
  inline char gc() {
    if (s == t) t = (s = buf) + fread(buf, 1, L, stdin);
    if (s == t) return EOF;
    return *s++;
  }
  inline int gi() {
    int r = 0, f = 1, ch = gc();
    while (!(ch <= '9' && ch >= '0')) {
      if (ch == '-') f = -1;
      ch = gc();
    }
    while (ch <= '9' && ch >= '0') r = r*10 + ch - '0', ch = gc();
    return r*f;
  }
}using I::gi;
const int N = 1e5 + 10;
int n, st[N], nsz = 1, ch[N][2], cnt[N], sz[N], rt = 1, fa[N];
inline int gd(int o) {return ch[fa[o]][1] == o;}
inline void up(int o) {sz[o] = sz[lch] + sz[rch] + cnt[o];}
inline void lk(int x, int y, int d) {
  if (x) fa[x] = y;
  if (y) ch[y][d] = x;
}
inline void rot(int o) {
  int x = fa[o], d = gd(o);
  lk(o, fa[x], gd(x));
  lk(ch[o][d^1], x, d);
  lk(x, o, d^1);
  up(x);
  up(o);
  if (x == rt) rt = o;
}
void splay(int o, int t) {
  while (fa[o] != t) {
    int x = fa[o];
    if (fa[x] != t) gd(o) == gd(x) ? rot(x) : rot(o);
    rot(o);
  }
}
int tar;
inline int newnode(int v) {
  st[++nsz] = v;
  cnt[nsz] = 1;
  up(nsz);
  return nsz;
}
void insert(int o, int v) {
  if (v < st[o]) {
    if (lch) insert(lch, v);
    else lk(tar = newnode(v), o, 0);
  }
  else if (v > st[o]) {
    if (rch) insert(rch, v);
    else lk(tar = newnode(v), o, 1);
  }
  else ++cnt[tar = o];
  up(o);
}
void remove(int o, int v) {
  if (!o) return;
  if (v < st[o]) remove(lch, v);
  else if (v > st[o]) remove(rch, v);
  else --cnt[o];
  up(o);
}
int rank(int o, int v) {
  if (v < st[o]) return rank(lch, v);
  else if (v > st[o]) return sz[lch] + cnt[o] + rank(rch, v);
  else return tar = o, sz[lch] + 1;
}
int kth(int o, int k) {
  while (k) {
    if (k <= sz[lch]) o = lch;
    else if ((k -= sz[lch]) <= cnt[o]) return st[tar = o];
    else k -= cnt[o], o = rch;
  }
  return st[tar = o];
}
inline int sp(int v, int t) {
  insert(rt, v);
  remove(rt, v);
  rank(rt, v);
  splay(tar, 0);
  int o = ch[rt][t], r;
  t ^= 1;
  while (o) {
    if (cnt[o]) {
      r = st[o];
      if (sz[ch[o][t]]) o = ch[o][t];
      else break;
    }
    else o = sz[ch[o][t]] ? ch[o][t] : ch[o][t^1];
  }
  return r;
}

int main() {
  n = gi();
  while (n--) {
    int op = gi(), x = gi();
    if (op == 1) insert(rt, x), splay(tar, 0);
    else if (op == 2) remove(rt, x);
    else if (op == 3) printf("%d\n", rank(rt, x)), splay(tar, 0);
    else if (op == 4) printf("%d\n", kth(rt, x)), splay(tar, 0);
    else if (op == 5) printf("%d\n", sp(x, 0));
    else if (op == 6) printf("%d\n", sp(x, 1));
  }
}
/* Debug
 * rank 的递归边界是 sz[lch] + 1
 * 有虚点求前驱后继
 */

网上有篇文章说是“终极”模板。我觉得这个板子写的很好。推荐看一看。

借鉴了一下,我写出了更短的代码,没有强行压行。

三行插入,两行伸展,一行求前驱后继。

——容我中二。我这个是极限模板。比他的模板短 34 行。

当然还有强行压行的空间,但是那样并没有简化逻辑。

#include <cstdio>
#include <algorithm>
#include <cstdlib>
using std::min;
const int N = 1e5 + 10, INF = ~0u>>2;
int n, st[N], nsz, rt, ch[N][2], fa[N], ans;
inline void lk(int x, int y, int d) {
  if (x) fa[x] = y;
  if (y) ch[y][d] = x;
}
inline int gd(int o) {return ch[fa[o]][1] == o;}
inline void rot(int o) {
  int x = fa[o], d = gd(o);
  lk(o, fa[x], gd(x));
  lk(ch[o][d^1], x, d);
  lk(x, o, d^1);
  if (x == rt) rt = o;
}
inline int newnode(int v) {
  st[++nsz] = v;
  return nsz;
}
void splay(int o) {
  for (int x; x = fa[o]; rot(o))
    if(fa[x]) rot(gd(o) == gd(x) ? x : o);
}
void insert(int o, int v) {
  for (int p; p = ch[o][v >= st[o]]; o = p) ;
  lk(newnode(v), o, v >= st[o]);
  splay(nsz);
}
int sp(int o, int t) {
  for (o = ch[o][t], t ^= 1; ch[o][t]; o = ch[o][t]) ;
  return st[o];
}
int main() {
  scanf("%d", &n);
  rt = newnode(INF);
  insert(rt, -INF);
  bool f = 0;
  while (n--) {
    int x;
    scanf("%d", &x);
    insert(rt, x);
    ans += f ? min(abs(x - sp(rt, 0)), abs(x - sp(rt, 1))) : (f = 1, x);
  }
  printf("%d\n", ans);
}
#include <cstdio>
#include <algorithm>
#include <cstring>
#include <cmath>
#include <cstdlib>
const int V = 80010, M = 1000000;
int n, sta = 3, ch[V][2], fa[V], cnt[V], st[V], top, rt, sz, ans;
inline int gd(int o) {return ch[fa[o]][1] == o;}
inline void lk(int x, int y, int d) {
    if (x) fa[x] = y;
    if (y) ch[y][d] = x;
}
inline void rot(int o) {
    int x = fa[o], d = gd(o);
    lk(o, fa[x], gd(x));
    lk(ch[o][d^1], x, d);
    lk(x, o, d^1);
    if (x == rt) rt = o;
}
inline void splay(int o) {
    for (;fa[o];rot(o)) if (fa[fa[o]]) rot(gd(o) == gd(fa[o]) ? fa[o] : o);
}
inline int sp(int o, int d){// 0 -> prev, 1 - > succ, assume o is rt
    for (o = ch[o][d], d^=1; ch[o][d]; o = ch[o][d]);
    return o;
}
inline int insert(int x) {
    for (int o = rt; o; ) if (x == st[o]) {cnt[o]++;return o;}
    else if (ch[o][x > st[o]]) o = ch[o][x > st[o]];
    else {
        st[++top] = x;
        cnt[top] = 1;
        lk(top, o, x > st[o]);
        return top;
    }
    return 0;
}
inline int find(int x) {
    int o = rt;
    while(1) if (x == st[o]) return o;
    else if (ch[o][x > st[o]]) o = ch[o][x > st[o]];
    else break;
    return o;
}
void remove(int o) {
    cnt[o]--;
    if (!cnt[o]) {
        splay(o);
        int d = bool(ch[o][1]), x = ch[o][d];
        if (!x) return;
        fa[rt = x] = 0;
        for (d^=1;ch[x][d];x = ch[x][d]);
        lk(ch[o][d], x, d);
    }
}
int main() {
//    freopen("input", "r", stdin);
    scanf("%d", &n);
    while (n--) {
        int t, x;scanf("%d%d", &t, &x);
        if (t == sta) splay(insert(x)), sz++;
        else if (t == (sta^1)) {
            int o = find(x);
            splay(o);
            if (st[o] == x) ;
            else {
                int tt = sp(o, 0);
                if (tt && abs(st[tt] - x) <= abs(st[o] - x)) o = tt;
                tt = sp(o, 1);
                if (tt && abs(st[tt] - x) < abs(st[o] - x)) o = tt;
            }
            remove(o);
            ans = (ans + abs(st[o] - x))%M;
            sz--;
        }
        else {
            st[rt = ++top] = x;
            cnt[top] = 1;
            sta = t;
            sz++;
        }
        if (!sz) sta = 3;
    }
    printf("%d\n", ans);
}

后记

Splay 树的实现太多了。它本身就很灵活。记录下几个调的时候杀时间的点,按杀伤力排行:

  1. 前驱后继中的空点
  2. rank 的递归边界是 lch + 1
  3. insert 中使用 lk + newnode 新建节点
  4. rot 中修改 rt