https://codeforces.com/contest/1156/problem/D
题目大意
给出一棵 $n$ 个点的树,边上有 $0$ 或 $1$ 的权值。问有多少个 $u$ 到 $v$ 的简单路径($u \neq v$)满足该路径通过任何边权为 $1$ 的边后不再经过边权为 $0$ 的边。
$2 \le n \le 2 \times 10^5$。
简要题解
如果我们固定了某个根,让其作为起点,那么我们相当于先 dfs 出所有根能通过 $0$ 到达的点,再以这些点为起点 dfs 出所有通过 $1$ 的边能到达的点,则这样的点的集合就是对于根而言所有合法的 $v$。
这个过程我们也可以反着通过 dp 来做。
$dp[0][u]$ 表示之后还可以通过 $0$ 的边所能到达的 $u$ 子树中所有点的数量。 $dp[1][u]$ 表示之后只可以通过 $1$ 的边所能到达的 $u$ 子树中所有点的数量。
$dp[0][u] = \sum_{v \ is \ son \ of \ u} dp[w][v]$。
$dp[1][u] = \sum_{v in V} dp[1][v]$。其中 $V$ 表示 $e(u, v) = 1$ 的 $u$ 的儿子的集合。
那么下一步就是发现某个子树贡献是很容易从当前节点中加上或减去的,也就是说我们可以很容易的进行换根。于是就做完了。
感觉上点分治之类的做法应该也是可以做的,但应该不如换根 dp 写法简单。
复杂度
$T$:$O(n)$
$S$:$O(n)$
代码实现
#include <bits/stdc++.h>
using namespace std;
int io_=[](){ ios::sync_with_stdio(false); cin.tie(nullptr); return 0; }();
using LL = long long;
using ULL = unsigned long long;
using LD = long double;
using PII = pair<int, int>;
using VI = vector<int>;
using MII = map<int, int>;
template<typename T> void cmin(T &x,const T &y) { if(y<x) x=y; }
template<typename T> void cmax(T &x,const T &y) { if(x<y) x=y; }
template<typename T> bool ckmin(T &x,const T &y) {
return y<x ? (x=y, true) : false; }
template<typename T> bool ckmax(T &x,const T &y) {
return x<y ? (x=y, true) : false; }
template<typename T> void cmin(T &x,T &y,const T &z) {// x<=y<=z
if(z<x) { y=x; x=z; } else if(z<y) y=z; }
template<typename T> void cmax(T &x,T &y,const T &z) {// x>=y>=z
if(x<z) { y=x; x=z; } else if(y<z) y=z; }
// mt19937 rnd(chrono::system_clock::now().time_since_epoch().count());
// mt19937_64 rnd_64(chrono::system_clock::now().time_since_epoch().count());
template<typename Func> struct YCombinatorResult {
Func func;
template<typename T>
explicit YCombinatorResult(T &&func) : func(std::forward<T>(func)) {}
template<class ...Args> decltype(auto) operator()(Args &&...args) {
return func(std::ref(*this), std::forward<Args>(args)...);
}
};
template<typename Func> decltype(auto) y_comb(Func &&fun) {
return YCombinatorResult<std::decay_t<Func>>(std::forward<Func>(fun));
}
/*
---------1---------2---------3---------4---------5---------6---------7---------
1234567890123456789012345678901234567890123456789012345678901234567890123456789
*/
void solve() {
int n; cin >> n;
vector<vector<PII>> g(n);
int u, v, w;
for (int i = 1; i < n; i++) {
cin >> u >> v >> w; u--; v--;
g[u].push_back({v, w});
g[v].push_back({u, w});
}
vector<vector<int>> dp(2, vector<int>(n));
y_comb([&](auto dfs, int u, int fa) -> void {
dp[0][u] = dp[1][u] = 1;
for (auto [v, w] : g[u]) {
if (v == fa) continue;
dfs(v, u);
if (!w) {
dp[0][u] += dp[0][v];
} else {
dp[0][u] += dp[1][v];
dp[1][u] += dp[1][v];
}
}
})(0, -1);
LL ans = 0;
y_comb([&](auto dfs, int u, int fa) -> void {
ans += dp[0][u];
for (auto [v, w] : g[u]) {
if (v == fa) continue;
int dpu0 = dp[0][u];
int dpu1 = dp[1][u];
int dpv0 = dp[0][v];
int dpv1 = dp[1][v];
if (!w) {
dp[0][u] -= dpv0;
dp[0][v] += dp[0][u];
} else {
dp[0][u] -= dpv1;
dp[1][u] -= dpv1;
dp[0][v] += dp[1][u];
dp[1][v] += dp[1][u];
}
dfs(v, u);
dp[0][u] = dpu0;
dp[1][u] = dpu1;
dp[0][v] = dpv0;
dp[1][v] = dpv1;
}
})(0, -1);
ans -= n;
cout << ans << '\n';
}
int main() {
int t = 1;
// cin >> t;
while (t--) {
solve();
}
return 0;
}
Next: [CF] C. Match Points - Educational Codeforces Round 64 (Rated for Div. 2)