Codechef 13.12 QTREE6 树链剖分+线段树
题目大意:
给定一棵$n$个点的树,其中每个点都有一个颜色,初始均为黑色。
两个不同的点在同一个连通块中,当且仅当两个点之间的路径上的所有点的颜色相同。
现在有$q$个询问,要么将一个点的颜色取反(黑色变成白色,白色变成黑色),要么询问一个点所在的连通块的点的数目。
数据范围$n,q\leq{10^5}$。
算法讨论:
考虑利用轻重链树链剖分,这样的话,任意一个点到根的路径上最多只会经过$O(\log n)$条轻边。
对于每个节点,我们维护在有根树意义下,以这个节点为根的连通块的大小。
那么答案就应该是,在这个点所在的重链上,从这个点开始深度递增的连续一段的点的数目,再加上以这些点连出去的所有颜色相同的虚儿子为根的连通块的大小。
那么我们只需在线段树上维护最长的白色节点前缀长度以及这些节点的答案之和,最长的黑色节点前缀长度以及这些节点的答案之和,深度最大的的白色节点、黑色节点。
考虑询问,我们只需找出一个点$x$到根的路径上深度最大的与这个点颜色不同的节点$y$,并找到$x$的深度$>y$且最小的祖先$z$,回答以$z$为根的连通块的答案就行了。利用倍增处理,时间复杂度$O(\log n)$。
考虑修改,我们依次处理每条重链,修改会对这条重链深度最小的节点信息产生影响,然后这个深度最小的节点又会沿着轻边对他的父亲重链的某个点产生修改。依此类推,这样只会对$O(\log n)$条重链进行修改,每次修改都是一次线段树上的操作,因此总时间复杂度$O(\log ^2n)$。
时空复杂度:
时间复杂度$O(n+q\log ^2n)$,空间复杂度$O(n\log n)$。
代码:
#include <cstdio> #include <cstring> #include <cctype> #include <iostream> #include <algorithm> #include <vector> using namespace std; int getc() { static const int L = 1 << 15; static char buf[L], *S = buf, *T = buf; if (S == T) { T = (S = buf) + fread(buf, 1, L, stdin); if (S == T) return EOF; } return *S++; } int getint() { int c; while (!isdigit(c = getc())); int x = c - '0'; while (isdigit(c = getc())) x = (x << 1) + (x << 3) + c - '0'; return x; } #define N 100010 int n, head[N], nxt[N << 1], to[N << 1]; void addedge(int a, int b) { static int q = 1; to[q] = b; nxt[q] = head[a]; head[a] = q++; } void make(int a, int b) { addedge(a, b); addedge(b, a); } int pa[N][17], dep[N], siz[N], son[N]; void dfs(int x, int fa) { int mx = 0; siz[x] = 1; for (int j = head[x]; j; j = nxt[j]) if (to[j] != fa) { dep[to[j]] = dep[x] + 1; pa[to[j]][0] = x; dfs(to[j], x); siz[x] += siz[to[j]]; if (siz[to[j]] > mx) { mx = siz[to[j]]; son[x] = to[j]; } } } int top[N], down[N], p_id[N], id_p[N], id; void create(int x, int Top) { top[x] = Top; p_id[x] = ++id; id_p[id] = x; if (son[x]) create(son[x], Top); else down[Top] = x; for (int j = head[x]; j; j = nxt[j]) if (to[j] != pa[x][0] && to[j] != son[x]) create(to[j], to[j]); } #define inf 0x3f3f3f3f bool col[N]; struct ans { int mx_dep[2], sum[2], mx_pre[2], size; void init(int x) { mx_dep[0] = dep[x]; mx_dep[1] = 0; sum[0] = 1; for (int j = head[x]; j; j = nxt[j]) if (to[j] != pa[x][0] && to[j] != son[x]) sum[0] += siz[to[j]]; sum[1] = 0; mx_pre[0] = 1; mx_pre[1] = 0; size = 1; } friend ans operator + (const ans &l, const ans &r) { ans re; re.mx_dep[0] = max(l.mx_dep[0], r.mx_dep[0]); //re.mx_dep_black = max(l.mx_dep_black, r.mx_dep_black); re.mx_dep[1] = max(l.mx_dep[1], r.mx_dep[1]); //re.mx_dep_white = max(l.mx_dep_white, r.mx_dep_white); re.sum[0] = l.sum[0] + r.sum[0]; re.sum[1] = l.sum[1] + r.sum[1]; //re.sum_black = l.sum_black + r.sum_black; //re.sum_white = l.sum_white + r.sum_white; re.mx_pre[0] = l.mx_pre[0] + ((l.mx_pre[0] == l.size) ? r.mx_pre[0] : 0); //if (l.mx_pre_black == l.sum_black) // re.mx_pre_black = l.mx_pre_black + r.mx_pre_black; //else // re.mx_pre_black = l.mx_pre_black; re.mx_pre[1] = l.mx_pre[1] + ((l.mx_pre[1] == l.size) ? r.mx_pre[1] : 0); //if (l.mx_pre_white == l.sum_white) // re.mx_pre_white = l.mx_pre_white + r.mx_pre_white; //else // re.mx_pre_white = l.mx_pre_white; re.size = l.size + r.size; return re; } }; #define ls(x) S[x].l #define rs(x) S[x].r struct Node { ans x; int l, r; }S[N << 1]; int cnt; int build(int tl, int tr) { int q = ++cnt; S[q].x.init(id_p[tl]); if (tl == tr) return q; int mid = (tl + tr) >> 1; ls(q) = build(tl, mid); rs(q) = build(mid + 1, tr); S[q].x = S[ls(q)].x + S[rs(q)].x; return q; } ans query(int q, int tl, int tr, int dl, int dr) { if (dl <= tl && tr <= dr) return S[q].x; int mid = (tl + tr) >> 1; if (dr <= mid) return query(ls(q), tl, mid, dl, dr); else if (dl > mid) return query(rs(q), mid + 1, tr, dl, dr); else return query(ls(q), tl, mid, dl, mid) + query(rs(q), mid + 1, tr, mid + 1, dr); } ans queryChain(int x, int up) { static vector<ans> save; save.clear(); while (top[x] != top[up]) { save.push_back(query(1, 1, n, p_id[top[x]], p_id[x])); x = pa[top[x]][0]; } save.push_back(query(1, 1, n, p_id[up], p_id[x])); for (int i = save.size() - 2; i >= 0; --i) save.back() = save.back() + save[i]; return save.back(); } int climb(int x, int to_dep) { for (int i = 16; i >= 0; --i) if (dep[pa[x][i]] >= to_dep) x = pa[x][i]; return x; } int calcLen(int x, int c) { return query(1, 1, n, p_id[x], p_id[down[x]]).mx_pre[c]; } int calcNum(int x, int c) { static int len; len = calcLen(x, c); return query(1, 1, n, p_id[x], p_id[x] + len - 1).sum[c]; } void Modify(int q, int tl, int tr, int ins) { static int c; if (tl == tr) { c = col[id_p[tl]]; S[q].x.mx_pre[c] = 0; S[q].x.mx_pre[1 ^ c] = 1; S[q].x.mx_dep[c] = 0; S[q].x.mx_dep[1 ^ c] = dep[id_p[tl]]; S[q].x.sum[c]--; S[q].x.sum[1 ^ c]++; return; } int mid = (tl + tr) >> 1; if (ins <= mid) Modify(ls(q), tl, mid, ins); else Modify(rs(q), mid + 1, tr, ins); S[q].x = S[ls(q)].x + S[rs(q)].x; } void ModifySum(int q, int tl, int tr, int ins, int c, int v) { if (tl == tr) { S[q].x.sum[c] += v; return; } int mid = (tl + tr) >> 1; if (ins <= mid) ModifySum(ls(q), tl, mid, ins, c, v); else ModifySum(rs(q), mid + 1, tr, ins, c, v); S[q].x = S[ls(q)].x +S[rs(q)].x; } int main() { //freopen("tt.in", "r", stdin); n = getint(); int i, j; int a, b; for (i = 1; i < n; ++i) { a = getint(); b = getint(); make(a, b); //printf("%d %d\n", a, b); } dep[1] = 1; dfs(1, -1); for (j = 1; j <= 16; ++j) for (i = 1; i <= n; ++i) pa[i][j] = pa[pa[i][j - 1]][j - 1]; create(1, 1); for (i = 1; i <= n; ++i) down[i] = down[top[i]]; build(1, n); int q = getint(); int type, x, y; ans temp; int lastcol, lastsum, _lastcol, _lastsum, nowcol, nowsum; while (q--) { type = getint(); x = getint(); temp = queryChain(x, 1); if (type == 0) { y = climb(x, temp.mx_dep[col[x] ^ 1] + 1); printf("%d\n", calcNum(y, col[x])); } else { lastcol = col[top[x]]; lastsum = calcNum(top[x], col[top[x]]); Modify(1, 1, n, p_id[x]); col[x] ^= 1; while (pa[top[x]][0]) { nowcol = col[top[x]]; nowsum = calcNum(top[x], col[top[x]]); _lastcol = col[top[pa[top[x]][0]]]; _lastsum = calcNum(top[pa[top[x]][0]], col[top[pa[top[x]][0]]]); ModifySum(1, 1, n, p_id[pa[top[x]][0]], lastcol, -lastsum); ModifySum(1, 1, n, p_id[pa[top[x]][0]], nowcol, nowsum); x = pa[top[x]][0]; lastcol = _lastcol; lastsum = _lastsum; } } } return 0; }