Codechef 12.12 DIFTRIP Trie树+后缀自动机+STL
题目大意:
给定一棵$n$个点的有根树,定义两条深度递增的路径是相似的当且仅当两条路径长度相同并且对应位置的点的度数相同。
问不相似的路径一共有多少种。
数据范围$n\leq{10^5}$。
算法讨论:
相当于是一颗Trie树,其中的每个点都写着一个与这个点在树上的度数相等的的字符,那么只需要求出Trie树上有多少种不同子串就行了。
利用后缀自动机在Trie树上的拓展就能解决这个问题了。
注意这里的字符集大小是$O(n)$级别的,所以需要用$map$来存指针。
时空复杂度:
时间复杂度$O(n\log n)$,空间复杂度$O(n)$。
代码:
#include <cstdio> #include <cstring> #include <cctype> #include <iostream> #include <algorithm> #include <map> using namespace std; int getc() { static const int L = 1 << 15; static char buf[L], *S = buf, *T = buf; if (S == T) { T = (S = buf) + fread(buf, 1, L, stdin); if (S == T) return EOF; } return *S++; } int getint() { static int c, x; while (!isdigit(c = getc())); x = c - '0'; while (isdigit(c = getc())) x = (x << 1) + (x << 3) + c - '0'; return x; } #define N 100010 int n, head[N], nxt[N << 1], to[N << 1], d[N], fa[N]; bool vis[N]; void addedge(int a, int b) { static int q = 1; to[q] = b; nxt[q] = head[a]; head[a] = q++; } void make(int a, int b) { ++d[a]; ++d[b]; addedge(a, b); addedge(b, a); } map<int, int> tr[N << 1]; int pa[N << 1], len[N << 1], cnt, root, last; int newnode(int _len) { len[++cnt] = _len; return cnt; } int lastnode[N], qu[N], fr, ta; int main() { #ifndef ONLINE_JUDGE freopen("tt.in", "r", stdin); #endif n = getint(); int i, j, a, b; for (i = 1; i < n; ++i) { a = getint(); b = getint(); make(a, b); } root = last = newnode(0); lastnode[0] = root; vis[qu[ta++] = 1] = 1; int p, np, q, nq, y; while (fr != ta) { i = qu[fr++]; last = lastnode[fa[i]]; y = d[i]; if (tr[last].count(y)) { q = tr[last][y]; if (len[q] == len[last] + 1) np = q; else { nq = newnode(len[last] + 1); pa[nq] = last; pa[q] = nq; tr[nq] = tr[q]; for (p = last; p && tr[p].count(y) && tr[p][y] == q; p = pa[p]) tr[p][y] = nq; np = nq; } } else { np = newnode(len[last] + 1); for (p = last; p && !tr[p].count(y); p = pa[p]) tr[p][y] = np; if (!p) pa[np] = root; else { if (len[q = tr[p][y]] == len[p] + 1) pa[np] = q; else { nq = newnode(len[p] + 1); pa[nq] = pa[q]; pa[q] = pa[np] = nq; tr[nq] = tr[q]; for (; p && tr[p].count(y) && tr[p][y] == q; p = pa[p]) tr[p][y] = nq; } } } lastnode[i] = np; for (j = head[i]; j; j = nxt[j]) if (!vis[to[j]]) { fa[to[j]] = i; vis[qu[ta++] = to[j]] = 1; } } long long ans = 0; for (i = 2; i <= cnt; ++i) ans += len[i] - len[pa[i]]; cout << ans << endl; return 0; }