luogu2664 树上游戏

题目描述

lrb有一棵树,树的每个节点有个颜色。给一个长度为n的颜色序列,定义s(i,j) 为i 到j 的颜色数量。以及

$$sum_i=\sum_{i=1}^{n}s(i,j)$$

现在他想让你求出所有的sum[i]

输入格式

第一行为一个整数n,表示树节点的数量

第二行为n个整数,分别表示n个节点的颜色c[1],c[2]……c[n]

接下来n-1行,每行为两个整数x,y,表示x和y之间有一条边

输出格式

输出n行,第i行为sum[i]

样例输入

5
1 2 3 2 3
1 2
2 3
2 4
1 5

样例输出

10
9
11
9
12

说明

sum[1]=s(1,1)+s(1,2)+s(1,3)+s(1,4)+s(1,5)=1+2+3+2+2=10
sum[2]=s(2,1)+s(2,2)+s(2,3)+s(2,4)+s(2,5)=2+1+2+1+3=9
sum[3]=s(3,1)+s(3,2)+s(3,3)+s(3,4)+s(3,5)=3+2+1+2+3=11
sum[4]=s(4,1)+s(4,2)+s(4,3)+s(4,4)+s(4,5)=2+1+2+1+3=9
sum[5]=s(5,1)+s(5,2)+s(5,3)+s(5,4)+s(5,5)=2+3+3+3+1=12

对于40%的数据,n<=2000

对于100%的数据,1<=n,c[i]<=10^5


分析

常规解法是点分治。对于这种统计树上所有路径信息的问题,一般可以使用点分治解决。

有篇洛谷题解写得很到位,我就不再写了。

考虑点分治,对于当前分治中心,统计出它自己出发到分治块内的所有路径对自己答案的贡献,和经过它的路径对当前分治块内点的贡献。自己出发到分治块内的所有路径对自己答案的贡献很好求,现在考虑怎么求经过它的路径对当前分治块内点的贡献。 我们对于当前分治中心的每一个子树分别考虑,令$cnt[i]$为从分治中心出发的进入其他子树的所有路径中,包含颜色$i$的路径条数,$size$为除了该子树外当前分治块内所有的点的个数。那么,我们dfs这棵子树计算贡献,假设当前dfs到$x$,首先给$sum_x$加上$\sum cnt[i]$,即所有在其他子树中出现的颜色的贡献总和,然后计算$x$到分治中心的路径上颜色的贡献。 对于一个出现在分治中心到$x$的路径上的颜色$c$,它对$x$的贡献为$size-cnt[c]$,因为$c$已经在一些路径上出现,它现在能产生的额外贡献为它原来没有出现的路径条数。所以我们给$sum_x$还要加上$size-cnt[c]$。同时$c$也会对$x$的子树内所有点产生贡献,所以这个贡献要像标记一样往下传递,然后标记一下$c$的贡献已经被计算过,往下dfs时就不用再次计算了。 所以只需要对于当前分治中心求出$cnt$数组和每棵子树的$size$,进入一棵子树时减去子树自己内部对$cnt$产生的贡献。同时为了防止复杂度退化,我们不能对于所有颜色求$cnt$,要先统计一下当前分治块内有哪些颜色出现了,这样枚举块内所有颜色的复杂度才是O(分治块大小)。 复杂度$O(n\log n)$。 ——洛谷题解 @Salamander

另外一种解法是$O(n)$的,非常巧妙。

考虑某个颜色$c$产生的贡献。把树上颜色为$c$的点都删掉,树会分成若干个连通块。对于点$i$,只有不与$i$在同一连通块里的点会对$i$产生贡献。设$siz(c,i)$表示此时点$i$所在的连通块大小,那么颜色$c$对点$i$的贡献为$n-siz(c,i)$。特殊地,若$i$被删掉了,即$color(i)=c$,则颜色$c$对$i$的贡献为$n$。

现在我们的目标是,求出$ans(i)=n*colorCnt-\sum_{c\not=color(i)}siz(c,i)$,其中$colorCnt$为总颜色个数。

