Codechef 14.2 COT5 Treap+线段树
题目大意:
Treap是一种二叉平衡树,维护一些节点$(k,w)$,其中$k$是键值、$w$是权值。
Treap的中序遍历要使得键值是递增的,还要保证每一个节点的权值都要小于他的父亲节点的权值。
现有$n$个询问,要么插入一个节点$(k,w)$,要么删除一个键值为$k$的节点,要么询问两个键值分别为$ku,kv$的节点在Treap上的距离。
数据范围$n\leq{200000}$。
保证任意时刻树中不存在两个节点的键值相同或者权值相同。
算法讨论:
首先假如我们有一个关于键值的有序序列,那么如何建树呢?
我们首先找出序列中权值最大的点作为根节点,然后将剩余的两部分作为左子树和右子树递归下去。
显然哪部分是左子树,哪部分是右子树对于答案是没有影响的。
那我们会发现,对于两个序列中的点$x,y$,如果以两个点为两端的区间中不存在比两个点权值都大的点,那么$x,y$必定有一个是另一个的祖先——权值大的那个是祖先;否则$x,y$一定在某次递归中被划分到了两棵子树中。
同时区间中权值最大的那个点就是两个点的lca,这个结论也是比较显然的。
那么我们要算两个点之间距离,在已经知道lca的情况下,只要能算出每个点的深度就行了。
考虑我们如何找出一个点所有的祖先,祖先的权值应该大于这个点的权值,并且区间中的点的权值均小于祖先的权值。
我们发现只要从这个点在序列中的位置开始,分别找出向前和向后的权值单调递增链的长度就行了。
现在先只考虑向后权值递增链的长度。
不妨使用线段树来维护。
对于线段树上的每个节点,记录从这个节点的左端点开始向右权值单调递增的链的长度以及链上最后一个点的权值。
合并时,我们只需要计算左儿子链上最后一个点在右儿子链上的排名就行了。
而这个排名能利用类似树上二分的方法做到$O(\log n)$。
于是合并时间复杂度为$O(\log n)$,修改一个点的权值只需要将叶子节点到根的路径上的$O(\log n)$个节点都更新一下就行了,时间复杂度$O(\log ^2n)$。
时空复杂度:
时间复杂度$O(n\log ^2n)$,空间复杂度$O(n)$。
代码:
#include <cstdio> #include <cstring> #include <cctype> #include <climits> #include <iostream> #include <algorithm> using namespace std; int getc() { static const int L = 1 << 15; static char buf[L], *S = buf, *T = buf; if (S == T) { T = (S = buf) + fread(buf, 1, L, stdin); if (S == T) return EOF; } return *S++; } unsigned getint() { static int c; static unsigned x; while (!isdigit(c = getc())); x = c - '0'; while (isdigit(c = getc())) x = (x << 1) + (x << 3) + c - '0'; return x; } #define N 200010 int n, type[N]; unsigned v[N], w[N], ku[N], kv[N]; struct Node { Node *ls, *rs, *pa; int l_len, r_len, mxid; unsigned l_mx, r_mx, mx; int lFind(unsigned); int rFind(unsigned); void up(); }mem[N << 1], *P = mem; Node *newnode() { return P++; } int Node::lFind(unsigned x) { if (x > l_mx) return 0; if (!ls) return x < l_mx ? 1 : 0; if (x > ls->l_mx) return rs->lFind(x); else return ls->lFind(x) + l_len - ls->l_len; } int Node::rFind(unsigned x) { if (x > r_mx) return 0; if (!ls) return x < r_mx ? 1 : 0; if (x > rs->r_mx) return ls->rFind(x); else return rs->rFind(x) + r_len - rs->r_len; } void Node::up() { if (ls->mx > rs->mx) { mx = ls->mx; mxid = ls->mxid; } else { mx = rs->mx; mxid = rs->mxid; } l_len = ls->l_len + rs->lFind(ls->l_mx); l_mx = max(ls->l_mx, rs->l_mx); r_len = rs->r_len + ls->rFind(rs->r_mx); r_mx = max(rs->r_mx, ls->r_mx); } Node *build(int tl, int tr) { Node *q = newnode(); if (tl == tr) { q->l_len = q->r_len = 1; q->l_mx = q->r_mx = q->mx = 0; q->mxid = tl; return q; } int mid = (tl + tr) >> 1; q->ls = build(tl, mid); q->rs = build(mid + 1, tr); q->ls->pa = q->rs->pa = q; q->up(); return q; } void modify(Node *q, int tl, int tr, int x, int v) { if (tl == tr) { q->mx = q->l_mx = q->r_mx = v; return; } int mid = (tl + tr) >> 1; if (x <= mid) modify(q->ls, tl, mid, x, v); else modify(q->rs, mid + 1, tr, x, v); q->up(); } pair<unsigned, int> queryMaxid(Node *q, int tl, int tr, int dl, int dr) { if (dl <= tl && tr <= dr) return make_pair(q->mx, q->mxid); int mid = (tl + tr) >> 1; if (dr <= mid) return queryMaxid(q->ls, tl, mid, dl, dr); else if (dl > mid) return queryMaxid(q->rs, mid + 1, tr, dl, dr); else return max(queryMaxid(q->ls, tl, mid, dl, mid), queryMaxid(q->rs, mid + 1, tr, mid + 1, dr)); } Node *findLeaf(Node *q, int tl, int tr, int x) { if (tl == tr) return q; int mid = (tl + tr) >> 1; if (x <= mid) return findLeaf(q->ls, tl, mid, x); else return findLeaf(q->rs, mid + 1, tr, x); } int left_extend(Node *q) { int re = 1; unsigned mx = q->mx; while (q->pa) { if (q == q->pa->rs) { re += q->pa->ls->rFind(mx); mx = max(mx, q->pa->ls->r_mx); } q = q->pa; } return re; } int right_extend(Node *q) { int re = 1; unsigned mx = q->mx; while (q->pa) { if (q == q->pa->ls) { re += q->pa->rs->lFind(mx); mx = max(mx, q->pa->rs->l_mx); } q = q->pa; } return re; } Node *root; int calcDep(int x) { static Node *leaf; leaf = findLeaf(root, 1, n, x); return left_extend(leaf) + right_extend(leaf) - 1; } unsigned global_hash[N]; int num; int Find(unsigned x) { return lower_bound(global_hash + 1, global_hash + num + 1, x) - global_hash; } int main() { #ifndef ONLINE_JUDGE //freopen("tt.in", "r", stdin); #endif n = getint(); int i; for (i = 1; i <= n; ++i) { type[i] = (int)getint(); if (type[i] == 0) { v[i] = getint(); global_hash[++num] = v[i]; w[i] = getint(); } else if (type[i] == 1) v[i] = getint(); else { ku[i] = getint(); kv[i] = getint(); } } sort(global_hash + 1, global_hash + num + 1); num = unique(global_hash + 1, global_hash + num + 1) - global_hash - 1; root = build(1, n); int id, _id, __id; for (i = 1; i <= n; ++i) { if (type[i] == 0) { id = Find(v[i]); modify(root, 1, n, id, w[i]); } else if (type[i] == 1) { id = Find(v[i]); modify(root, 1, n, id, 0); } else { id = Find(ku[i]); _id = Find(kv[i]); if (id > _id) swap(id, _id); __id = queryMaxid(root, 1, n, id, _id).second; printf("%d\n", calcDep(id) + calcDep(_id) - calcDep(__id) * 2); } } return 0; }