如果处理“所有经过某一个顶点的链对答案的贡献”的时间复杂度为 O ( n ) O(n) O(n)或者 O ( n l o g n ) O(nlogn) O(nlogn),那么运用点分治的思想可以把问题规模降为 O ( n l o g n ) O(nlogn) O(nlogn)或 O ( n l o g 2 n ) O(nlog^2n) O(nlog2n),而非暴力枚举顶点计算答案的 O ( n 2 ) O(n^2) O(n2)。
所以说,点分治是一种在树上统计合法链个数的思想。显而易见的,对于当前顶点 x x x,任意一条链要么经过 x x x,要么不经过 x x x。于是我们只需要计算那些经过 x x x的链,剩下不经过 x x x的链则统统放在子树中递归计算,此为分治。
那么,怎么分治?按照常规的递推思路来考虑,处理完当前顶点后,递归进入所有与顶点相邻的顶点并分别处理它们?不行,因为很容易可以找到一条退化成链的树(见下图),每一层递归只能使问题规模减少1(去掉一个顶点),却要递归到第n层,显然,时间复杂度是 O ( n 2 ) O(n^2) O(n2)的。
比起点分治,这似乎更像树形dp,但树形dp的前提是每一个点的状态都可以借助其子结点 O ( 1 ) O(1) O(1)处理(在例题中可以看到一个既能用点分治,又能用树形dp来解决的例子)。但如果每一条链的状态都不能够合并,无法用dp来降低复杂度呢?实际上,点分治可以说就是为了解决一些无法保存状态的树形dp而发明的一种“优雅的暴力”。
点分治维护其复杂度的关键在于“找重心”。依然从最难处理的链来考虑问题,如果每次递归进入一颗子树时,先找子树的重心,再以重心来分割子树,可以发现每一次分割都使问题的规模减少为原来的一半,那么经过 O ( l o g n ) O(logn) O(logn)次分割,问题规模一定被减少到1,也就是每个顶点都已被考虑到。如果用递归树来刻画这一过程(见下图),可以发现树高为 l o g ( n ) log(n) log(n),每一层的时间之和都为 T ( n ) T(n) T(n),于是,我们成功将 O ( n 2 ) O(n^2) O(n2)的复杂度,通过找重心降低到 O ( n l o g n ) O(nlogn) O(nlogn)。
点分治的题目中,“分治”部分几乎是一成不变的,无非是找到重心后,以重心为根重新计算子树大小,并利用计算结果进一步去找子树的重心。简而言之就是两个函数Getroot和Getsz而已。
int Getroot(int x, int f) { int sum = 0, mx = 0, tmp; for (auto T : G[x]) { int to = T.to; if (vis[to] || to == f) continue; tmp = Getroot(to, x); sum += tmp; mx = max(mx, tmp); } sum++; mx = max(mx, tot - sum); if (mx < MN) { MN = mx; rt = x; } return sum; } void Getsz(int x, int f) { sz[x] = 0; for (auto T : G[x]) { int to = T.to; if (vis[to] || to == f) continue; Getsz(to, x); sz[x] += sz[to]; } sz[x]++; }但这只是第一步。点分治中最重要,最灵活,也最需要投入思考的点,其实是一开始提到的如何计算“所有经过某一个顶点的链对答案的贡献”。
P3806 【模板】点分治1
以这道最基本的模板题来说。相对而言比较经典的计算过程是这样的。
1、一次性收集所有链信息,存放在一个vector或数组中。
void Getdis(int x, int len, int f) { D.push_back(len); for (auto T : G[x]) { int to = T.to, v = T.v; if (vis[to] || to == f) continue; Getdis(to, len + v, x); } }2、利用数据结构(这里用的是桶)计算链对答案的贡献
void calc(int x, int len, int type) { D.clear(); Getdis(x, len, -1); for (int y : D) { if (y > M) continue; for (int i = 1; i <= m; i++) { if (q[i] - y < 0) continue; cnt[i] += type * bucket[q[i] - y]; } bucket[y]++; } for (int y : D) { if (y > M) continue; bucket[y]--; } }3、先计算子树内所有链的贡献(cal(x, 0, 1)),再利用容斥原理,去除不合法的链(calc(to, v, -1))
void solve(int x) { vis[x] = 1; calc(x, 0, 1); for (auto T : G[x]) { int to = T.to, v = T.v; if (vis[to]) continue; calc(to, v, -1); tot = sz[to], MN = INF; Getroot(to, -1); Getsz(rt, -1); solve(rt); } }关于第三点或许需要再举个例子说明一下。
若当前以A为根收集链信息,一共可以得到五条链:
A
A->B
A->B->D
A->B->E
A->C
而用于计算答案的链,实际上是从收集到的链中任选两条,组合成一条新链再进行计算的。比如选择A->B和A->C,实际上是B->A->C这条链,选择A和A->B实际上就是A->B这条链,这两条都是合法的。但如果选择的是A->B->D和A->B->E时,选择的其实是D->B->A->B->E,但D到E的简单路径是D->B->E,也就是这条链不存在,这是不合法的。实际上,同时经过A和B的两条链组合之后总是不合法的,而cal(to, v, -1)的意义就在于选出这样的不合法的链,并把它们对答案的贡献消除。
完整的代码贴在这里
#include <bits/stdc++.h> #define debug(x) cerr << #x << " : " << x << endl using namespace std; typedef long long LL; const int N = 1e4 + 5, M = 1e7 + 5, P = 1e2 + 5, INF = 0x3f3f3f3f; struct Edge { int to, v; }; int n, m, k, q[P], cnt[P], bucket[M]; vector<int> D; /**variables of tree divide*/ int tot, sz[N], MN, rt; bool vis[N]; vector<Edge> G[N]; int Getroot(int x, int f) { int sum = 0, mx = 0, tmp; for (auto T : G[x]) { int to = T.to; if (vis[to] || to == f) continue; tmp = Getroot(to, x); sum += tmp; mx = max(mx, tmp); } sum++; mx = max(mx, tot - sum); if (mx < MN) { MN = mx; rt = x; } return sum; } void Getsz(int x, int f) { sz[x] = 0; for (auto T : G[x]) { int to = T.to; if (vis[to] || to == f) continue; Getsz(to, x); sz[x] += sz[to]; } sz[x]++; } void Getdis(int x, int len, int f) { D.push_back(len); for (auto T : G[x]) { int to = T.to, v = T.v; if (vis[to] || to == f) continue; Getdis(to, len + v, x); } } void calc(int x, int len, int type) { D.clear(); Getdis(x, len, -1); for (int y : D) { if (y > M) continue; for (int i = 1; i <= m; i++) { if (q[i] - y < 0) continue; cnt[i] += type * bucket[q[i] - y]; } bucket[y]++; } for (int y : D) { if (y > M) continue; bucket[y]--; } } void solve(int x) { vis[x] = 1; calc(x, 0, 1); for (auto T : G[x]) { int to = T.to, v = T.v; if (vis[to]) continue; calc(to, v, -1); tot = sz[to], MN = INF; Getroot(to, -1); Getsz(rt, -1); solve(rt); } } int main() { cin >> n >> m; for (int i = 1, u, v, w; i <= n - 1; i++) { scanf("%d %d %d", &u, &v, &w); G[u].push_back({v, w}); G[v].push_back({u, w}); } for (int i = 1; i <= m; i++) scanf("%d", &q[i]); tot = n, MN = INF; Getroot(1, -1); Getsz(rt, -1); solve(rt); for (int i = 1; i <= m; i++) { printf("%s\n", cnt[i] ? "AYE" : "NAY"); } return 0; }至此,一个完整的点分治就结束了。
但这并不是唯一的计算贡献的方法。同样以这道题目来说,我们其实不需要一次性收集所有的链,而是可以先收集某个子树中的链,根据桶的信息修改答案,把收集到的链存入桶中,再去下一个子树收集链信息。如此一来就可以避免计算不合法的链,也就不需要利用容斥来消除它们。
void calc(int x) { judge[0] = 1; for (auto T : G[x]) { int to = T.to; if (vis[to]) continue; D.clear(); Getdis(to, T.v, x); for (int i : D) { for (int j = 1; j <= m; j++) if (query[j] >= i) ans[j] |= judge[query[j] - i]; } for (int i : D) { if (i < 10000010 && !judge[i]) { C.push_back(i); judge[i] = 1; } } } for (int i : C) judge[i] = 0; C.clear(); }我个人认为点分治只能说是一种思想,而不能称为一种算法,就在于每道题计算贡献的方法都不一样。因此,考虑是否要运用点分治解决问题,最关键的就是能否找到一个计算贡献的方法。许多计算方法都要套一个树状数组或者线段树之类的结构,这也都会对最终的时间复杂度产生影响。
P2634 [国家集训队]聪聪可可
最简单的点分治, O ( n ) O(n) O(n)收集信息, O ( 1 ) O(1) O(1)修改答案即可。
#include <bits/stdc++.h> #define debug(x) cerr << #x << " : " << x << endl using namespace std; typedef long long LL; const int N = 2e5 + 5, INF = 0x3f3f3f3f; struct Edge { int to, v; }; vector<Edge> G[N]; bool vis[N]; int n, tot, sz[N], MN, rt, a[N]; LL ans[3], cnt[3]; int add(int x, int y) {return x + y >= 3 ? x + y - 3 : x + y;} int Getroot(int x, int f) { int sum = 0, mx = 0, tmp; for (auto T : G[x]) { int to = T.to; if (vis[to] || to == f) continue; tmp = Getroot(to, x); sum += tmp; mx = max(mx, tmp); } sum++; mx = max(mx, tot - sum); if (mx < MN) { MN = mx; rt = x; } return sum; } void Getsz(int x, int f) { sz[x] = 0; for (auto T : G[x]) { int to = T.to; if (vis[to] || to == f) continue; Getsz(to, x); sz[x] += sz[to]; } sz[x]++; } void Getdis(int x, int len, int f) { cnt[len]++; for (auto T : G[x]) { int to = T.to, v = T.v; if (vis[to] || to == f) continue; Getdis(to, add(len, v), x); } } void calc(int x, int len, int type) { memset(cnt, 0, sizeof(cnt)); Getdis(x, len, -1); for (int i = 0; i <= 2; i++) { for (int j = 0; j <=2; j++) { ans[add(i, j)] += cnt[i] * cnt[j] * type; } } } void solve(int x) { vis[x] = 1; calc(x, 0, 1); for (auto T : G[x]) { int to = T.to, v = T.v; if (vis[to]) continue; calc(to, v, -1); tot = sz[to], MN = INF; Getroot(to, -1); Getsz(rt, -1); solve(rt); } } int main() { cin >> n; for (int i = 1, u, v, w; i <= n - 1; i++) { scanf("%d %d %d", &u, &v, &w); w %= 3; G[u].push_back({v, w}); G[v].push_back({u, w}); } tot = n, MN = INF; Getroot(1, -1); Getsz(rt, -1); solve(rt); LL gcd = __gcd(ans[0], ans[0] + ans[1] + ans[2]); printf("%lld/%lld", ans[0] / gcd, (ans[0] + ans[1] + ans[2]) / gcd); return 0; }可能因为要处理的信息太简单了,用树形dp一样可以过这道题,复杂度还更低。
#include <bits/stdc++.h> using namespace std; typedef long long LL; const int N = 2e4 + 5; int n; LL dp[N][3], sum[N][3], ans[3]; struct Edge { int to, v; }; vector<Edge> G[N]; int add(int x, int y) {return x + y >= 3 ? x + y - 3 : x + y;} int sub(int x, int y) {return x - y < 0 ? x - y + 3 : x - y;} void dfs(int x, int f) { dp[x][0]++; for (auto T : G[x]) { int to = T.to, v = T.v; if (to == f) continue; dfs(to, x); for (int i = 0; i <= 2; i++) { for (int j = 0; j <= 2; j++) { ans[add(i, j)] += dp[to][sub(i, v)] * dp[x][j]; } } for (int i = 0; i <= 2; i++) { dp[x][i] += dp[to][sub(i, v)]; } } } int main() { cin >> n; for (int i = 1, u, v, w; i <= n - 1; i++) { scanf("%d %d %d", &u, &v, &w); w %= 3; G[u].push_back({v,w}); G[v].push_back({u,w}); } dfs(1, -1); ans[0] *= 2; ans[1] *= 2; ans[2] *= 2; ans[0] += n; LL gcd = __gcd(ans[0], ans[0] + ans[1] + ans[2]); printf("%lld/%lld", ans[0] / gcd, (ans[0] + ans[1] + ans[2]) / gcd); return 0; }P4178 Tree
利用树状数组统计答案,非常经典。
#include <bits/stdc++.h> using namespace std; const int N = 4e4 + 5, M = 2e4 + 5, INF = 0x3f3f3f3f; struct Edge { int to, v; }; struct BIT { int t[N]; int lowbit(int x) {return x &(-x);} void add(int x, int v) { x += 1; while(x < N) { t[x] += v; x += lowbit(x); } } int ask(int x) { x += 1; int res = 0; while(x) { res += t[x]; x -= lowbit(x); } return res; } }bit; vector<Edge> G[N]; vector<int> D; bool vis[N]; int n, k, tot, sz[N], MN, rt, ans; int Getroot(int x, int f) { int sum = 0, mx = 0, tmp; for (auto T : G[x]) { int to = T.to; if (vis[to] || to == f) continue; tmp = Getroot(to, x); sum += tmp; mx = max(mx, tmp); } sum++; mx = max(mx, tot - sum); if (mx < MN) { MN = mx; rt = x; } return sum; } void Getsz(int x, int f) { sz[x] = 0; for (Edge T : G[x]) { int to = T.to; if (vis[to] || to == f) continue; Getsz(to, x); sz[x] += sz[to]; } sz[x]++; } void Getdis(int x,int len, int f) { D.push_back(len); for (Edge T : G[x]) { int to = T.to, v = T.v; if (vis[to] || to == f) continue; Getdis(to, len + v, x); } } int calc(int x, int len) { int res = 0; Getdis(x, len, -1); D.push_back(0); for (int y : D) { if (y > k) continue; res += bit.ask(k - y); bit.add(y, 1); } for (int y : D) { if (y > k) continue; bit.add(y, -1); } D.clear(); return res; } void solve(int x) { vis[x] = 1; ans += calc(x, 0); for (Edge T : G[x]) { int to = T.to, v = T.v; if (vis[to]) continue; ans -= calc(to, v); tot = sz[to], MN = INF; Getroot(to, -1); Getsz(rt, -1); solve(rt); } } int main() { cin >> n; for (int i = 1, u, v, w; i <= n - 1; i++) { scanf("%d %d %d", &u, &v, &w); G[u].push_back({v, w}); G[v].push_back({u, w}); } cin >> k; tot = n, MN = INF; Getroot(1, -1); Getsz(rt, -1); solve(rt); printf("%d\n", ans - n); return 0; }Constructing Ranches
需要利用排序、离散化、树状数组来计算贡献,可能还要卡卡常,比较清奇。
#include <bits/stdc++.h> using namespace std; typedef long long LL; const int N = 2e5 + 5, INF = 0x3f3f3f3f; struct BIT { int t[N]; int lowbit(int x) {return x &(-x);} void add(int x, int v) { x++; while(x < N) { t[x] += v; x += lowbit(x); } } int ask(int x) { x++; int res = 0; while(x) { res += t[x]; x -= lowbit(x); } return res; } }bit; struct Edge{ LL len; int mx; }; vector<int> G[N], dc, add; vector<Edge> D; unordered_map<LL, int> pos; bool vis[N]; int T, n, tot, sz[N], MN, rt, a[N]; LL ans; int Getroot(int x, int f) { int sum = 0, mx = 0, tmp; for (int to : G[x]) { if (vis[to] || to == f) continue; tmp = Getroot(to, x); sum += tmp; mx = max(mx, tmp); } sum++; mx = max(mx, tot - sum); if (mx < MN) { MN = mx; rt = x; } return sum; } void Getsz(int x, int f) { sz[x] = 0; for (int to : G[x]) { if (vis[to] || to == f) continue; Getsz(to, x); sz[x] += sz[to]; } sz[x]++; } void Getdis(int x,LL len, int mx, int f) { D.push_back({len, mx}); for (int to : G[x]) { if (vis[to] || to == f) continue; Getdis(to, len + a[to], max(mx, a[to]), x); } } LL calc(int x, LL len, int mx, int top) { dc.clear(); pos.clear(); D.clear(); add.clear(); if (top == 0) top = a[x]; LL res = 0; Getdis(x, len, mx, -1); sort(D.begin(), D.end(), [&](Edge u, Edge v) { return u.mx > v.mx; }); for (auto y : D) { dc.push_back(y.len - top); } sort(dc.begin(), dc.end()); dc.erase(unique(dc.begin(), dc.end()), dc.end()); for (auto y : D) { int idx = lower_bound(dc.begin(), dc.end(), y.len - top) - dc.begin(); res += bit.ask(idx); idx = lower_bound(dc.begin(), dc.end(), y.mx * 2 - y.len + 1) - dc.begin(); if (idx >= 0) { bit.add(idx, 1); add.push_back(idx); } } for (auto y : add) bit.add(y, -1); return res; } void solve(int x) { vis[x] = 1; ans += calc(x, a[x], a[x], 0); for (int to : G[x]) { if (vis[to]) continue; ans -= calc(to, a[x] + a[to], max(a[x], a[to]), a[x]); tot = sz[to], MN = INF; Getroot(to, -1); Getsz(rt, -1); solve(rt); } } int main() { cin >> T; while(T--) { cin >> n; for (int i = 1; i <= n; i++) { vis[i] = 0; scanf("%d", &a[i]); G[i].clear(); } ans = 0; for (int i = 1, u, v, w; i <= n - 1; i++) { scanf("%d %d", &u, &v); G[u].push_back(v); G[v].push_back(u); } tot = n, MN = INF; Getroot(1, -1); Getsz(rt, -1); solve(rt); printf("%lld\n", ans); } return 0; }