在我们从$fa(i)$移动到点$i$时,考虑$\sum_{c\not=color(i)}siz(c,i)$的改变,发现只需要修改$siz(color(fa(i)),i)$。所以我们先做一遍预处理,求出每个点的$siz(color(fa(i)),i)$即可。这个不难求,直接看代码就好了。

在实现中需要用一些树上差分的技巧。复杂度$O(n)$。

代码

点分治:

#include <bits/stdc++.h>
typedef long long ll;

const int N = 1e5+50;

struct Graph {
    struct Edge {
        int to, nxt;
    } e[N<<1];
    int top, head[N];

    Graph(): top(0) {
        memset(head, -1, sizeof head);
    }

    void add(int u, int v) {
        e[top] = (Edge){v, head[u]};
        head[u] = top++;
    }

    void add2(int u, int v) {
        add(u, v);
        add(v, u);
    }
} G;

int n;
bool vis[N];
int color[N];
int siz[N];
bool exist[N];
int buc[N];
std::vector<int> colors;
ll sum, otherSum;
int otherCnt;
int cnt[N], cntTmp[N], cntv[N];
ll ans[N];

void getSiz(int u, int fa) {
    siz[u] = 1;
    for (int i = G.head[u]; ~i; i = G.e[i].nxt) {
        int v = G.e[i].to;
        if (v == fa || vis[v]) continue;
        getSiz(v, u);
        siz[u] += siz[v];
    }
}

void getColors(int u, int fa, int cnt[N]) {
    if (!exist[color[u]]) {
        exist[color[u]] = true;
        colors.push_back(color[u]);
    }
    if (++buc[color[u]] == 1)
        cnt[color[u]] += siz[u];
    for (int i = G.head[u]; ~i; i = G.e[i].nxt) {
        int v = G.e[i].to;
        if (v == fa || vis[v]) continue;
        getColors(v, u, cnt);
    }
    buc[color[u]]--;
}

void getAns(int u, int fa, ll tag) {
    if (++buc[color[u]] == 1)
        tag += otherCnt - cnt[color[u]];
    ans[u] += sum + tag;
    for (int i = G.head[u]; ~i; i = G.e[i].nxt) {
        int v = G.e[i].to;
        if (v == fa || vis[v]) continue;
        getAns(v, u, tag);
    }
    buc[color[u]]--;
}

void calc(int u) {
    getSiz(u, 0);
    colors.clear();
    getColors(u, 0, cnt);
    std::vector<int> colorsTmp = colors;
    sum = 0;
    for (int i = 0; i < (int)colors.size(); i++) {
        int c = colors[i];
        exist[c] = false;
        sum += cnt[c];
        cntTmp[c] = cnt[c];
    }
    ll sumTmp = sum;
    ans[u] += sum;
    for (int i = G.head[u]; ~i; i = G.e[i].nxt) {
        int v = G.e[i].to;
        if (vis[v]) continue;

        colors.clear();
        buc[color[u]]++;
        exist[color[u]] = true;
        getColors(v, u, cntv);
        for (int j = 0; j < (int)colors.size(); j++) {
            int c = colors[j];
            exist[c] = false;
            cnt[c] -= cntv[c];
            sum -= cntv[c];
        }
        cnt[color[u]] -= siz[v];
        sum -= siz[v];
        buc[color[u]]--;
        exist[color[u]] = false;

        otherCnt = siz[u] - siz[v];
        getAns(v, u, 0);

        sum = sumTmp;
        cnt[color[u]] = cntTmp[color[u]];
        for (int j = 0; j < (int)colors.size(); j++) {
            int c = colors[j];
            cntv[c] = 0;
            cnt[c] = cntTmp[c];
        }
    }

    for (int i = 0; i < (int)colorsTmp.size(); i++)
        cnt[colorsTmp[i]] = 0;
}

