置换群 + 生成函数 + NTT + 启发式合并/分治
题意
给一个 1-n 的排列 p 和一个非负整数 k,求大小为 k 的 {1, 2, 3,... n} 的子集合 T 的数量,满足
即 T 的元素按 p 置换一轮后和自身没有交集
思路
-
\(i\) -> \(p_i\) 连边,找到 m 个环
-
设某个环的大小为 a,要找 b 个元素进入 T 集合,若这 b 个被置换一轮后与自身没有交集,即这 b 个元素在这个环上是不相邻的,求出该环的生成函数,\([x^b]f(x)\) 表示当前这个环中,取 b 个不相邻的元素的方案数
-
对于这 m 个环,这 m 个多项式求卷积,\([x^k]F(x)\) 即为答案
然而 m 与 n 是一个数量级的,暴力求卷积的复杂度为 \(n^2logn\), 有启发式合并和分治两种方法优化
由于这 m 个环一共有 n 个元素,所以这 m 个多项式的项数之和是 m,因此可以启发式合并,用堆来找到当前项数最小的两个多项式,优先合并,这样复杂度为 \(nlog^2n\)
也可以分治,复杂度为 \(T(n)=T(\frac n2)+O(nlogn)\), 解得 \(T(n)=O(nlog^2n)\)
求每个环的生成函数
步骤 2 中求生成函数是本题的关键,抽象为 a 个相同的球围成一个环,从中抽 b 个互不相邻的球的方案数
考虑链式的情况
如果 a 个球排成一排,抽 b 个互不相邻的球,可以先固定这 b 个球,然后把剩下 a - b 个球放到这 b + 1 个空位中,设每个空位放的球的个数为 \(x_0,x_1,x_2...x_b\)
因为这 b 个球要互不相邻,所以 \(x_1,x_2...x_{b-1}\) 至少为 1,\(x_0,x_b\) 可以为 0
若给 \(x_0,x_b\) 都提前放一个球,则转化为求 \(x_0+x_1+...+x_{b-1}+x_b=a-b+2\) 的正整数解个数
用隔板法,\(g(a,b)=\binom {a-b+1}{b}\)
考虑环上的情况
先随便选一个球,这个球有两种情况
- 被选在那 b 个球中,所以它相邻两个球不能选,并且可以在这个球处把环断开,方案数为 \(g(a-3,b-1)\)
- 没有被选中,也可以在这个球处断开,方案数为 \(g(a-1,b)\)
因此生成函数为 \(f(a,b)=g(a-3,b-1)+g(a-1,b)=\binom {a-b-1}{b-1}+\binom {a-b}b\)
代码
#include <bits/stdc++.h>
using namespace std;
#define clog(x) std::clog << (#x) << " is " << (x) << '\n';using ll = long long;namespace NFTS {
const int M = 998244353, g = 3;
std::vector<int> rev, roots{0, 1};
int powMod(int x, int n) {int r(1);while (n) {if (n&1) r = 1ll * r * x % M;n >>= 1; x = 1ll * x * x % M;}return r;
}
void dft(std::vector<int> &a) {int n = a.size();if ((int)rev.size() != n) {int k = __builtin_ctz(n) - 1;rev.resize(n);for (int i = 0; i < n; ++i) {rev[i] = rev[i >> 1] >> 1 | (i & 1) << k;}}if ((int)roots.size() < n) {int k = __builtin_ctz(roots.size());roots.resize(n);while ((1 << k) < n) {int e = powMod(g, (M - 1) >> (k + 1));for (int i = 1 << (k - 1); i < (1 << k); ++i) {roots[2 * i] = roots[i];roots[2 * i + 1] = 1ll * roots[i] * e % M;}++k;}}for (int i = 0; i < n; ++i) if (rev[i] < i) {std::swap(a[i], a[rev[i]]);}for (int k = 1; k < n; k *= 2) {for (int i = 0; i < n; i += 2 * k) {for (int j = 0; j < k; ++j) {int u = a[i + j];int v = 1ll * a[i + j + k] * roots[k + j] % M;int x = u + v, y = u - v;if (x >= M) x -= M;if (y < 0) y += M;a[i + j] = x;a[i + j + k] = y;}}}
}
void idft(std::vector<int> &a) {int n = a.size();std::reverse(a.begin() + 1, a.end());dft(a);int inv = powMod(n, M - 2);for (int i = 0; i < n; ++i) {a[i] = 1ll * a[i] * inv % M;}
}
} //namespace NFTS// 如果需要换模 M,那么就把 M 当做全局变量提出来,NFT 中 g, rev, root 需要重新初始化。
// 如果只需要换模几个常数 M,可使用 template<ll M>(但是不是特别推荐)
class PolyS {void reverse() {std::reverse(a.begin(), a.end());}
public:static const int M = NFTS::M;static const int inv2 = (M + 1) / 2;std::vector<int> a;PolyS() {}PolyS(int x) { if (x) a = {x};}PolyS(const std::vector<int> _a) : a(_a) {}int size() const { return a.size();}int& operator[](int id) { return a[id];}int at(int id) const {if (id < 0 || id >= (int)a.size()) return 0;return a[id];}PolyS operator-() const {auto A = *this;for (auto &x : A.a) x = (x == 0 ? 0 : M - x);return A;}PolyS mulXn(int n) const {auto b = a;b.insert(b.begin(), n, 0);return PolyS(b);}PolyS modXn(int n) const {if (n > size()) return *this;return PolyS({a.begin(), a.begin() + n});}PolyS divXn(int n) const {if (size() <= n) return PolyS();return PolyS({a.begin() + n, a.end()});}PolyS &operator+=(const PolyS &rhs) {if (size() < rhs.size()) a.resize(rhs.size());for (int i = 0; i < rhs.size(); ++i) {if ((a[i] += rhs.a[i]) >= M) a[i] -= M;}return *this;}PolyS &operator-=(const PolyS &rhs) {if (size() < rhs.size()) a.resize(rhs.size());for (int i = 0; i < rhs.size(); ++i) {if ((a[i] -= rhs.a[i]) < 0) a[i] += M;}return *this;}PolyS &operator*=(PolyS rhs) {int n = size(), m = rhs.size(), tot = std::max(1, n + m - 1);int sz = 1 << std::__lg(tot * 2 - 1);a.resize(sz);rhs.a.resize(sz);NFTS::dft(a);NFTS::dft(rhs.a);for (int i = 0; i < sz; ++i) {a[i] = 1ll * a[i] * rhs.a[i] % M;}NFTS::idft(a);return *this;}PolyS operator+(const PolyS &rhs) const {return PolyS(*this) += rhs;}PolyS operator-(const PolyS &rhs) const {return PolyS(*this) -= rhs;}PolyS operator*(PolyS rhs) const {return PolyS(*this) *= rhs;}
}; // PolyS 全家桶测试:https://www.luogu.com.cn/training/3015#information
const int N = 5e5 + 10;
const int mod = NFTS::M;
int n, k, idx;
int p[N], cnt[N];
bool st[N];
ll fac[N], finv[N];
PolyS a[N/2];
ll qmi(ll a, ll b)
{ll ans = 1;while(b){if (b & 1)ans = ans * a % mod;b >>= 1;a = a * a % mod;}return ans;
}
void presolve(int n)
{fac[0] = finv[0] = 1;for (int i = 1; i <= n; i++)fac[i] = fac[i-1] * i % mod;finv[n] = qmi(fac[n], mod - 2);for (int i = n - 1; i >= 1; i--)finv[i] = finv[i+1] * (i + 1) % mod;
}
void init()
{idx = 0;fill(st, st + n + 2, false);
}
ll C(int n, int m)
{if (m < 0 || n - m < 0)return 0;return fac[n] * finv[m] % mod * finv[n-m] % mod;
}struct cmp
{bool operator()(int x, int y){return a[x].size() > a[y].size();}
};
//启发式合并
ll solve()
{priority_queue<int, vector<int>, cmp> heap;for (int i = 0; i < idx; i++)heap.push(i);while(heap.size() > 1){int p = heap.top(); heap.pop();int q = heap.top(); heap.pop();a[p] *= a[q];a[p] = a[p].modXn(k + 1);heap.push(p);}int ver = heap.top();return a[ver].at(k);
}
//分治
PolyS solve(int l, int r)
{if (l == r)return a[l];int mid = l + r >> 1;return solve(l, mid) * solve(mid + 1, r);
}
int main()
{int T;scanf("%d", &T);presolve(N - 5);while(T--){scanf("%d%d", &n, &k);init();for (int i = 1; i <= n; i++)scanf("%d", p + i);for (int i = 1; i <= n; i++){if (st[i])continue;st[i] = true;int tot = 1, now = i;while(i != p[now]){tot++;st[p[now]] = true;now = p[now];}if (tot == 1)continue;cnt[idx++] = tot;}for (int i = 0; i < idx; i++){int tot = cnt[i];vector<int> tmp(min(k + 1, tot / 2 + 1));for (int j = 0; j <= min(k, tot / 2); j++)tmp[j] = (C(tot - j, j) + C(tot - j - 1, j - 1)) % mod;a[i] = PolyS(tmp);}ll ans = solve();printf("%lld\n", ans);}return 0;
}