Commit 5cafd0ee authored by 24OI-bot's avatar 24OI-bot
Browse files

style: format markdown files with remark-lint

parent ae7ecae3
Loading
Loading
Loading
Loading
+142 −113
Original line number Diff line number Diff line
在学习本章前请确认你已经学习了 [矩阵](../math/matrix.md)  [树链剖分](../graph/hld.md) 

动态DP问题是猫锟在WC2018讲得黑科技,一般用来解决树上的DP问题,同时支持点权(边权)修改操作。
动态 DP 问题是猫锟在 WC2018 讲得黑科技,一般用来解决树上的 DP 问题,同时支持点权边权修改操作。

因为 NOIP2018D2T3 考了所以突然风靡 Oier 圈。

@@ -18,6 +18,7 @@
$$
C_{i,j}=max_{k=1}^{n}(A_{i,k}+B_{k,j})
$$

相当于将普通的矩阵乘法中的乘变为加,加变为 $max$ 操作

同时广义矩阵乘法满足结合律,所以可以使用矩阵快速幂
@@ -41,6 +42,7 @@ $$
假设我们已知 $g_{i,0/1}$ 那么有 DP 方程 $\begin{cases}f_{i,0}=g_{i,0}+max(f_{son_i,0},f_{son_i,1})\\f_{i,1}=g_{i,1}+f_{son_i,0}\end{cases}$ ,答案是 $max(f_{root,0},f_{root,1})$ 

可以构造出矩阵

$$
\left[
\begin{matrix}
@@ -59,16 +61,17 @@ f_{i,0}\\f_{i,1}
\end{matrix}
\right]
$$

注意,我们这里使用的是广义乘法规则

可以发现,修改操作时只需要修改 $g_{i,1}$ 和每条往上的重链即可

#### 具体思路

*   DFS预处理求出 $f_{i,0/1}$ 和 $g_{i,0/1}$ 
*   对这棵树进行树剖(注意,因为我们对一个点进行询问需要计算从该点到该点所在的重链末尾的区间矩阵乘,所以对于每一个点记录 $End_i$ 表示 $i$ 所在的重链末尾节点编号),每一条重链建立线段树,线段树维护 $g$ 矩阵和 $g$ 矩阵区间乘积
*   修改时首先修改 $g_{i,1}$ 和线段树中 $i$ 节点的矩阵,计算 $top_i$ 矩阵的变化量,修改到 $fa_{top_i}$ 矩阵
*   查询时就是1到其所在的重链末尾的区间乘,最后取一个 $max$ 即可
-   DFS 预处理求出 $f_{i,0/1}$ 和 $g_{i,0/1}$ 
-   对这棵树进行树剖注意,因为我们对一个点进行询问需要计算从该点到该点所在的重链末尾的区间矩阵乘,所以对于每一个点记录 $End_i$ 表示 $i$ 所在的重链末尾节点编号,每一条重链建立线段树,线段树维护 $g$ 矩阵和 $g$ 矩阵区间乘积
-   修改时首先修改 $g_{i,1}$ 和线段树中 $i$ 节点的矩阵,计算 $top_i$ 矩阵的变化量,修改到 $fa_{top_i}$ 矩阵
-   查询时就是 1 到其所在的重链末尾的区间乘,最后取一个 $max$ 即可

#### 代码

@@ -88,34 +91,37 @@ const int maxn = 500010;
const int INF = 0x3f3f3f3f;

int Begin[maxn], Next[maxn], To[maxn], e, n, m;
int size[maxn], son[maxn], top[maxn], fa[maxn], dis[maxn], p[maxn], id[maxn], End[maxn];
int size[maxn], son[maxn], top[maxn], fa[maxn], dis[maxn], p[maxn], id[maxn],
    End[maxn];
// p[i]表示i树剖后的编号,id[p[i]] = i
int cnt, tot, a[maxn], f[maxn][2];

