Codechef 13.8 PRIMEDST 树分治+FFT
题目大意:
给定一棵$n$个点的树,每条边的长度都为$1$,求在上面随机一条路径,这条路径的长度为质数的概率,误差不超过$10^{-6}$就会被认为是正确的。
数据范围$2\leq{n}\leq{50000}$。
题解:
我们只需求出每种长度的路径各有多少条,就能知道所有路径的数目以及长度为质数的路径的数目,用两者相除就能得到我们想要的概率。
我们利用点分治统计所有通过当前有根树的根节点的路径。
考虑所有只经过一棵子树的路径,我们只需要从根出发向下DFS一下就能求出这些路径的贡献。
考虑所有跨越两棵子树的路径,对于每棵子树,我们维护一个数组$C$,其中$C_i$表示子树内深度为$i$的节点的数目,我们对两棵子树的$C$数组做一个卷积$D$,则$D_i$表示的就是跨越两棵子树且长度为$i$的路径的数目。
我们可以依次枚举每棵子树,将这棵子树的$C$数组与前面的子树的$C$数组的前缀和的卷积$D$数组统计入答案。
我们分析复杂度:对于一棵子树,$C$数组的长度就是子树内节点的最大深度;对于两个长度分别为$n,m$的$C$数组,卷积花费的时间复杂度为:
$O(max(n,m)\log max(n,m))$。
因此,我们只要将所有子树按照最大深度从小到大的顺序进行排序,并按照上面的过程利用卷积统计答案,对于一棵点数为$n$的树,显然就能在不超过$O(n\log n)$的时间复杂度内完成统计。
再加上树分治,则总时间复杂度为$O(n\log ^2n)$。
时空复杂度:
时间复杂度$O(n\log ^2n)$,空间复杂度$O(n)$。
代码:
#include <cstdio> #include <cstring> #include <cctype> #include <iostream> #include <algorithm> #include <cmath> #include <vector> using namespace std; typedef double db; int getc() { static const int L = 1 << 15; static char buf[L], *S = buf, *T = buf; if (S == T) { T = (S = buf) + fread(buf, 1, L, stdin); if (S == T) return EOF; } return *S++; } int getint() { int c; while (!isdigit(c = getc())); int x = c - '0'; while (isdigit(c = getc())) x = (x << 1) + (x << 3) + c - '0'; return x; } #define N 50010 int head[N], nxt[N << 1], to[N << 1]; void addedge(int a, int b) { static int q = 1; to[q] = b; nxt[q] = head[a]; head[a] = q++; } void make(int a, int b) { addedge(a, b); addedge(b, a); } int prime[N], id; bool np[N]; void pre(int n) { int i, j; for (i = 2; i <= n; ++i) { if (!np[i]) prime[++id] = i; for (j = 1; j <= id && prime[j] * i <= n; ++j) { np[i * prime[j]] = 1; if (i % prime[j] == 0) break; } } } struct Cp { db u, v; Cp() {} Cp(db _u, db _v) : u(_u), v(_v) {} friend Cp operator + (const Cp &a, const Cp &b) { return Cp(a.u + b.u, a.v + b.v); } friend Cp operator - (const Cp &a, const Cp &b) { return Cp(a.u - b.u, a.v - b.v); } friend Cp operator * (const Cp &a, const Cp &b) { return Cp(a.u * b.u - a.v * b.v, a.u * b.v + a.v * b.u); } Cp operator * (const db &p) const { return Cp(u * p, v * p); } }; static const db pi = acos(-1.0); vector<int> Reverse[18]; void init() { int i, j, k, num; for (i = 1; i <= 17; ++i) { for (j = 0; j < (1 << i); ++j) { for (num = k = 0; k < i; ++k) if ((j >> k) & 1) num |= (1 << (i - k - 1)); Reverse[i].push_back(num); } } } int getbit(int n) { int cnt = 0; for (; n != 1; n >>= 1, ++cnt); return cnt; } void fastFourierTransform(Cp A[], int n, int rev) { static Cp B[131072], wn, w, t; int i, j, k, bit = getbit(n); for (i = 0; i < n; ++i) B[i] = A[Reverse[bit][i]]; for (i = 0; i < n; ++i) A[i] = B[i]; for (k = 2; k <= n; k <<= 1) { for (wn = Cp(cos(2 * pi / k), rev * sin(2 * pi / k)), i = 0; i < n; i += k) { for (w = Cp(1, 0), j = 0; j < (k >> 1); ++j, w = w * wn) { t = w * A[i + j + (k >> 1)]; A[i + j + (k >> 1)] = A[i + j] - t; A[i + j] = A[i + j] + t; } } } if (rev == -1) for (i = 0; i < n; ++i) A[i] = A[i] * (1.0 / n); } bool del[N]; long long ans[131072]; int siz[N]; void calSize(int x, int fa) { siz[x] = 1; for (int j = head[x]; j; j = nxt[j]) if (!del[to[j]] && to[j] != fa) { calSize(to[j], x); siz[x] += siz[to[j]]; } } int mn, g; void findGravity(int x, int fa, int total) { int ans = 0; for (int j = head[x]; j; j = nxt[j]) if (!del[to[j]] && to[j] != fa) ans = max(ans, siz[to[j]]); ans = max(ans, total - siz[x]); if (ans < mn) { mn = ans; g = x; } for (int j = head[x]; j; j = nxt[j]) if (!del[to[j]] && to[j] != fa) findGravity(to[j], x, total); } int pa[N]; vector<pair<int, int> > root_dep; vector<int> save_dep[N]; void dfs(int x, int fa, int root, int dep) { ++ans[dep]; save_dep[root].push_back(dep); for (int j = head[x]; j; j = nxt[j]) if (!del[to[j]] && to[j] != fa) dfs(to[j], x, root, dep + 1); } Cp c[131072], c1[131072], c2[131072]; void divideConquer(int x) { calSize(x, -1); if (siz[x] == 1) return; mn = 0x3f3f3f3f; findGravity(x, -1, siz[x]); root_dep.clear(); int i, j, M; for (j = head[g]; j; j = nxt[j]) if (!del[to[j]]) { save_dep[to[j]].clear(); dfs(to[j], g, to[j], 1); sort(save_dep[to[j]].begin(), save_dep[to[j]].end()); root_dep.push_back(make_pair(save_dep[to[j]].back(), to[j])); } sort(root_dep.begin(), root_dep.end()); for (i = 0; i <= siz[x]; ++i) c[i] = Cp(0, 0); for (i = 0; i < root_dep.size(); ++i) { for (M = 1; M <= (root_dep[i].first << 1); M <<= 1); for (j = 0; j < M; ++j) c1[j] = c2[j] = Cp(0, 0); for (j = 1; j <= root_dep[i].first; ++j) c1[j] = c[j]; for (j = 0; j < save_dep[root_dep[i].second].size(); ++j) c2[save_dep[root_dep[i].second][j]].u += 1; fastFourierTransform(c1, M, 1); fastFourierTransform(c2, M, 1); for (j = 0; j < M; ++j) c1[j] = c1[j] * c2[j]; fastFourierTransform(c1, M, -1); for (j = 0; j < M; ++j) ans[j] += (int)(c1[j].u + .5); for (j = 0; j < save_dep[root_dep[i].second].size(); ++j) c[save_dep[root_dep[i].second][j]].u += 1; } del[g] = 1; for (int j = head[g]; j; j = nxt[j]) if (!del[to[j]]) divideConquer(to[j]); } int main() { //freopen("tt.in", "r", stdin); int n = getint(); pre(n); int i, a, b; for (i = 1; i < n; ++i) { a = getint(); b = getint(); make(a, b); } init(); //for (i = 0; i < 8; ++i) // c1[i] = Cp(i, 0); //fastFourierTransform(c1, 8, 1); //fastFourierTransform(c1, 8, -1); //for (i = 0; i < 8; ++i) // printf("%lf %lf\n", c1[i].u, c1[i].v); //return 0; divideConquer(1); long long total_path = 0; for (i = 1; i <= n; ++i) total_path += ans[i]; long long prime_path = 0; for (i = 1; i <= id; ++i) prime_path += ans[prime[i]]; printf("%.10lf", (db)prime_path / total_path); return 0; }