int getRoot(int u, int fa, int tot) {
    static int mx[N];
    siz[u] = 1;
    mx[u] = 0;
    int rt = 0;
    for (int i = G.head[u]; ~i; i = G.e[i].nxt) {
        int v = G.e[i].to;
        if (v == fa || vis[v]) continue;
        int vrt = getRoot(v, u, tot);
        siz[u] += siz[v];
        mx[u] = std::max(mx[u], siz[v]);
        if (!rt || mx[vrt] < mx[rt])
            rt = vrt;
    }
    mx[u] = std::max(mx[u], tot - siz[u]);
    if (!rt || mx[u] < mx[rt])
        rt = u;
    return rt;
}

void dfs(int u, int tot) {
    vis[u] = true;
    calc(u);
    for (int i = G.head[u]; ~i; i = G.e[i].nxt) {
        int v = G.e[i].to;
        if (vis[v]) continue;
        int vsiz = (siz[v] < siz[u] ? siz[v] : tot - siz[u]);
        int vrt = getRoot(v, 0, vsiz);
        dfs(vrt, vsiz);
    }
}

int main() {
    scanf("%d", &n);
    for (int i = 1; i <= n; i++)
        scanf("%d", &color[i]);
    for (int i = 1, u, v; i < n; i++) {
        scanf("%d%d", &u, &v);
        G.add2(u, v);
    }
    int root = getRoot(1, 0, n);
    dfs(root, n);
    for (int i = 1; i <= n; i++)
        printf("%lld\n", ans[i]);
    return 0;
}

O(n)做法:

#include <bits/stdc++.h>
typedef long long ll;

const int N = 1e5+50;
const int M = 1e5+50;

struct Graph {
    struct Edge {
        int to, nxt;
    } e[N<<1];
    int top, head[N];

    Graph(): top(0) {
        memset(head, -1, sizeof head);
    }

    void add(int u, int v) {
        e[top] = (Edge){v, head[u]};
        head[u] = top++;
    }

    void add2(int u, int v) {
        add(u, v);
        add(v, u);
    }
} G;

int n, m, kind;
int color[N];
bool vis[M];
int siz[N];
int cnt[N], sub[M];
int clrCnt[M];
ll tot, ans[N];


void dfs(int u, int fa) {
    siz[u] = 1;
    int tmp = sub[color[fa]];
    for (int i = G.head[u]; ~i; i = G.e[i].nxt) {
        int v = G.e[i].to;
        if (v == fa) continue;
        dfs(v, u);
        siz[u] += siz[v];
    }
    sub[color[u]]++;
    if (fa) {
        cnt[u] = siz[u] - (sub[color[fa]] - tmp);
        sub[color[fa]] += cnt[u];
    }
}

void getAns(int u, int fa) {
    int tmp = clrCnt[color[fa]];
    tot += cnt[u] - clrCnt[color[fa]];
    clrCnt[color[fa]] = cnt[u];
    ans[u] = (ll)n*kind - (tot - clrCnt[color[u]]);
    for (int i = G.head[u]; ~i; i = G.e[i].nxt) {
        int v = G.e[i].to;
        if (v == fa) continue;
        getAns(v, u);
    }
    clrCnt[color[fa]] = tmp;
    tot -= cnt[u] - clrCnt[color[fa]];
}

int main() {
    scanf("%d", &n);
    for (int i = 1; i <= n; i++) {
        scanf("%d", &color[i]);
        if (!vis[color[i]]) {
            vis[color[i]] = true;
            kind++;
        }
        m = std::max(m, color[i]);
    }
    for (int i = 1, u, v; i < n; i++) {
        scanf("%d%d", &u, &v);
        G.add2(u, v);
    }
    dfs(1, 0);
    for (int i = 1; i <= m; i++) {
        if (!vis[i]) continue;
        tot += n - sub[i];
        clrCnt[i] = n - sub[i];
    }
    getAns(1, 0);
    for (int i = 1; i <= n; i++)
        printf("%lld\n", ans[i]);
    return 0;
}


转载请注明出处。


评论列表,共 0 条评论

    暂无评论

发表评论