luogu2109 [NOI2007]生成树计数

题目描述

有$n$个点,每个点只能向编号在$[i-k, i-1]$内的点连边,求有标号生成树的个数。

$n\ge 10^{15}$,$k\le 5$。

输入格式

输入文件中包含两个整数 k, n,由一个空格分隔。

输出格式

输出文件输出一个整数,表示生成树的个数。由于答案可能比较大,所以你只要输出答案除 65521 的余数即可。

样例输入

3 5

样例输出

75

分析

可以发现$k$只有$5$,那么可以状压记录$[i-k-1,i-1]$的连通情况,再枚举$i$的连边,就可以递推了。

但是$n$特别大。观察到$f(i-1,s)$到$f(i,t)$的转移只与状态$s,t$有关,所以可以先构造出转移矩阵,然后用矩阵快速幂优化。

连通状态用最小表示法表示。dfs得状态数只有52个,很容易构造转移矩阵。

另外当$i=k$时的$f$初值要先预处理。对于每个连通块,设大小为$m$,则生成树个数为$m^{m-2}$。

代码

#include <bits/stdc++.h>
typedef long long ll;

const int K = 6;
const int S = 55;
const int MXS = 1 << 16;
const int P = 65521;
const int POW[] = {1, 1, 1, 3, 16, 125};

int k;
ll n;
int tot, st[S], id[MXS];
int bl[K];
int f[S];

void madd(int &x, int y) {
    x += y;
    if (x >= P) x -= P;
    if (x < 0) x += P;
}

struct Matrix {
    int a[S][S];

    Matrix(bool p=false) {
        memset(a, 0, sizeof a);
        if (p) {
            for (int i = 0; i < tot; i++)
                a[i][i] = 1;
        }
    }

    int* operator [] (int k) {
        return a[k];
    }

    const int* operator [] (int k) const {
        return a[k];
    }

    Matrix operator * (const Matrix &b) const {
        Matrix c;
        for (int i = 0; i < tot; i++)
            for (int j = 0; j < tot; j++)
                for (int t = 0; t < tot; t++)
                    madd(c[i][j], (ll)a[i][t] * b[t][j] % P);
        return c;
    }

    Matrix operator ^ (ll t) const {
        Matrix rtn(1), b = *this;
        for (; t; t >>= 1, b = b * b)
            if (t & 1) rtn = rtn * b;
        return rtn;
    }
} g;

void decode(int s, int b[]) {
    for (int i = k-1; i >= 0; i--) {
        b[i] = s & 7;
        s >>= 3;
    }
}

int encode(int b[]) {
    int t[K];
    memset(t, -1, sizeof t);
    int top = 0, s = 0;
    for (int i = 0; i < k; i++) {
        if (t[b[i]] == -1)
            t[b[i]] = ++top;
        s = s << 3 | t[b[i]];
    }
    return s;
}

void dfs(int u, int mx) {
    if (u == k) {
        int s = encode(bl);
        st[tot] = s;
        id[s] = tot++;
        return;
    }
    for (int i = 1; i <= mx; i++) {
        bl[u] = i;
        dfs(u+1, mx);
    }
    bl[u] = mx + 1;
    dfs(u+1, mx+1);
}

void getNxt(int u, int conn) {
    int s = st[u];
    decode(s, bl);
    bl[k] = 0;
    for (int i = 0; i < k; i++) {
        if (conn & 1 << i) {
            if (!bl[k]) bl[k] = bl[i];
            else if (bl[k] != bl[i]) {
                int t = bl[i];
                for (int j = 0; j < k; j++) {
                    if (bl[j] == t)
                        bl[j] = bl[k];
                }
            } else return;
        }
    }
    bool ok = false;
    for (int i = 1; i <= k; i++)
        if (bl[0] == bl[i])
            ok = true;
    if (!ok) return;
    int t = encode(bl+1);
    g[u][id[t]]++;
}

int main() {
    scanf("%d%lld", &k, &n);
    dfs(0, 0);
    for (int i = 0; i < tot; i++) {
        int s = st[i];
        decode(s, bl);

        int cnt[K] = {0};
        for (int j = 0; j < k; j++)
            cnt[bl[j]]++;
        f[i] = 1;
        for (int j = 0; j < k; j++)
            f[i] *= POW[cnt[j]];

        for (int conn = 0, lim = 1 << k; conn < lim; conn++)
            getNxt(i, conn);
    }
    g = g ^ (n-k);
    int ans = 0;
    for (int i = 0; i < tot; i++)
        madd(ans, (ll)f[i] * g[i][0] % P);
    printf("%d\n", ans);
    return 0;
}


转载请注明出处。


评论列表,共 0 条评论

    暂无评论

发表评论