[CF] E. Triangle Tree - Codeforces Round 1000 (Div. 2)

https://codeforces.com/contest/2063/problem/E

题目大意

给出一棵 $n \ (\le 3 \times 10^5)$ 个点的树,根为 $1$,边长全部为 $1$。设 $l = lca(u, v)$,定义 $f(u, v)$ 代表,能和 $dist(u, l)$ 与 $dist(v, l)$ 组成非退化的三角形的第三边的整数值个数。求

$$ \sum_{i = 1}^{n - 1}\sum_{j = i + 1}^{n} f(i, j) $$

简要题解

首先我们考虑如果有边长 $x$ 与 $y$,不妨设 $x \le y$,则第三边 $z$ 满足 $y - x < z < y + x$。也就是说其实 $z$ 有 $2x - 1$ 种不同选择。

因为都是通过 LCA 的,我们不妨枚举 LCA。然后因为我们永远只关注较小的长度,因此想到可以启发式合并,永远把最大长度小的合并到最大长度大的。

$cnt[u][i]$ 表示到 $u$ 点距离为 $i + 1$ 的儿子的数量。我们先用 $cnt[u]$ 统计儿子的情况,最后再把 $u$ 自己加入。(注意这个定义下好推导,但是需要在头部插入,实际的实现数组是反过来的)。

由此我们可以得出总的贡献是

$$ \sum_{i = 0} \sum_{j > i} cnt[u][i] \cdot cnt[v][j] \cdot (2(i + 1) - 1) + \sum_{i = 0} \sum_{j \ge i} cnt[v][i] \cdot cnt[u][j] \cdot (2(i + 1) - 1) $$

因为有一个后缀和,所以我们额外保存一个 $sz$ 来加速计算。

复杂度

$T$:$O(n \log n)$

$S$:$O(n \log n)$

代码实现

#include <bits/stdc++.h>
using namespace std;

int io_=[](){ ios::sync_with_stdio(false); cin.tie(nullptr); return 0; }();

using LL = long long;
using ULL = unsigned long long;
using LD = long double;
using PII = pair<int, int>;
using VI = vector<int>;
using MII = map<int, int>;

template<typename T> void cmin(T &x,const T &y) { if(y<x) x=y; }
template<typename T> void cmax(T &x,const T &y) { if(x<y) x=y; }
template<typename T> bool ckmin(T &x,const T &y) { 
    return y<x ? (x=y, true) : false; }
template<typename T> bool ckmax(T &x,const T &y) { 
    return x<y ? (x=y, true) : false; }
template<typename T> void cmin(T &x,T &y,const T &z) {// x<=y<=z 
    if(z<x) { y=x; x=z; } else if(z<y) y=z; }
template<typename T> void cmax(T &x,T &y,const T &z) {// x>=y>=z
    if(x<z) { y=x; x=z; } else if(y<z) y=z; }

// mt19937 rnd(chrono::system_clock::now().time_since_epoch().count());
// mt19937_64 rnd_64(chrono::system_clock::now().time_since_epoch().count());
template<typename Func> struct YCombinatorResult {
  Func func;
  template<typename T>
  explicit YCombinatorResult(T &&func) : func(std::forward<T>(func)) {}
  template<class ...Args> decltype(auto) operator()(Args &&...args) {
    return func(std::ref(*this), std::forward<Args>(args)...);
  }
};
template<typename Func> decltype(auto) y_comb(Func &&fun) {
  return YCombinatorResult<std::decay_t<Func>>(std::forward<Func>(fun));
}
/*
---------1---------2---------3---------4---------5---------6---------7---------
1234567890123456789012345678901234567890123456789012345678901234567890123456789
*/

/*
x, y -> y - x < z < x + y -> ans = (x - 1) * 2 + 1 
*/

void solve() {
  int n; cin >> n;
  vector<vector<int>> g(n);
  int u, v;
  for (int i = 1; i < n; i++) {
    cin >> u >> v; u--; v--;
    g[u].push_back(v);
    g[v].push_back(u);
  }

  vector<vector<int>> cnt(n);
  vector<int> sz(n);
  LL ans = 0;
  
  auto merge = [&](int u, int v) {
    if (cnt[u].size() < cnt[v].size()) {
      swap(cnt[u], cnt[v]);
      swap(sz[u], sz[v]);
    }

    int cur = sz[u];
    LL cur2 = 0;
    for (int i = cnt[v].size() - 1, j = cnt[u].size() - 1; i >= 0; i--, j--) {
      ans += 1LL * ((cnt[v].size() - i) * 2 - 1) * cnt[v][i] * cur;
      ans += cur2 * cnt[v][i];
      cur -= cnt[u][j];
      cur2 += 1LL * cnt[u][j] * ((cnt[v].size() - i) * 2 - 1);
      cnt[u][j] += cnt[v][i];
    }

    sz[u] += sz[v];
  };

  y_comb([&](auto dfs, int u, int fa) -> void {
    for (int v : g[u]) {
      if (v == fa) continue;
      dfs(v, u);
      merge(u, v);
    }
    cnt[u].push_back(1);
    sz[u]++;
  })(0, -1);

  cout << ans << '\n';
}

int main() {
  int t = 1; 
  cin >> t;
  while (t--) {
    solve();
  }
  return 0;
}

标答做法

