https://codeforces.com/contest/2063/problem/E
题目大意
给出一棵 $n \ (\le 3 \times 10^5)$ 个点的树,根为 $1$,边长全部为 $1$。设 $l = lca(u, v)$,定义 $f(u, v)$ 代表,能和 $dist(u, l)$ 与 $dist(v, l)$ 组成非退化的三角形的第三边的整数值个数。求
$$ \sum_{i = 1}^{n - 1}\sum_{j = i + 1}^{n} f(i, j) $$简要题解
首先我们考虑如果有边长 $x$ 与 $y$,不妨设 $x \le y$,则第三边 $z$ 满足 $y - x < z < y + x$。也就是说其实 $z$ 有 $2x - 1$ 种不同选择。
因为都是通过 LCA 的,我们不妨枚举 LCA。然后因为我们永远只关注较小的长度,因此想到可以启发式合并,永远把最大长度小的合并到最大长度大的。
$cnt[u][i]$ 表示到 $u$ 点距离为 $i + 1$ 的儿子的数量。我们先用 $cnt[u]$ 统计儿子的情况,最后再把 $u$ 自己加入。(注意这个定义下好推导,但是需要在头部插入,实际的实现数组是反过来的)。
由此我们可以得出总的贡献是
$$ \sum_{i = 0} \sum_{j > i} cnt[u][i] \cdot cnt[v][j] \cdot (2(i + 1) - 1) + \sum_{i = 0} \sum_{j \ge i} cnt[v][i] \cdot cnt[u][j] \cdot (2(i + 1) - 1) $$因为有一个后缀和,所以我们额外保存一个 $sz$ 来加速计算。
复杂度
$T$:$O(n \log n)$
$S$:$O(n \log n)$
代码实现
#include <bits/stdc++.h>
using namespace std;
int io_=[](){ ios::sync_with_stdio(false); cin.tie(nullptr); return 0; }();
using LL = long long;
using ULL = unsigned long long;
using LD = long double;
using PII = pair<int, int>;
using VI = vector<int>;
using MII = map<int, int>;
template<typename T> void cmin(T &x,const T &y) { if(y<x) x=y; }
template<typename T> void cmax(T &x,const T &y) { if(x<y) x=y; }
template<typename T> bool ckmin(T &x,const T &y) {
return y<x ? (x=y, true) : false; }
template<typename T> bool ckmax(T &x,const T &y) {
return x<y ? (x=y, true) : false; }
template<typename T> void cmin(T &x,T &y,const T &z) {// x<=y<=z
if(z<x) { y=x; x=z; } else if(z<y) y=z; }
template<typename T> void cmax(T &x,T &y,const T &z) {// x>=y>=z
if(x<z) { y=x; x=z; } else if(y<z) y=z; }
// mt19937 rnd(chrono::system_clock::now().time_since_epoch().count());
// mt19937_64 rnd_64(chrono::system_clock::now().time_since_epoch().count());
template<typename Func> struct YCombinatorResult {
Func func;
template<typename T>
explicit YCombinatorResult(T &&func) : func(std::forward<T>(func)) {}
template<class ...Args> decltype(auto) operator()(Args &&...args) {
return func(std::ref(*this), std::forward<Args>(args)...);
}
};
template<typename Func> decltype(auto) y_comb(Func &&fun) {
return YCombinatorResult<std::decay_t<Func>>(std::forward<Func>(fun));
}
/*
---------1---------2---------3---------4---------5---------6---------7---------
1234567890123456789012345678901234567890123456789012345678901234567890123456789
*/
/*
x, y -> y - x < z < x + y -> ans = (x - 1) * 2 + 1
*/
void solve() {
int n; cin >> n;
vector<vector<int>> g(n);
int u, v;
for (int i = 1; i < n; i++) {
cin >> u >> v; u--; v--;
g[u].push_back(v);
g[v].push_back(u);
}
vector<vector<int>> cnt(n);
vector<int> sz(n);
LL ans = 0;
auto merge = [&](int u, int v) {
if (cnt[u].size() < cnt[v].size()) {
swap(cnt[u], cnt[v]);
swap(sz[u], sz[v]);
}
int cur = sz[u];
LL cur2 = 0;
for (int i = cnt[v].size() - 1, j = cnt[u].size() - 1; i >= 0; i--, j--) {
ans += 1LL * ((cnt[v].size() - i) * 2 - 1) * cnt[v][i] * cur;
ans += cur2 * cnt[v][i];
cur -= cnt[u][j];
cur2 += 1LL * cnt[u][j] * ((cnt[v].size() - i) * 2 - 1);
cnt[u][j] += cnt[v][i];
}
sz[u] += sz[v];
};
y_comb([&](auto dfs, int u, int fa) -> void {
for (int v : g[u]) {
if (v == fa) continue;
dfs(v, u);
merge(u, v);
}
cnt[u].push_back(1);
sz[u]++;
})(0, -1);
cout << ans << '\n';
}
int main() {
int t = 1;
cin >> t;
while (t--) {
solve();
}
return 0;
}
标答做法
我们可以回到刚才对于 $f$ 的研究。
$$ f(u, v) = 2\min(dist(u, l), dist(v, l)) - 1 $$我们可以用到根的距离把 $dist$ 展开,即
$$ f(u, v) = 2\min(d[u] - d[l], d[v] - d[l]) - 1 $$这样我们可以把 $d[l]$ 拿出来
$$ f(u, v) = 2\min(d[u], d[v]) - 2d[l] - 1 $$这样我们可以完全将贡献分开算。
我们统计 $cntd$ 为各个深度节点数,以及 $sz$ 为子树中节点数。
首先,任意两个在同一深度的,必然在不同的子树,我们可以先算出这部分贡献
$$ \sum cntd[i] (cntd[i] - 1) / 2 \cdot (2i - 1) $$再看 $d[u] \neq d[v]$ 的情况。不妨设 $d[u] < d[v]$ 则当且仅当 $v$ 在 $u$ 的子树中时没有贡献,其他时候 $f(u, v)$ 产生贡献。这部分相当于
$$ \sum_{u} \left(\left( \left(\sum_{i = d[u] + 1} ^ {\inf} cntd[i] \right) - sz[u] + 1 \right) \cdot (2d[u] - 1) \right) $$注意到这里总是有一个后缀和,可以先把它预处理了。
最后再来看处理所有的 LCA 的部分。由于所有通过 $l$ 的对都需要减去 $2d[l]$。因此这部分是
$$ \sum_{u} \left(\sum_{v \ is \ son \ of \ u} (sz[u] - sz[v] - 1)sz[v] \right)$$最后把三部分贡献合并起来即可。
复杂度
$T$:$O(n \log n)$
$S$:$O(n \log n)$
代码实现
#include <bits/stdc++.h>
using namespace std;
int io_=[](){ ios::sync_with_stdio(false); cin.tie(nullptr); return 0; }();
using LL = long long;
using ULL = unsigned long long;
using LD = long double;
using PII = pair<int, int>;
using VI = vector<int>;
using MII = map<int, int>;
template<typename T> void cmin(T &x,const T &y) { if(y<x) x=y; }
template<typename T> void cmax(T &x,const T &y) { if(x<y) x=y; }
template<typename T> bool ckmin(T &x,const T &y) {
return y<x ? (x=y, true) : false; }
template<typename T> bool ckmax(T &x,const T &y) {
return x<y ? (x=y, true) : false; }
template<typename T> void cmin(T &x,T &y,const T &z) {// x<=y<=z
if(z<x) { y=x; x=z; } else if(z<y) y=z; }
template<typename T> void cmax(T &x,T &y,const T &z) {// x>=y>=z
if(x<z) { y=x; x=z; } else if(y<z) y=z; }
// mt19937 rnd(chrono::system_clock::now().time_since_epoch().count());
// mt19937_64 rnd_64(chrono::system_clock::now().time_since_epoch().count());
template<typename Func> struct YCombinatorResult {
Func func;
template<typename T>
explicit YCombinatorResult(T &&func) : func(std::forward<T>(func)) {}
template<class ...Args> decltype(auto) operator()(Args &&...args) {
return func(std::ref(*this), std::forward<Args>(args)...);
}
};
template<typename Func> decltype(auto) y_comb(Func &&fun) {
return YCombinatorResult<std::decay_t<Func>>(std::forward<Func>(fun));
}
/*
---------1---------2---------3---------4---------5---------6---------7---------
1234567890123456789012345678901234567890123456789012345678901234567890123456789
*/
/*
x, y -> y - x < z < x + y -> ans = (x - 1) * 2 + 1
*/
void solve() {
int n; cin >> n;
vector<vector<int>> g(n);
int u, v;
for (int i = 1; i < n; i++) {
cin >> u >> v; u--; v--;
g[u].push_back(v);
g[v].push_back(u);
}
vector<int> cntd(n + 1);
vector<int> sz(n);
y_comb([&](auto dfs, int u, int fa, int d) -> void {
sz[u] = 1;
cntd[d]++;
for (int v : g[u]) {
if (v == fa) continue;
dfs(v, u, d + 1);
sz[u] += sz[v];
}
})(0, -1, 0);
LL ans = 0;
for (int i = n - 1; i >= 0; i--) {
ans += 1LL * cntd[i] * (cntd[i] - 1) / 2 * (i * 2 - 1);
cntd[i] += cntd[i + 1];
}
y_comb([&](auto dfs, int u, int fa, int d) -> void {
for (int v : g[u]) {
if (v == fa) continue;
dfs(v, u, d + 1);
ans -= 1LL * (sz[u] - 1 - sz[v]) * sz[v] * d;
}
ans += 1LL * (cntd[d + 1] - sz[u] + 1) * (d * 2 - 1);
})(0, -1, 0);
cout << ans << '\n';
}
int main() {
int t = 1;
cin >> t;
while (t--) {
solve();
}
return 0;
}
Next: [CF] C. Minimum Ties - Educational Codeforces Round 104 (Rated for Div. 2)