Codechef 12.6 CLOSEST KDTree
题目大意:
给定三维空间中的$n$个点,另有$Q$组询问,每次给定一个三维空间中的点,求到这个点的欧几里得距离最小的点的下标。
数据范围$n,Q\leq{50000}$,坐标数值范围绝对值$\leq{10^9}$,得分与输出正确的数量有关。
算法讨论:
考虑K-Dimension Tree来进行修改和查询,具体细节在这里不再赘述。
直接进行每次查询期望是$O(\sqrt{n})$的,但是点集可能并不是随机的,可能会超时,于是我们进行卡时,利用KDTree回答若干个询问直到时间所剩无几,然后对于剩下的询问输出随机数。
时空复杂度:
时间复杂度$O(Q\sqrt{n})$,空间复杂度$O(n)$。
代码:
#include <cstdio> #include <cstring> #include <cctype> #include <iostream> #include <algorithm> #include <ctime> #include <cstdlib> using namespace std; int getc() { static const int L = 1 << 15; static char buf[L], *S = buf, *T = buf; if (S == T) { T = (S = buf) + fread(buf, 1, L, stdin); if (S == T) return EOF; } return *S++; } int getint() { static int x, c, sign; while (!isdigit(c = getc()) && c != '-'); if (c == '-') sign = -1, x = 0; else sign = 1, x = c - '0'; while (isdigit(c = getc())) x = (x << 1) + (x << 3) + c - '0'; return sign * x; } typedef long long ll; #define N 50010 struct Point { int x[3], id; Point() {} friend ll getdis(const Point &a, const Point &b) { static ll ans, i; ans = 0; for (i = 0; i < 3; ++i) ans += (ll)(a.x[i] - b.x[i]) * (a.x[i] - b.x[i]); return ans; } }P[N], _P[N]; struct Node { Node *ls, *rs; int mx[3], mn[3], id; ll _minDist(const Point &a, int d) { if (a.x[d] > mx[d]) return (ll)(a.x[d] - mx[d]) * (a.x[d] - mx[d]); else if (a.x[d] < mn[d]) return (ll)(mn[d] - a.x[d]) * (mn[d] - a.x[d]); else return 0; } ll minDist(const Point &a) { ll ans = 0; for (int j = 0; j < 3; ++j) ans += _minDist(a, j); return ans; } }mem[N], *G = mem; bool cmp0(const Point &a, const Point &b) { return a.x[0] < b.x[0]; } bool cmp1(const Point &a, const Point &b) { return a.x[1] < b.x[1]; } bool cmp2(const Point &a, const Point &b) { return a.x[2] < b.x[2]; } Node *build(int tl, int tr, int d) { if (tl > tr) return NULL; int mid = (tl + tr) >> 1; if (d == 0) nth_element(P + tl, P + mid, P + tr + 1, cmp0); else if (d == 1) nth_element(P + tl, P + mid, P + tr + 1, cmp1); else nth_element(P + tl, P + mid, P + tr + 1, cmp2); Node *q = G++; q->id = P[mid].id; for (int j = 0; j < 3; ++j) { q->mn[j] = 0x3f3f3f3f; q->mx[j] = -0x3f3f3f3f; } for (int i = tl; i <= tr; ++i) { for (int j = 0; j < 3; ++j) { q->mx[j] = max(q->mx[j], P[i].x[j]); q->mn[j] = min(q->mn[j], P[i].x[j]); } } q->ls = build(tl, mid - 1, (d + 1) % 3); q->rs = build(mid + 1, tr, (d + 1) % 3); return q; } ll ans; int ans_id; void query(Node *q, int d, Point p) { ll dis = getdis(p, _P[q->id]); if (ans > dis) { ans = dis; ans_id = q->id; } ll l_dis, r_dis; l_dis = q->ls ? q->ls->minDist(p) : 1ll << 60; r_dis = q->rs ? q->rs->minDist(p) : 1ll << 60; if (l_dis < r_dis) { if (l_dis < ans) { if (q->ls) query(q->ls, (d + 1) % 3, p); if (r_dis < ans && q->rs) query(q->rs, (d + 1) % 3, p); } } else { if (r_dis < ans) { if (q->rs) query(q->rs, (d + 1) % 3, p); if (l_dis < ans && q->ls) query(q->ls, (d + 1) % 3, p); } } } int main() { #ifndef ONLINE_JUDGE freopen("tt.in", "r", stdin); freopen("tt.out", "w", stdout); #endif clock_t begin = clock(); int n = getint(), i, j; for (i = 1; i <= n; ++i) { P[i].id = i; for (j = 0; j < 3; ++j) P[i].x[j] = getint(); _P[i] = P[i]; } Node *root = build(1, n, 0); Point p; int q = getint(); for (i = 1; i <= q; ++i) { if (clock() - begin > 900) break; for (j = 0; j < 3; ++j) p.x[j] = getint(); ans = 1ll << 60; query(root, 0, p); printf("%d\n", ans_id - 1); //if (ans != getdis(p, _P[ans_id])) // puts("WA"); //ll std_ans = 1ll << 60; //for (j = 1; j <= n; ++j) // std_ans = min(std_ans, getdis(p, P[j])); //if (ans != std_ans) // puts("WA"); } while (i <= q) { printf("%d\n", (long long)rand() * rand() % n); ++i; } return 0; }