Treap是 二叉搜索树(BST)和二叉堆(Heap)的结合。二叉搜索树支持Treap的所有一般功能,例如查排名,查第k大,前驱,后继,删除,插入。它的特点是左子树小于等于根,右子树大于等于根。但是它的复杂度依赖于树的高度,而树的高度很容易被数据卡成链。
Heap是一种完全二叉树,它的树高为log(n)。一般有小根堆和大根堆。我们以小根堆为例,它的特点是根节点小于等于子节点。
其实到这里我们会发现两种树的性质是矛盾的,BST要求左儿子小于等于根,右儿子大于等于根,Heap要求根既小于等于左儿子,又小于等于右儿子,同时人为规定左儿子小于右儿子。
所以为了让二者结合,我们给二叉搜索树的每个节点赋予一个随机值,通过这个随机值来维护堆的性质。
首先我们要知道,BST的中序遍历是不变的(最小的在最左边),而Heap的后序遍历是不变的(最大的在最右边)。数据是出题人给的,也就是我们无法改变中序遍历。现在假设我们的随机数组是B,那么根据这个数组可以得到唯一的后序遍历,有了后序遍历和中序遍历我们就可以唯一的确定一棵树。Treap一个重要的特性就是树的形状是确定的。在随机的情况下,树的高度是log(n)的。证明方法就不展开讲了。
Treap是可以通过Zig和Zag来维护的,但是今天讲的是利用分裂还有合并来保持平衡,单次操作的期望是log(N)。
声明一些变量
const int N = 2e6+5;
int val[N], ch[N][2], siz[N], rnd[N], q[N];
int root, cnt, x, y, z, an, tot;
分裂是按照权值来分,(其实也可以按照size来分,一般做区间问题的时侯会用到)。把树分为两个子树,其中左子树小于等于a,右子树大于a。
void split(int rt, int a, int &x, int &y) {
if(!rt) x = y = 0;else {
if(val[rt] <= a) x = rt, split(ch[rt][1], a, ch[rt][1], y);else y = rt, split(ch[rt][0], a, x, ch[rt][0]);up(rt);}
}
合并是按照随机数数合并的,一定要注意顺序,x一定是更小节点的根。然后按照堆的性质来合并即可
int merge(int x, int y) {
if(!x || !y) return x + y;if(rnd[x] < rnd[y]) {
ch[x][1] = merge(ch[x][1], y);up(x); return x;} else {
ch[y][0] = merge(x, ch[y][0]);up(y); return y;}
}
void insert(int a) {
// 插入split(root, a, x, y);root = merge(merge(x, new_node(a)), y);
}
void del(int a) {
// 删除,加入了垃圾回收split(root, a, x, y);split(x, a-1, x, z);q[++tot] = z;z = merge(ch[z][0], ch[z][1]);root = merge(merge(x, z), y);
}
int _rank(int a) {
//查找a的排名split(root, a-1, x, y);an = siz[x] + 1;root = merge(x, y);return an;
}
int _kth(int rt, int k) {
// 查找排名为k的数while(rt) {
if(k <= siz[ch[rt][0]]) rt = ch[rt][0];else if(k == siz[ch[rt][0]] + 1) return val[rt];else k -= siz[ch[rt][0]]+1, rt = ch[rt][1];}
}
int pre(int a) {
// 查找前驱split(root, a-1, x, y);an = _kth(x, siz[x]);root = merge(x, y);return an;
}
int nxt(int a) {
// 查找后继split(root, a, x, y);an = _kth(y, 1);root = merge(x, y);return an;
}
完整代码
#include<bits/stdc++.h>
using namespace std;
typedef long long LL;
const int INF = 0x3f3f3f3f;
const double Pi = acos(-1);
namespace {
template <typename T> inline void read(T &x) {
x = 0; T f = 1;char s = getchar();for(; !isdigit(s); s = getchar()) if(s == '-') f = -1;for(; isdigit(s); s = getchar()) x = (x << 3) + (x << 1) + (s ^ 48);x *= f;}
}
#define fio ios::sync_with_stdio(false);cin.tie(0);cout.tie(0);
#define _for(n,m,i) for (register int i = (n); i < (m); ++i)
#define _rep(n,m,i) for (register int i = (n); i <= (m); ++i)
#define _srep(n,m,i)for (register int i = (n); i >= (m); i--)
#define _sfor(n,m,i)for (register int i = (n); i > (m); i--)
#define lson rt << 1, l, mid
#define rson rt << 1 | 1, mid + 1, r
#define lowbit(x) x & (-x)
#define pii pair<int,int>
#define fi first
#define se second
const int N = 2e6+5;
int val[N], ch[N][2], siz[N], rnd[N], q[N];
int root, cnt, x, y, z, an, tot;
void up(int rt) {
siz[rt] = 1 + siz[ch[rt][0]] + siz[ch[rt][1]];
}
int new_node(int x) {
int new_cnt = tot ? q[tot--] : ++cnt;val[new_cnt] = x;siz[new_cnt] = 1;rnd[new_cnt] = rand();ch[new_cnt][0] = ch[new_cnt][1] = 0;// cout << "debug: " << new_cnt << endl;return new_cnt;
}
void split(int rt, int a, int &x, int &y) {
if(!rt) x = y = 0;else {
if(val[rt] <= a) x = rt, split(ch[rt][1], a, ch[rt][1], y);else y = rt, split(ch[rt][0], a, x, ch[rt][0]);up(rt);}
}
int merge(int x, int y) {
if(!x || !y) return x + y;if(rnd[x] < rnd[y]) {
ch[x][1] = merge(ch[x][1], y);up(x); return x;} else {
ch[y][0] = merge(x, ch[y][0]);up(y); return y;}
}
void insert(int a) {
split(root, a, x, y);root = merge(merge(x, new_node(a)), y);
}
void del(int a) {
split(root, a, x, y);split(x, a-1, x, z);q[++tot] = z;z = merge(ch[z][0], ch[z][1]);root = merge(merge(x, z), y);
}
int _rank(int a) {
split(root, a-1, x, y);an = siz[x] + 1;root = merge(x, y);return an;
}
int _kth(int rt, int k) {
while(rt) {
if(k <= siz[ch[rt][0]]) rt = ch[rt][0];else if(k == siz[ch[rt][0]] + 1) return val[rt];else k -= siz[ch[rt][0]]+1, rt = ch[rt][1];}
}
int pre(int a) {
split(root, a-1, x, y);an = _kth(x, siz[x]);root = merge(x, y);return an;
}
int nxt(int a) {
split(root, a, x, y);an = _kth(y, 1);root = merge(x, y);return an;
}
int main() {
srand(time(0));int n,m, op, a, la = 0, ans = 0; read(n); read(m);while(n--) {
read(a); insert(a);}while(m--) {
read(op); read(a);a ^= la;if(op == 1) insert(a);else if(op == 2) del(a);else if(op == 3) la = _rank(a),ans ^= la;else if(op == 4) la = _kth(root, a), ans ^= la;else if(op == 5) la = pre(a), ans ^= la;else la = nxt(a), ans ^= la;} printf("%d\n", ans);
}