loj2537 「PKUWC2018」Minimax

题目传送门

传送门


分析

先把权值离散化。题目要求的是对于$1\le i\le m$($m$为权值个数),树根最终权值为$i$的概率$f(i)$。

递推顺序肯定是从儿子推到父亲。先推个$O(n^2)$的方程:

记左子树的$f$数组为$a$,右子树的$f$数组为$b$.

记当前节点取左右子树最大值的概率为$pmax$,取最小值的概率为$pmin$。

若权值$i$在左子树中出现:

$$f(i)=a(i)\cdot [pmin\cdot \sum_{j>i}b_j+pmax\cdot \sum_{j<i}b_j]$$

右子树中同理,只需把$a,b$反过来。

预处理前后缀和可以做到$O(n^2)$。

考虑如何优化这个方程,因为是若干个单点的信息逐渐合并起来,可以考虑线段树合并。

线段树上的某个节点表示权值在$[L,R]$之间的概率。记左子树的线段树为$A$,右子树为$B$。

设我们当前走到线段树节点$u$,区间为$[L,R]$。

若只有A中存在权值在$[L,R]$之间的点,那么我们需要把这个点及子树的概率都乘上$pmin\cdot \sum_{j>i}b_j+pmax\cdot \sum_{j<i}b_j$,这个值在线段树从根走到$u$的过程中不难统计。只有$B$中存在时同理。

若$A,B$中都存在权值在$[L,R]$之间的点,那么先往左右子树递归合并,然后$f(u)=f(l(u))+f(r(u))$。

代码

#include <bits/stdc++.h>

typedef long long ll;

const int N = 3e5+50;
const int P = 998244353;
const int IW = 796898467;
const int SZ = N * 20;

int n;
int son[N][2];
int w[N], p[N];
int dif, disc[N];
int top, ls[SZ], rs[SZ], sum[SZ], mul[SZ];
int pmx;

int build(int L, int R, int pos) {
    int u = ++top;
    sum[u] = mul[u] = 1;
    if (L == R) return u;
    int mid = L + R >> 1;
    if (pos <= mid) ls[u] = build(L, mid, pos);
    else rs[u] = build(mid+1, R, pos);
    return u;
}

void putMul(int u, int val) {
    mul[u] = (ll)mul[u] * val % P;
    sum[u] = (ll)sum[u] * val % P;
}

void pushDown(int u) {
    if (mul[u] == 1) return;
    putMul(ls[u], mul[u]);
    putMul(rs[u], mul[u]);
    mul[u] = 1;
}

void pushUp(int u) {
    sum[u] = (sum[ls[u]] + sum[rs[u]]) % P;
}

int merge(int x, int y, int tagx, int tagy) {
    if (!y) {
        putMul(x, tagx);
        return x;
    }
    if (!x) {
        putMul(y, tagy);
        return y;
    }
    pushDown(x);
    pushDown(y);
    int sx0 = sum[ls[x]], sx1 = sum[rs[x]], sy0 = sum[ls[y]], sy1 = sum[rs[y]];
    ls[x] = merge(ls[x], ls[y], (tagx + (ll)(1-pmx) * sy1) % P, (tagy + (ll)(1-pmx) * sx1) % P);
    rs[x] = merge(rs[x], rs[y], (tagx + (ll)pmx * sy0) % P, (tagy + (ll)pmx * sx0) % P);
    pushUp(x);
    return x;
}

int dfs(int u) {
    if (!son[u][0])
        return build(1, dif, w[u]);
    int lu = dfs(son[u][0]);
    if (!son[u][1]) return lu;
    int ru = dfs(son[u][1]);
    pmx = p[u];
    return merge(lu, ru, 0, 0);
}

int solve(int u, int L, int R) {
    if (L == R)
        return (ll)L * disc[L] % P * sum[u] % P * sum[u] % P;
    pushDown(u);
    int mid = L + R >> 1;
    return (solve(ls[u], L, mid) + solve(rs[u], mid+1, R)) % P;
}

int main() {
    scanf("%d%*d", &n);
    for (int i = 2, fa; i <= n; i++) {
        scanf("%d", &fa);
        son[fa][(bool)son[fa][0]] = i;
    }
    for (int i = 1, r; i <= n; i++) {
        scanf("%d", &r);
        if (!son[i][0]) {
            w[i] = r;
            disc[++dif] = w[i];
        }
        else p[i] = (ll)r * IW % P;
    }
    std::sort(disc+1, disc+dif+1);
    for (int i = 1; i <= n; i++) {
        if (!son[i][0]) {
            w[i] = std::lower_bound(disc+1, disc+dif+1, w[i]) - disc;
        }
    }
    int root = dfs(1);
    printf("%d\n", (solve(root, 1, dif) + P) % P);
    return 0;
}


转载请注明出处。


评论列表,共 0 条评论

    暂无评论

发表评论