Codechef 12.7 DGCD 数学+树链剖分+线段树
题目大意:
给定一棵$n$个点的树,一开始时树上编号为$i$的点的权值为$v_i$,支持两种操作:要么将一条路径上的点的权值均加上一个数$d$,要么询问一条路径上的点的权值的最大公约数。
数据范围$n,q\leq{50000},v_i,d\leq{10000}$。
算法讨论:
首先注意到一个事实:
\[gcd(a_1,a_2,...,a_n)=gcd(a_1,a_2-a_1,...,a_n-a_{n-1})\]
这个结论我们可以使用数学归纳法证明,这里略过。
那么考虑如何在序列上解决这个问题,考虑区间加上一个数对于区间的影响:
\[gcd(a_l+d,a_{l+1}+d,...a_{r}+d)=gcd(a_l+d,a_{l+1}-a_{l},...,a_{r}-a_{r-1})\]
因此我们利用线段树在每个区间长度为$n$节点上维护区间左端点的值,以及$n-1$个差分数值的最大公约数,这个区间的最大公约数就是二者的最大公约数,对区间进行修改只需要改左端点的值就行了,利用延迟标记,我们就能在$O(\log n)$的时间里解决这个问题。
为了将序列拓展到树上,我们只需要利用树链剖分,将路径拆分为最多$O(\log n)$条链求解,同时最大公约数是容易合并的,那么只需依次在每条链上用序列的求解方法求得答案然后合并就行了。
单组询问时间复杂度$O(\log ^2n)$。
时空复杂度:
时间复杂度$O(n+q\log ^2n)$,空间复杂度$O(n)$。
代码:
#include <cstdio> #include <cstring> #include <cctype> #include <iostream> #include <algorithm> using namespace std; int getc() { static const int L = 1 << 15; static char buf[L], *S = buf, *T = buf; if (S == T) { T = (S = buf) + fread(buf, 1, L, stdin); if (S == T) return EOF; } return *S++; } int getint() { static int c, x; while (!isdigit(c = getc())); x = c - '0'; while (isdigit(c = getc())) x = (x << 1) + (x << 3) + c - '0'; return x; } int getch() { static int c; while ((c = getc()) != 'F' && c != 'C'); return c; } #define N 100010 int n, head[N], nxt[N << 1], to[N << 1]; void addedge(int a, int b) { static int q = 1; to[q] = b; nxt[q] = head[a]; head[a] = q++; } void make(int a, int b) { addedge(a, b); addedge(b, a); } int siz[N], dep[N], pa[N], son[N]; void dfs(int x, int fa) { int mx = 0; siz[x] = 1; for (int j = head[x]; j; j = nxt[j]) { if (to[j] != fa) { dep[to[j]] = dep[x] + 1; pa[to[j]] = x; dfs(to[j], x); siz[x] += siz[to[j]]; if (siz[to[j]] > mx) { mx = siz[to[j]]; son[x] = to[j]; } } } } int top[N], p_id[N], id_p[N], id; void create(int x, int Top) { p_id[x] = ++id; id_p[id] = x; top[x] = Top; if (son[x]) create(son[x], Top); for (int j = head[x]; j; j = nxt[j]) if (to[j] != son[x] && to[j] != pa[x]) create(to[j], to[j]); } int gcd(int a, int b) { return (!b) ? a : gcd(b, a % b); } struct Node { Node *ls, *rs; int lv, rv, diff, ans, c; void addit(int _c) { lv += _c; rv += _c; c += _c; ans = gcd(lv, diff); } void down() { if (c) { if (ls) ls->addit(c); if (rs) rs->addit(c); c = 0; } } void up() { lv = ls->lv; rv = rs->rv; diff = gcd(rs->lv - ls->rv, gcd(ls->diff, rs->diff)); ans = gcd(lv, diff); } }mem[N << 1], *P = mem, *root; Node *newnode(int v) { P->lv = P->rv = P->ans = v; P->diff = 0; P->c = 0; return P++; } int w[N]; Node *build(int tl, int tr) { if (tl == tr) return newnode(w[id_p[tl]]); int mid = (tl + tr) >> 1; Node *q = P++; q->ls = build(tl, mid); q->rs = build(mid + 1, tr); q->up(); return q; } int query(Node *q, int tl, int tr, int dl, int dr) { if (dl <= tl && tr <= dr) return q->ans; q->down(); int mid = (tl + tr) >> 1; if (dr <= mid) return query(q->ls, tl, mid, dl, dr); else if (dl > mid) return query(q->rs, mid + 1, tr, dl, dr); else return gcd(query(q->ls, tl, mid, dl, mid), query(q->rs, mid + 1, tr, mid + 1, dr)); } void modify(Node *q, int tl, int tr, int dl, int dr, int c) { if (dl <= tl && tr <= dr) { q->addit(c); return; } q->down(); int mid = (tl + tr) >> 1; if (dr <= mid) modify(q->ls, tl, mid, dl, dr, c); else if (dl > mid) modify(q->rs, mid + 1, tr, dl, dr, c); else { modify(q->ls, tl, mid, dl, mid, c); modify(q->rs, mid + 1, tr, mid + 1, dr, c); } q->up(); } int Query(int x, int y) { int ans = 0; while (top[x] != top[y]) { if (dep[top[x]] < dep[top[y]]) swap(x, y); ans = gcd(ans, query(root, 1, n, p_id[top[x]], p_id[x])); x = pa[top[x]]; } if (dep[x] < dep[y]) swap(x, y); ans = gcd(ans, query(root, 1, n, p_id[y], p_id[x])); return ans; } void Modify(int x, int y, int c) { while (top[x] != top[y]) { if (dep[top[x]] < dep[top[y]]) swap(x, y); modify(root, 1, n, p_id[top[x]], p_id[x], c); x = pa[top[x]]; } if (dep[x] < dep[y]) swap(x, y); modify(root, 1, n, p_id[y], p_id[x], c); } int myAbs(int x) { return x < 0 ? -x : x; } int main() { #ifndef ONLINE_JUDGE freopen("tt.in", "r", stdin); #endif n = getint(); int i, j, a, b, c; for (i = 1; i < n; ++i) { a = getint() + 1; b = getint() + 1; make(a, b); } dep[1] = 1; dfs(1, -1); create(1, 1); for (i = 1; i <= n; ++i) w[i] = getint(); root = build(1, n); int q = getint(); char qte; while (q--) { qte = getch(); if (qte == 'F') { a = getint() + 1; b = getint() + 1; printf("%d\n", myAbs(Query(a, b))); } else { a = getint() + 1; b = getint() + 1; c = getint(); Modify(a, b, c); } } return 0; }