Hoi_koro memo

Editorialと異なる解法をしたら更新するblog

AtCoder Educational DP Contest V Subtree

atcoder.jp

問題概要

$ N $頂点の木(辺は$(x_1,y_1),\dots,(x_{N-1},y_{N-1})$)の頂点を白と黒に塗り分ける。ただし、黒い頂点どうしは相互に黒い頂点のみを通って到達可能でなければならない。そのような塗り方のうち、頂点$i$が黒く塗られるものは何通りあるか。$1\leq i\leq n$のそれぞれについて、$ M $で割った余りを答えよ。

制約

  • $1\leq N\leq 10^5$
  • $2\leq M\leq 10^9$

解法

特定の頂点$i$について答えるなら、頂点$i$を根とした木DPによって$\mathcal O(N)$で答えられる。この問題ではすべての$i$について答えなければならないが、全方位木DPをすればやはり$\mathcal O(N)$でできる。

辺で結ばれた頂点$i,j$について、

$dp[i][j]=$頂点$i$から見た$j$の部分木(色付きの部分)を塗る組み合わせのうち頂点$j$を黒く塗る方法の数 

とする。実装上はこの値を辺の情報として持たせている。 f:id:hiko3r:20190107134524p:plain

1回目のDFS

適当な頂点を根とする。頂点$v$の親を$p$,子を$c_1,\dots,c_r$ ($ r $は$ v $の次数-1)とする。頂点$v$を黒く塗るとき、頂点$c_i$以下の部分木を塗る方法は、

  • $c_i$を黒く塗る:$ dp[v][c_i] $通り
  • $c_i$を白く塗る:1通り

となる。よってこのときの遷移は、

\begin{align} dp[p][v] = \prod_{k\in\{1,\dots,r\}}(dp[v][c_k]+1) \end{align}

2回目のDFS

頂点$v$に隣接する頂点を$u_1,\dots,u_q$ ($ q $は$ v $の次数)とする。$dp[v][u_1],\dots,dp[v][u_q]$が求まっていれば、

\begin{align} dp[u_i][v] = \prod_{k\in\{1,\dots,q\}-\{i\}}(dp[v][u_k]+1) \end{align}

で求められる。これを求めるときに \begin{align} &\prod_{k\in\{1,\dots,q\}-\{i\}}(dp[v][u_k]+1)\\ =& \frac{\prod_{k\in\{1,\dots,q\}}(dp[v][u_k]+1)}{dp[v][u_i]+1}\\ =&\prod_{k\in\{1,\dots,q\}}(dp[v][u_k]+1)\times (dp[v][u_i]+1)^{-1}~({\rm mod}~ M) \end{align} として計算しようとすると、$(dp[v][u_i]+1)=0 ~({\rm mod}~M)$のときに$0^{-1}~({\rm mod}~M)$が現れてしまい破綻する。この問題を回避するには、

\begin{align} &\prod_{k\in\{1,\dots,q\}-\{i\}}(dp[v][u_k]+1)\\ =&\prod_{k\in\{1,\dots,i-1\}}(dp[v][u_k]+1)\times \prod_{k\in\{i+1,\dots,q\}}(dp[v][u_k]+1) \end{align}

と考えればよい。前後両側からの累積積を前計算しておけば、$\prod_{k\in\{1,\dots,i-1\}}(dp[v][u_k]+1),\prod_{k\in\{i+1,\dots,q\}}(dp[v][u_k]+1)$を除算なしで求められる。

解答例

#include <bits/stdc++.h>

using LL = long long;

namespace Problem {
using namespace std;

class Solver {
 public:
  int n;
  LL mod;
  vector<LL> ans;
  struct Edge {
    int to, rev;
    LL val;
  };
  vector<vector<Edge>> t;

  Solver(LL n, LL m) : n(n), mod(m), ans(n), t(n){};

  void solve() {
    for (int i = 0; i < n - 1; ++i) {
      int a, b;
      cin >> a >> b;
      --a;
      --b;
      t[a].push_back({b, (int)t[b].size(), 1});
      t[b].push_back({a, (int)t[a].size() - 1, 1});
    }
    if (n == 1) {
      cout << 1 << endl;
      return;
    }
    dfs(0);
    dfs2(0);
    for (int i = 0; i < n; ++i) {
      cout << ans[i] << endl;
    }
  }
  void dfs(int v, int p = -1) {
    LL res = 1;
    int par;
    for (int i = 0; i < (int)t[v].size(); ++i) {
      if (t[v][i].to != p) {
        dfs(t[v][i].to, v);
        res *= t[v][i].val + 1;
        res %= mod;
      } else {
        par = i;
      }
    }
    if (p == -1) {
      ans[v] = res;
    } else {
      t[p][t[v][par].rev].val = res;
    }
  }
  void dfs2(int v, int p = -1) {
    LL res = 1;
    for (int i = 0; i < t[v].size(); ++i) {
      res *= t[v][i].val + 1;
      res %= mod;
    }
    ans[v] = res;

    //累積積の計算
    vector<LL> prodl(t[v].size(), 1), prodr(t[v].size(), 1);
    prodl[0] = (t[v][0].val + 1) % mod;
    for (int i = 1; i < t[v].size(); ++i) {
      prodl[i] = prodl[i - 1] * (t[v][i].val + 1) % mod;
    }
    prodr[t[v].size() - 1] = (t[v][t[v].size() - 1].val + 1) % mod;
    for (int i = (int)t[v].size() - 2; i >= 0; --i) {
      prodr[i] = prodr[i + 1] * (t[v][i].val + 1) % mod;
    }

    for (int i = 0; i < t[v].size(); ++i) {
      if (t[v][i].to != p) {
        if (i > 0) {
          t[t[v][i].to][t[v][i].rev].val *= prodl[i - 1];
          t[t[v][i].to][t[v][i].rev].val %= mod;
        }
        if (i < (int)t[v].size() - 1) {
          t[t[v][i].to][t[v][i].rev].val *= prodr[i + 1];
          t[t[v][i].to][t[v][i].rev].val %= mod;
        }
        dfs2(t[v][i].to, v);
      }
    }
  }
};
}  // namespace Problem

int main() {
  std::cin.tie(0);
  std::ios_base::sync_with_stdio(false);
  // std::cout << std::fixed << std::setprecision(12);
  long long n = 0, m;
  std::cin >> n >> m;

  Problem::Solver sol(n, m);
  sol.solve();
  return 0;
}

note

Educational DP Contest のうち、自分にとって最も教育的だった問題。

コンテスト中は「 $ M $が素数でなくて面倒だな。合成数でmodを計算させる教育的問題なのかな?」と思っていた。だが、この問題で注意すべきなのは2回目のDFSで$0^{-1}$をかける計算が発生しないようにすること。それは$ M $の値が素数でも合成数でも同じ。