Codechef 13.2 QUERY 可持久化线段树+可持久化数据结构+标记永久化
Codechef 12.7 DGCD 数学+树链剖分+线段树

Codechef 12.11 COUNTARI FFT+分块

shinbokuow posted @ Nov 20, 2015 04:26:55 PM in Something with tags FFT 分块 , 1511 阅读

 

题目大意:
给定一个长度为$n$的正整数序列$a$,问有多少三元组$(i,j,k)$满足$1\leq{i}<j<k\leq{n}$且$a_j-a_i=a_k-a_j$。
数据范围$n\leq{10^5}$,$1\leq{a_i}\leq{30000}$。
算法讨论:
不妨将序列分成$m$块,每块$\frac{n}{m}$个元素。
考虑三元组分布的情况:
(1)分布在同一块中。
(2)前两个分布在同一块中,最后一个分布在另一块中。
(3)后两个分布在同一块中,第一个分布在另一块中。
这三种情况,我们都可以通过暴力枚举$j$和同一块中的$i$(或$k$),并利用$2a_j=a_i+a_k$统计此时合法的$k$(或$i$)的数目,只需维护一个带时间戳的计数数组就能完成统计。
时间复杂度$O(m\times{(\frac{n}{m})}^2+n)$。
(4)三个元素分布在不同块中。
只需枚举$j$所在的块,对于两侧的两个序列做卷积,就能对于这个块中的每个$j$算出有多少对合法的$(i,k)$。
不妨令$w=max\{a_i\}$。
时间复杂度$O(m(n+w\log w))$。
实际测试中令$m=30$就能通过全部测试数据了。
时空复杂度:
时间复杂度$O(m\times{(\frac{n}{m})^2}+n+m(n+w\log w))$,空间复杂度$O(n+w)$。
代码:
#include <cstdio>
#include <cstring>
#include <cctype>
#include <iostream>
#include <algorithm>
#include <cmath>
#include <vector>
using namespace std;
typedef long long ll;
typedef long double ldb;
struct Complex {
    ldb u, v;
    Complex() {}
    Complex(ldb _u, ldb _v) : u(_u), v(_v) {}
    friend Complex operator + (const Complex &a, const Complex &b) {
        return Complex(a.u + b.u, a.v + b.v);
    }
    friend Complex operator - (const Complex &a, const Complex &b) {
        return Complex(a.u - b.u, a.v - b.v);
    }
    friend Complex operator * (const Complex &a, const Complex &b) {
        return Complex(a.u * b.u - a.v * b.v, a.u * b.v + a.v * b.u);
    }
    friend Complex operator * (const Complex &a, const ldb &p) {
        return Complex(a.u * p, a.v * p);
    }
    friend Complex operator / (const Complex &a, const ldb &p) {
        return Complex(a.u / p, a.v / p);
    }
};
vector<int> Rev[18];
//vector<Complex> wn[18], inv_wn[18];
static const ldb pi = acos(-1);
void init() {
    int i, j, k, add;
    for (i = 1; i <= 17; ++i) {
        for (j = 0; j < (1 << i); ++j) {
            add = 0;
            for (k = 0; k < i; ++k)
                if ((j >> k) & 1)
                    add |= (1 << (i - k - 1));
            Rev[i].push_back(add);
        }
        /*wn[i].push_back(Complex(1, 0));
        inv_wn[i].push_back(Complex(1, 0));
        wn[i].push_back(Complex(cos(pi / (1 << (i - 1))), sin(pi / (1 << (i - 1)))));
        inv_wn[i].push_back(Complex(cos(pi / (1 << (i - 1))), -sin(pi / (1 << (i - 1)))));
        for (j = 2; j <= 1 << (17 - i); ++j) {
            wn[i].push_back(wn[i].back() * wn[i][1]);
            inv_wn[i].push_back(inv_wn[i].back() * inv_wn[i][1]);
        }*/
    }
}
int get_bit(int n) {
    static int re;
    for (re = 0; n > 1; n >>= 1, ++re);
    return re;
}
/*void FFT(Complex A[], int n, bool rev) {
    static Complex B[131072], t;
    static int m, i, j, k, half;
    m = get_bit(n);
    for (i = 0; i < n; ++i)
        B[Rev[m][i]] = A[i];
    for (i = 0; i < n; ++i)
        A[i] = B[i];
    for (k = 1; k <= m; ++k) {
        half = 1 << (k - 1);
        for (i = 0; i < n; i += half << 1) {
            for (j = 0; j < half; ++j) {
                t = A[i + j + half] * (rev ? inv_wn[k][j] : wn[k][j]);
                A[i + j + half] = A[i + j] - t;
                A[i + j] = A[i + j] + t;
            }
        }
    }
    if (rev) {
        for (i = 0; i < n; ++i)
            A[i] = A[i] / n;
    }
}*/
void FFT(Complex A[], int n, bool rev) {
    static Complex B[65536], w, wn, t;
    static int m, i, j, k, half;
    m = get_bit(n);
    for (i = 0; i < n; ++i)
        B[Rev[m][i]] = A[i];
    for (i = 0; i < n; ++i)
        A[i] = B[i];
    for (k = 2; k <= n; k <<= 1) {
        half = k >> 1;
        for (wn = Complex(cos(2 * pi / k), (rev ? -1 : 1) * sin(2 * pi / k)), i = 0; i < n; i += k) {
            for (w = Complex(1, 0), j = 0; j < half; ++j, w = w * wn) {
                t = A[i + j + half] * w;
                A[i + j + half] = A[i + j] - t;
                A[i + j] = A[i + j] + t;
            }
        }
    }
    if (rev) {
        for (i = 0; i < n; ++i)
            A[i] = A[i] / n;
    }
}
#define N 100010
#define W 30010
int n, a[N];
int getint() {
    static int c, x;
    while (!isdigit(c = getchar()));
    x = c - '0';
    while (isdigit(c = getchar()))
        x = (x << 1) + (x << 3) + c - '0';
    return x;
}
int l[N], r[N], cnt;
int c[W << 1], t[W << 1], tclock;
int get(int x) {
    if (t[x] != tclock) {
        t[x] = tclock;
        c[x] = 0;
    }
    return c[x];
}
void add(int x) {
    if (t[x] != tclock) {
        t[x] = tclock;
        c[x] = 0;
    }
    ++c[x];
}
Complex A[65536], B[65536];
int main() {
#ifndef ONLINE_JUDGE
    freopen("tt.in", "r", stdin);
#endif
    init();
    n = getint();
    int i, j, k, mx = 0;
    for (i = 1; i <= n; ++i) {
        a[i] = getint();
        mx = max(mx, a[i]);
    }
    //int m = ceil(n / (sqrt(n/ (log(mx) / log(2)))));
    int m = ceil(n / 30.0);
    //int m = 1;
    //int m = ceil(sqrt(n));
    int point = 0;
    //printf("%d\n", m);
    while (point < n) {
        l[++cnt] = point + 1;
        for (i = 1; i <= m && point < n; ++i)
            ++point;
        r[cnt] = point;
    }
    //printf("%d\n", cnt);
    //return 0;
    long long ans = 0;
    for (i = 1; i <= cnt; ++i) {
        ++tclock;
        for (j = r[i]; j >= l[i]; --j) {
            for (k = j - 1; k >= l[i]; --k)
                if (2 * a[j] > a[k])
                    ans += get(2 * a[j] - a[k]);
            add(a[j]);
        }
    }
    //puts("OK1");
    ++tclock;
    for (i = cnt; i >= 1; --i) {
        for (j = r[i]; j >= l[i]; --j)
            for (k = j - 1; k >= l[i]; --k)
                if (2 * a[j] > a[k])
                    ans += get(2 * a[j] - a[k]);
        for (j = r[i]; j >= l[i]; --j)
            add(a[j]);
    }
    //puts("OK2");
    ++tclock;
    for (i = 1; i <= cnt; ++i) {
        for (j = l[i]; j <= r[i]; ++j)
            for (k = j + 1; k <= r[i]; ++k)
                if (2 * a[j] > a[k])
                    ans += get(2 * a[j] - a[k]);
        for (j = l[i]; j <= r[i]; ++j)
            add(a[j]);
    }
    //puts("OK3");
    int mxl, mxr, M;
    for (i = 2; i < cnt; ++i) {
        mxl = 0;
        for (j = 1; j < l[i]; ++j)
            mxl = max(mxl, a[j]);
        mxr = 0;
        for (j = r[i] + 1; j <= n; ++j)
            mxr = max(mxr, a[j]);
        for (M = 1; M < (mxl + mxr + 1); M <<= 1);
        //printf("%d\n", M);
        for (j = 0; j < M; ++j)
            A[j] = B[j] = Complex(0, 0);
        for (j = 1; j < l[i]; ++j)
            ++A[a[j]].u;
        for (j = r[i] + 1; j <= n; ++j)
            ++B[a[j]].u;
        FFT(A, M, 0);
        FFT(B, M, 0);
        for (j = 0; j < M; ++j)
            A[j] = A[j] * B[j];
        FFT(A, M, 1);
        //for (j = 0; j < M; ++j)
        //  printf("%.10lf %.10lf\n", (double)A[j].u, (double)A[j].v);
        for (j = l[i]; j <= r[i]; ++j)
            if (2 * a[j] < M)
                ans += (long long)(A[2 * a[j]].u + .5);
    }
    
    cout << ans << endl;
    
    return 0;
}

 


登录 *


loading captcha image...
(输入验证码)
or Ctrl+Enter