Codechef 15.5 CBAL 分块
题目大意:
定义一个字符串是平衡的当且仅当这个字符串中的所有字符能被完全分成两个相同的集合。对于一个字符串$s$,$s$的长度用$|s|$表示,$s[l,r]$表示$s$的第$l$个字符到第$r$个字符形成的字符串($1\leq{l}\leq{r}\leq{|s|}$)。
现在给定一个长度为$n$的字符串,另有$q$组询问,每次给定一个区间$[l,r]$,再给定一个非负整数$type$,求:
\[\sum_{l\leq{l'}\leq{r'}\leq{r}}|s[l',r']|^{type}[IsBalanceString(s[l',r'])]\]
要求强制在线,数据范围$n,q\leq{10^5},0\leq{type}\leq{2}$。
算法讨论:
一个串是平衡的当且仅当串中的每种字符出现的次数都是偶数,一个串中每种字符的出现次数都能用两个前缀相减,因此如果一个串是平衡的,那么两个前缀中每种字符出现次数的奇偶性相同。
我们用一个$2^{26}$以内的整数存下每个前缀每种字符的奇偶性,并去重重新标号为区间$[1,n+1]$中的数,令$label_i$表示长度为$i$的前缀的标号。
那么对于询问$[l,r]$,其实也就是询问:
\[\sum_{l-1\leq{l'}<r'\leq{r}}(r'-l')^{type}[label_{l'}==label_{r'}]\]
如果对于区间$[l,r]$已经有答案,我们能够方便的将区间拓展到$[l-1,r]$或者$[l,r+1]$,只需要在区间中维护每种$label$对应下标的数目、和、平方和就能够$O(1)$实现答案更新。时间复杂度$O(n\sqrt{n})$。
考虑分块,将$n+1$个前缀分成$\sqrt{n}$块,利用刚才的算法预处理每两个块之间的答案,然后再维护以每个块为结尾的前缀中每种$label$对应下标的数目、和、平方和,这样查询的时候对于整块部分我们直接调用答案,对于零散的部分我们利用预处理的前缀和信息暴力计算即可,单组询问时间复杂度$O(\sqrt{n})$。
时空复杂度:
时间复杂度$O((n+q)\sqrt{n})$,空间复杂度$O(n\sqrt{n})$。
代码:
#include <cstdio> #include <cstring> #include <cctype> #include <iostream> #include <algorithm> #include <cmath> #include <map> using namespace std; typedef long long ll; int getc() { static const int L = 1 << 15; static char buf[L], *S = buf, *T = buf; if (S == T) { T = (S = buf) + fread(buf, 1, L, stdin); if (S == T) return EOF; } return *S++; } int getint() { static int c, x; while (!isdigit(c = getc())); x = c - '0'; while (isdigit(c = getc())) x = (x << 1) + (x << 3) + c - '0'; return x; } #define N 100010 char s[N]; int n; void get_string() { static int c; n = 0; while ((c = getc()) < 'a' || c > 'z'); s[++n] = c; while ((c = getc()) >= 'a' && c <= 'z') s[++n] = c; } map<int, int> M; int label[N], id; int sum0[321][N]; ll sum1[321][N], sum2[321][N]; ll ans0[321][321], ans1[321][321], ans2[321][321]; int l[N], r[N], block[N]; struct Array { ll sum[N]; int t[N], tclock; ll operator [] (const int &x) { if (t[x] != tclock) { t[x] = tclock; sum[x] = 0; } return sum[x]; } void add(const int &x, const ll &_add) { if (t[x] != tclock) { t[x] = tclock; sum[x] = 0; } sum[x] += _add; } }_sum0, _sum1, _sum2; ll _ans0, _ans1, _ans2; ll myAbs(const ll &x) { return x < 0 ? -x : x; } void Update(int x) { _ans0 += _sum0[label[x]]; _ans1 += myAbs((ll)_sum0[label[x]] * x - _sum1[label[x]]); _ans2 += (ll)_sum0[label[x]] * x * x + _sum2[label[x]] - 2ll * x * _sum1[label[x]]; _sum0.add(label[x], 1); _sum1.add(label[x], x); _sum2.add(label[x], (ll)x * x); } ll solve(int L, int R, int type) { static int i; ++_sum0.tclock; ++_sum1.tclock; ++_sum2.tclock; if (block[R] - block[L] <= 1) { _ans0 = _ans1 = _ans2 = 0; for (i = L; i <= R; ++i) Update(i); } else { _ans0 = ans0[block[L] + 1][block[R] - 1]; _ans1 = ans1[block[L] + 1][block[R] - 1]; _ans2 = ans2[block[L] + 1][block[R] - 1]; for (i = l[block[R]]; i <= R; ++i) { if (_sum0.t[label[i]] != _sum0.tclock) { _sum0.t[label[i]] = _sum1.t[label[i]] = _sum2.t[label[i]] = _sum0.tclock; _sum0.sum[label[i]] = sum0[block[R] - 1][label[i]] - sum0[block[L]][label[i]]; _sum1.sum[label[i]] = sum1[block[R] - 1][label[i]] - sum1[block[L]][label[i]]; _sum2.sum[label[i]] = sum2[block[R] - 1][label[i]] - sum2[block[L]][label[i]]; } Update(i); } for (i = r[block[L]]; i >= L; --i) { if (_sum0.t[label[i]] != _sum0.tclock) { _sum0.t[label[i]] = _sum1.t[label[i]] = _sum2.t[label[i]] = _sum0.tclock; _sum0.sum[label[i]] = sum0[block[R] - 1][label[i]] - sum0[block[L]][label[i]]; _sum1.sum[label[i]] = sum1[block[R] - 1][label[i]] - sum1[block[L]][label[i]]; _sum2.sum[label[i]] = sum2[block[R] - 1][label[i]] - sum2[block[L]][label[i]]; } Update(i); } } if (type == 0) return _ans0; if (type == 1) return _ans1; if (type == 2) return _ans2; } int main() { #ifndef ONLINE_JUDGE freopen("tt.in", "r", stdin); #endif static int T = getint(), i, j, k, Q, x, y, type, temp; static ll A, B, ans; while (T--) { get_string(); id = 0; temp = 0; M.clear(); for (i = 0; i <= n; ++i) { if (i >= 1) temp ^= (1 << (s[i] - 'a')); if (M.count(temp) == 0) { label[i] = M[temp] = ++id; //printf("%d\n", M.count(temp)); } else label[i] = M[temp]; } int m = ceil(sqrt(n)), cnt = 0, point = -1; while (point < n) { l[++cnt] = point + 1; for (i = 1; point < n && i <= m; ++i) ++point; r[cnt] = point; for (i = l[cnt]; i <= r[cnt]; ++i) block[i] = cnt; } for (i = 1; i <= cnt; ++i) { for (j = 1; j <= id; ++j) { sum0[i][j] = sum0[i - 1][j]; sum1[i][j] = sum1[i - 1][j]; sum2[i][j] = sum2[i - 1][j]; } for (j = l[i]; j <= r[i]; ++j) { sum0[i][label[j]]++; sum1[i][label[j]] += j; sum2[i][label[j]] += (ll)j * j; } } for (i = 1; i <= cnt; ++i) { ++_sum0.tclock; ++_sum1.tclock; ++_sum2.tclock; _ans0 = _ans1 = _ans2 = 0; for (j = i; j <= cnt; ++j) { for (k = l[j]; k <= r[j]; ++k) Update(k); ans0[i][j] = _ans0; ans1[i][j] = _ans1; ans2[i][j] = _ans2; } } A = 0, B = 0; Q = getint(); while (Q--) { x = getint(); y = getint(); type = getint(); x = (A + x) % n + 1; y = (B + y) % n + 1; if (x > y) swap(x, y); ans = solve(x - 1, y, type); printf("%I64d\n", ans); //printf("%lld\n", ans); A = B, B = ans; } } return 0; }