struct matrix
{
struct matrix {
  int g[2][2];
  matrix() { memset(g, 0, sizeof(g)); }
  matrix operator*(const matrix &b) const  // 重载矩阵乘
  {
    matrix c;
        REP(i, 0, 1) REP(j, 0, 1) REP(k, 0, 1) c.g[i][j] = max(c.g[i][j], g[i][k] + b.g[k][j]);
    REP(i, 0, 1)
    REP(j, 0, 1) REP(k, 0, 1) c.g[i][j] = max(c.g[i][j], g[i][k] + b.g[k][j]);
    return c;
  }
} Tree[maxn], g[maxn];  // Tree[]是建出来的线段树,g[]是维护的每个点的矩阵

inline void PushUp(int root) { Tree[root] = Tree[lson] * Tree[rson]; }

inline void Build(int root, int l, int r)
{
    if ( l == r ) { Tree[root] = g[id[l]]; return ; }
inline void Build(int root, int l, int r) {
  if (l == r) {
    Tree[root] = g[id[l]];
    return;
  }
  int Mid = l + r >> 1;
    Build(lson, l, Mid); Build(rson, Mid + 1, r);
  Build(lson, l, Mid);
  Build(rson, Mid + 1, r);
  PushUp(root);
}

inline matrix Query(int root, int l, int r, int L, int R)
{
inline matrix Query(int root, int l, int r, int L, int R) {
  if (L <= l && r <= R) return Tree[root];
  int Mid = l + r >> 1;
  if (R <= Mid) return Query(lson, l, Mid, L, R);
@@ -124,60 +130,78 @@ inline matrix Query(int root, int l, int r, int L, int R)
  // 注意查询操作的书写
}

inline void Modify(int root, int l, int r, int pos)
{
    if ( l == r ) { Tree[root] = g[id[l]]; return ; }
inline void Modify(int root, int l, int r, int pos) {
  if (l == r) {
    Tree[root] = g[id[l]];
    return;
  }
  int Mid = l + r >> 1;
    if ( pos <= Mid ) Modify(lson, l, Mid, pos);
    else Modify(rson, Mid + 1, r, pos);
  if (pos <= Mid)
    Modify(lson, l, Mid, pos);
  else
    Modify(rson, Mid + 1, r, pos);
  PushUp(root);
}

inline void Update(int x, int val)
{
    g[x].g[1][0] += val - a[x]; a[x] = val;
inline void Update(int x, int val) {
  g[x].g[1][0] += val - a[x];
  a[x] = val;
  // 首先修改x的g矩阵
    while ( x ) 
    {
  while (x) {
    matrix last = Query(1, 1, n, p[top[x]], End[top[x]]);
    // 查询top[x]的原本g矩阵
        Modify(1, 1, n, p[x]); // 进行修改(x点的g矩阵已经进行修改但线段树上的未进行修改)
    Modify(1, 1, n,
           p[x]);  // 进行修改(x点的g矩阵已经进行修改但线段树上的未进行修改)
    matrix now = Query(1, 1, n, p[top[x]], End[top[x]]);
    // 查询top[x]的新g矩阵
    x = fa[top[x]];
        g[x].g[0][0] += max(now.g[0][0], now.g[1][0]) - max(last.g[0][0], last.g[1][0]); 
    g[x].g[0][0] +=
        max(now.g[0][0], now.g[1][0]) - max(last.g[0][0], last.g[1][0]);
    g[x].g[0][1] = g[x].g[0][0];
    g[x].g[1][0] += now.g[0][0] - last.g[0][0];
    // 根据变化量修改fa[top[x]]的g矩阵
  }
}

inline void add(int u, int v) { To[++ e] = v; Next[e] = Begin[u]; Begin[u] = e; }
inline void add(int u, int v) {
  To[++e] = v;
  Next[e] = Begin[u];
  Begin[u] = e;
}

inline void DFS1(int u)
{
    size[u] = 1; int Max = 0; f[u][1] = a[u];
    for ( int i = Begin[u]; i; i = Next[i] ) 
    {
        int v = To[i]; if ( v == fa[u] ) continue ;
        dis[v] = dis[u] + 1; fa[v] = u;
        DFS1(v); size[u] += size[v];
        if ( size[v] > Max ) { Max = size[v]; son[u] = v; }
inline void DFS1(int u) {
  size[u] = 1;
  int Max = 0;
  f[u][1] = a[u];
  for (int i = Begin[u]; i; i = Next[i]) {
    int v = To[i];
    if (v == fa[u]) continue;
    dis[v] = dis[u] + 1;
    fa[v] = u;
    DFS1(v);
    size[u] += size[v];
    if (size[v] > Max) {
      Max = size[v];
      son[u] = v;
    }
    f[u][1] += f[v][0];
    f[u][0] += max(f[v][0], f[v][1]);
    // DFS1过程中同时求出f[i][0/1]
  }
}

inline void DFS2(int u, int t)
{
    top[u] = t; p[u] = ++ cnt; id[cnt] = u; End[t] = cnt;
    g[u].g[1][0] = a[u]; g[u].g[1][1] = -INF;
inline void DFS2(int u, int t) {
  top[u] = t;
  p[u] = ++cnt;
  id[cnt] = u;
  End[t] = cnt;
  g[u].g[1][0] = a[u];
  g[u].g[1][1] = -INF;
  if (!son[u]) return;
  DFS2(son[u], t);
    for ( int i = Begin[u]; i; i = Next[i] ) 
    {
        int v = To[i]; if ( v == fa[u] || v == son[u] ) continue ;
  for (int i = Begin[u]; i; i = Next[i]) {
    int v = To[i];
    if (v == fa[u] || v == son[u]) continue;
    DFS2(v, v);
    g[u].g[0][0] += max(f[v][0], f[v][1]);
    g[u].g[1][0] += f[v][0];
@@ -186,20 +210,26 @@ inline void DFS2(int u, int t)
  g[u].g[0][1] = g[u].g[0][0];
}

int main()
{
int main() {
#ifndef ONLINE_JUDGE
  freopen("input.txt", "r", stdin);
  freopen("output.txt", "w", stdout);
#endif
  scanf("%d%d", &n, &m);
  REP(i, 1, n) scanf("%d", &a[i]);
    REP(i, 1, n - 1) { int u, v; scanf("%d%d", &u, &v); add(u, v); add(v, u); }
    dis[1] = 1; DFS1(1); DFS2(1, 1);
  REP(i, 1, n - 1) {
    int u, v;
    scanf("%d%d", &u, &v);
    add(u, v);
    add(v, u);
  }
  dis[1] = 1;
  DFS1(1);
  DFS2(1, 1);
  Build(1, 1, n);
    REP(i, 1, m)
    {
        int x, val; scanf("%d%d", &x, &val);
  REP(i, 1, m) {
    int x, val;
    scanf("%d%d", &x, &val);
    Update(x, val);
    matrix ans = Query(1, 1, n, 1, End[1]);  // 查询1所在重链的矩阵乘
    printf("%d\n", max(ans.g[0][0], ans.g[1][0]));
@@ -207,4 +237,3 @@ int main()
  return 0;
}
```