Codechef 12.11 COUNTARI FFT+分块
题目大意:
给定一个长度为$n$的正整数序列$a$,问有多少三元组$(i,j,k)$满足$1\leq{i}<j<k\leq{n}$且$a_j-a_i=a_k-a_j$。
数据范围$n\leq{10^5}$,$1\leq{a_i}\leq{30000}$。
算法讨论:
不妨将序列分成$m$块,每块$\frac{n}{m}$个元素。
考虑三元组分布的情况:
(1)分布在同一块中。
(2)前两个分布在同一块中,最后一个分布在另一块中。
(3)后两个分布在同一块中,第一个分布在另一块中。
这三种情况,我们都可以通过暴力枚举$j$和同一块中的$i$(或$k$),并利用$2a_j=a_i+a_k$统计此时合法的$k$(或$i$)的数目,只需维护一个带时间戳的计数数组就能完成统计。
时间复杂度$O(m\times{(\frac{n}{m})}^2+n)$。
(4)三个元素分布在不同块中。
只需枚举$j$所在的块,对于两侧的两个序列做卷积,就能对于这个块中的每个$j$算出有多少对合法的$(i,k)$。
不妨令$w=max\{a_i\}$。
时间复杂度$O(m(n+w\log w))$。
实际测试中令$m=30$就能通过全部测试数据了。
时空复杂度:
时间复杂度$O(m\times{(\frac{n}{m})^2}+n+m(n+w\log w))$,空间复杂度$O(n+w)$。
代码:
#include <cstdio> #include <cstring> #include <cctype> #include <iostream> #include <algorithm> #include <cmath> #include <vector> using namespace std; typedef long long ll; typedef long double ldb; struct Complex { ldb u, v; Complex() {} Complex(ldb _u, ldb _v) : u(_u), v(_v) {} friend Complex operator + (const Complex &a, const Complex &b) { return Complex(a.u + b.u, a.v + b.v); } friend Complex operator - (const Complex &a, const Complex &b) { return Complex(a.u - b.u, a.v - b.v); } friend Complex operator * (const Complex &a, const Complex &b) { return Complex(a.u * b.u - a.v * b.v, a.u * b.v + a.v * b.u); } friend Complex operator * (const Complex &a, const ldb &p) { return Complex(a.u * p, a.v * p); } friend Complex operator / (const Complex &a, const ldb &p) { return Complex(a.u / p, a.v / p); } }; vector<int> Rev[18]; //vector<Complex> wn[18], inv_wn[18]; static const ldb pi = acos(-1); void init() { int i, j, k, add; for (i = 1; i <= 17; ++i) { for (j = 0; j < (1 << i); ++j) { add = 0; for (k = 0; k < i; ++k) if ((j >> k) & 1) add |= (1 << (i - k - 1)); Rev[i].push_back(add); } /*wn[i].push_back(Complex(1, 0)); inv_wn[i].push_back(Complex(1, 0)); wn[i].push_back(Complex(cos(pi / (1 << (i - 1))), sin(pi / (1 << (i - 1))))); inv_wn[i].push_back(Complex(cos(pi / (1 << (i - 1))), -sin(pi / (1 << (i - 1))))); for (j = 2; j <= 1 << (17 - i); ++j) { wn[i].push_back(wn[i].back() * wn[i][1]); inv_wn[i].push_back(inv_wn[i].back() * inv_wn[i][1]); }*/ } } int get_bit(int n) { static int re; for (re = 0; n > 1; n >>= 1, ++re); return re; } /*void FFT(Complex A[], int n, bool rev) { static Complex B[131072], t; static int m, i, j, k, half; m = get_bit(n); for (i = 0; i < n; ++i) B[Rev[m][i]] = A[i]; for (i = 0; i < n; ++i) A[i] = B[i]; for (k = 1; k <= m; ++k) { half = 1 << (k - 1); for (i = 0; i < n; i += half << 1) { for (j = 0; j < half; ++j) { t = A[i + j + half] * (rev ? inv_wn[k][j] : wn[k][j]); A[i + j + half] = A[i + j] - t; A[i + j] = A[i + j] + t; } } } if (rev) { for (i = 0; i < n; ++i) A[i] = A[i] / n; } }*/ void FFT(Complex A[], int n, bool rev) { static Complex B[65536], w, wn, t; static int m, i, j, k, half; m = get_bit(n); for (i = 0; i < n; ++i) B[Rev[m][i]] = A[i]; for (i = 0; i < n; ++i) A[i] = B[i]; for (k = 2; k <= n; k <<= 1) { half = k >> 1; for (wn = Complex(cos(2 * pi / k), (rev ? -1 : 1) * sin(2 * pi / k)), i = 0; i < n; i += k) { for (w = Complex(1, 0), j = 0; j < half; ++j, w = w * wn) { t = A[i + j + half] * w; A[i + j + half] = A[i + j] - t; A[i + j] = A[i + j] + t; } } } if (rev) { for (i = 0; i < n; ++i) A[i] = A[i] / n; } } #define N 100010 #define W 30010 int n, a[N]; int getint() { static int c, x; while (!isdigit(c = getchar())); x = c - '0'; while (isdigit(c = getchar())) x = (x << 1) + (x << 3) + c - '0'; return x; } int l[N], r[N], cnt; int c[W << 1], t[W << 1], tclock; int get(int x) { if (t[x] != tclock) { t[x] = tclock; c[x] = 0; } return c[x]; } void add(int x) { if (t[x] != tclock) { t[x] = tclock; c[x] = 0; } ++c[x]; } Complex A[65536], B[65536]; int main() { #ifndef ONLINE_JUDGE freopen("tt.in", "r", stdin); #endif init(); n = getint(); int i, j, k, mx = 0; for (i = 1; i <= n; ++i) { a[i] = getint(); mx = max(mx, a[i]); } //int m = ceil(n / (sqrt(n/ (log(mx) / log(2))))); int m = ceil(n / 30.0); //int m = 1; //int m = ceil(sqrt(n)); int point = 0; //printf("%d\n", m); while (point < n) { l[++cnt] = point + 1; for (i = 1; i <= m && point < n; ++i) ++point; r[cnt] = point; } //printf("%d\n", cnt); //return 0; long long ans = 0; for (i = 1; i <= cnt; ++i) { ++tclock; for (j = r[i]; j >= l[i]; --j) { for (k = j - 1; k >= l[i]; --k) if (2 * a[j] > a[k]) ans += get(2 * a[j] - a[k]); add(a[j]); } } //puts("OK1"); ++tclock; for (i = cnt; i >= 1; --i) { for (j = r[i]; j >= l[i]; --j) for (k = j - 1; k >= l[i]; --k) if (2 * a[j] > a[k]) ans += get(2 * a[j] - a[k]); for (j = r[i]; j >= l[i]; --j) add(a[j]); } //puts("OK2"); ++tclock; for (i = 1; i <= cnt; ++i) { for (j = l[i]; j <= r[i]; ++j) for (k = j + 1; k <= r[i]; ++k) if (2 * a[j] > a[k]) ans += get(2 * a[j] - a[k]); for (j = l[i]; j <= r[i]; ++j) add(a[j]); } //puts("OK3"); int mxl, mxr, M; for (i = 2; i < cnt; ++i) { mxl = 0; for (j = 1; j < l[i]; ++j) mxl = max(mxl, a[j]); mxr = 0; for (j = r[i] + 1; j <= n; ++j) mxr = max(mxr, a[j]); for (M = 1; M < (mxl + mxr + 1); M <<= 1); //printf("%d\n", M); for (j = 0; j < M; ++j) A[j] = B[j] = Complex(0, 0); for (j = 1; j < l[i]; ++j) ++A[a[j]].u; for (j = r[i] + 1; j <= n; ++j) ++B[a[j]].u; FFT(A, M, 0); FFT(B, M, 0); for (j = 0; j < M; ++j) A[j] = A[j] * B[j]; FFT(A, M, 1); //for (j = 0; j < M; ++j) // printf("%.10lf %.10lf\n", (double)A[j].u, (double)A[j].v); for (j = l[i]; j <= r[i]; ++j) if (2 * a[j] < M) ans += (long long)(A[2 * a[j]].u + .5); } cout << ans << endl; return 0; }