我们可以回到刚才对于 $f$ 的研究。

$$ f(u, v) = 2\min(dist(u, l), dist(v, l)) - 1 $$

我们可以用到根的距离把 $dist$ 展开,即

$$ f(u, v) = 2\min(d[u] - d[l], d[v] - d[l]) - 1 $$

这样我们可以把 $d[l]$ 拿出来

$$ f(u, v) = 2\min(d[u], d[v]) - 2d[l] - 1 $$

这样我们可以完全将贡献分开算。

我们统计 $cntd$ 为各个深度节点数,以及 $sz$ 为子树中节点数。

首先,任意两个在同一深度的,必然在不同的子树,我们可以先算出这部分贡献

$$ \sum cntd[i] (cntd[i] - 1) / 2 \cdot (2i - 1) $$

再看 $d[u] \neq d[v]$ 的情况。不妨设 $d[u] < d[v]$ 则当且仅当 $v$ 在 $u$ 的子树中时没有贡献,其他时候 $f(u, v)$ 产生贡献。这部分相当于

$$ \sum_{u} \left(\left( \left(\sum_{i = d[u] + 1} ^ {\inf} cntd[i] \right) - sz[u] + 1 \right) \cdot (2d[u] - 1) \right) $$

注意到这里总是有一个后缀和,可以先把它预处理了。

最后再来看处理所有的 LCA 的部分。由于所有通过 $l$ 的对都需要减去 $2d[l]$。因此这部分是

$$ \sum_{u} \left(\sum_{v \ is \ son \ of \ u} (sz[u] - sz[v] - 1)sz[v] \right)$$

最后把三部分贡献合并起来即可。

复杂度

$T$:$O(n \log n)$

$S$:$O(n \log n)$

代码实现

#include <bits/stdc++.h>
using namespace std;

int io_=[](){ ios::sync_with_stdio(false); cin.tie(nullptr); return 0; }();

using LL = long long;
using ULL = unsigned long long;
using LD = long double;
using PII = pair<int, int>;
using VI = vector<int>;
using MII = map<int, int>;

template<typename T> void cmin(T &x,const T &y) { if(y<x) x=y; }
template<typename T> void cmax(T &x,const T &y) { if(x<y) x=y; }
template<typename T> bool ckmin(T &x,const T &y) { 
    return y<x ? (x=y, true) : false; }
template<typename T> bool ckmax(T &x,const T &y) { 
    return x<y ? (x=y, true) : false; }
template<typename T> void cmin(T &x,T &y,const T &z) {// x<=y<=z 
    if(z<x) { y=x; x=z; } else if(z<y) y=z; }
template<typename T> void cmax(T &x,T &y,const T &z) {// x>=y>=z
    if(x<z) { y=x; x=z; } else if(y<z) y=z; }

// mt19937 rnd(chrono::system_clock::now().time_since_epoch().count());
// mt19937_64 rnd_64(chrono::system_clock::now().time_since_epoch().count());
template<typename Func> struct YCombinatorResult {
  Func func;
  template<typename T>
  explicit YCombinatorResult(T &&func) : func(std::forward<T>(func)) {}
  template<class ...Args> decltype(auto) operator()(Args &&...args) {
    return func(std::ref(*this), std::forward<Args>(args)...);
  }
};
template<typename Func> decltype(auto) y_comb(Func &&fun) {
  return YCombinatorResult<std::decay_t<Func>>(std::forward<Func>(fun));
}
/*
---------1---------2---------3---------4---------5---------6---------7---------
1234567890123456789012345678901234567890123456789012345678901234567890123456789
*/

/*
x, y -> y - x < z < x + y -> ans = (x - 1) * 2 + 1 
*/

void solve() {
  int n; cin >> n;
  vector<vector<int>> g(n);
  int u, v;
  for (int i = 1; i < n; i++) {
    cin >> u >> v; u--; v--;
    g[u].push_back(v);
    g[v].push_back(u);
  }

  vector<int> cntd(n + 1);
  vector<int> sz(n);
  y_comb([&](auto dfs, int u, int fa, int d) -> void {
    sz[u] = 1;
    cntd[d]++;
    for (int v : g[u]) {
      if (v == fa) continue;
      dfs(v, u, d + 1);
      sz[u] += sz[v];
    }
  })(0, -1, 0);
  
  LL ans = 0;
  for (int i = n - 1; i >= 0; i--) {
    ans += 1LL * cntd[i] * (cntd[i] - 1) / 2 * (i * 2 - 1);
    cntd[i] += cntd[i + 1];
  }

  y_comb([&](auto dfs, int u, int fa, int d) -> void {
    for (int v : g[u]) {
      if (v == fa) continue;
      dfs(v, u, d + 1);

      ans -= 1LL * (sz[u] - 1 - sz[v]) * sz[v] * d;
    }
    ans += 1LL * (cntd[d + 1] - sz[u] + 1) * (d * 2 - 1);
  })(0, -1, 0);

  cout << ans << '\n';
}

int main() {
  int t = 1; 
  cin >> t;
  while (t--) {
    solve();
  }
  return 0;
}
Prev: [CF] F. Isomorphic Strings - Educational Codeforces Round 44 (Rated for Div. 2)
Next: [CF] C. Minimum Ties - Educational Codeforces Round 104 (Rated for Div. 2)