算法学习笔记:WQS二分

引例

luogu2619 [国家集训队2]Tree I

给你一个无向带权连通图,每条边是黑色或白色。让你求一棵最小权的恰好有$need$条白色边的生成树。

题目保证有解,点数与边数均为$10^5$级别。

思路

设选$x$条白边的最小代价为$f(x)$,不难看出$f(x)$是下凸函数,而题目要求的是$f(need)$。

$f(x)$的最小值很好求,直接跑一边MST(最小生成树)即可,设$f(x)$的最小值在$f(t)$处取到,也就是直接跑的MST中有$t$条白边。

最小值$f(t)$虽然容易求,但看起来好像并没有什么用。不妨设$t>need$,那么我们肯定想让算法中少取几条白边。

一个很简单的想法是,增加一个奖惩机制,把所有白边的权值增大$k$,再跑MST。当$k$设得够大时,最小值点$t$肯定会左移。

设改变权值后,MST中取到$x$条白边的最小代价为$g(x)$,可以发现$g(x)=f(x)+k\cdot x$。因为在取$x$条白边的方案中,$f(x)$的方案是最优的,并且在改变权值时,这些方案得到的奖惩值相同,所以$f(x)$的方案仍然最优。

假设我们运气很好,随便设了一个$k$,求出来的最小值$g(t)$刚好满足$t=need$。于是我们可以直接输出$f(t)=g(t)-k\cdot x$(把白边的权值还原)。

但是写算法肯定不能靠运气,现在的问题是如何找到一个$k$,使得$g(x)$最小值在$need$处取到。这时候就需要用到$f(x)$的特性:下凸。

我们可以用数形结合的方式理解一下算法过程:有个下凸壳$f(x)$,我们把它减去一条直线$y=kx$,然后求新函数的最小值,并且我们希望这个最小值在$need$处取到。

可以发现最小值点即为用一条斜率为$w$的直线去切$f(x)$的切点。那么显然随着斜率$k$的递增,切点会单调右移。于是我们可以二分$k$,直到切点即最小值点$t=need$时输出答案。

还有一些细节问题:一条直线同时切到$need$与相邻共线的若干个点的情况。此时,我们不妨规定在黑白边权值相等时优先取白边,这样我们枚举到$w$时,得到的切点一定是这条直线上最右边的点。那么我们只要在$t\ge need$时更新答案就好了。

另外,因为这题边权为整数,所以可以证明二分斜率时也只需要二分整数。设$need-1$到$need$的斜率为$k_1$,$need$到$need+1$的斜率为$k_2$。因为横坐标差值为$1$,则斜率等于纵坐标差值,所以$k_1,k_2$均为整数。若$k_1=k_2$,则$w=k_1$时得到答案。否则显然$k_1+1 \le k_2$(纵坐标差值至少增加$1$)。而$w\in [k_1,k_2]$时都可以得到合法解,并且$[k_1,k_2]$中一定存在整数。所以二分整数即可。

最后还有如何定二分斜率上下界的问题,斜率的范围肯定在【相邻两个点纵坐标差值】的最小值与最大值之间,也就是多取/少取一条白边,$f(x)$的值最多会变动多少。此题边权为$[0,100]$,所以二分范围为$[-100,100]$。

这就是WQS二分。对于在若干个东西中取恰好$k$个的最优化问题,若关于$k$的价值函数为凸函数,则可以利用二分把取$k$个的限制去掉。

代码

直接二分+MST的复杂度是$O(E\log E\log W)$的($W$为二分范围)。复杂度瓶颈在于每次做MST都对所有边重新排序。但事实上我们每次只是把所有白边加上或减去一个相同的值,并不改变黑/白边内部的顺序,所以可以预处理把黑边和白边分别排好序,然后每次用归并排序做。复杂度$O(E\log E+E\log W)$。

#include <bits/stdc++.h>

const int N = 5e4+50;
const int M = 1e5+50;

struct Edge {
    int u, v, w;

    bool operator < (const Edge &t) const {
        return w < t.w;
    }
};

int n, m, need;
Edge ew[M], eb[M];
int mw, mb;

struct DS {
    int fa[N];

    void init() {
        for (int i = 0; i < n; i++)
            fa[i] = i;
    }

    int find(int x) {
        if (fa[x] != x) fa[x] = find(fa[x]);
        return fa[x];
    }
} ds;

bool add(const Edge &e) {
    int fu = ds.find(e.u), fv = ds.find(e.v);
    if (fu == fv) return false;
    ds.fa[fu] = fv;
    return true;
}

std::pair<int, int> kruscal(int delta) {
    ds.init();
    int cnt = 0, i = 1, j = 1, sum = 0, wcnt = 0;
    while (cnt < n-1) {
        if (i <= mw && (j > mb || ew[i].w + delta <= eb[j].w)) {
            if (!add(ew[i])) {
                i++;
                continue;
            }
            cnt++;
            wcnt++;
            sum += ew[i++].w + delta;
        } else {
            if (!add(eb[j])) {
                j++;
                continue;
            }
            cnt++;
            sum += eb[j++].w;
        }
    }
    return std::make_pair(sum, wcnt);
}

int main() {
    scanf("%d%d%d", &n, &m, &need);
    for (int i = 1; i <= m; i++) {
        Edge e;
        int c;
        scanf("%d%d%d%d", &e.u, &e.v, &e.w, &c);
        if (c == 0) ew[++mw] = e;
        else eb[++mb] = e;
    }
    std::sort(ew+1, ew+mw+1);
    std::sort(eb+1, eb+mb+1);
    int L = -100, R = 100, w = -546;
    while (L <= R) {
        int mid = L + R >> 1;
        std::pair<int, int> res = kruscal(mid);
        if (res.second >= need) {
            w = res.first - need * mid;
            L = mid + 1;
        } else {
            R = mid - 1;
        }
    }
    printf("%d\n", w);
    return 0;
}


转载请注明出处。


评论列表,共 0 条评论

    暂无评论

发表评论