Codeforces 997C - Sky Full of Stars

链接

传送门

题意

有一个\(n \times n\)的矩阵,用红绿蓝其进行染色,求出至少有一列或者一行为同色的染色方案数。

思路

官方题解:Codeforces Round #493 — Editorial

\(A_i\)表示第\(i\)行颜色相同的矩阵数量,\(B_i\)表示第\(i\)列颜色相同的矩阵数量。

那么要求的答案即为\(\left| A_1 \bigcup A_2 \dots A_n \bigcup B_1 \bigcup B_2 \dots B_n \right|\)

更一般的,由于对称性,具体的某一行和某一列并不重要,只需要统计有多少行和列满足同色就可以了。

这样得到\(ans = \sum_{i = 0 \dots n, j = 0 \dots n, i + j > 0}{C_n^iC_n^j (-1)^{i + j + 1}f(i, j)}\)

其中,\(f(i, j)\)表示前\(i\)\(j\)列同色的矩阵个数。

对于\(i=0\)\(j=0\)的情况,\(f(0, k) = f(k, 0) = 3^k \cdot 3^{n(n - k)}\),这部分用朴素的方法求和复杂度为\(O(n)\)

对于其他情况,\(f(i, j) = 3 \cdot 3^{(n - i)(n - j)}\),这部分用朴素的方法求和复杂度为\(O(n^2)\)\(O(n^2\log n)\),因此需要化简。

\[ans = \sum_{i = 1}^n\sum_{j = 1}^nC_n^iC_n^j (-1)^{i + j + 1}3 \cdot 3^{(n - i)(n - j)}\]

\(i\)替换为\(n-i\)\(j\)替换为\(n-j\)

\[ans = 3\sum_{i = 0}^{n - 1}\sum_{j = 0}^{n - 1}C_n^{n - i}C_n^{n - j} (-1)^{n - i + n - j + 1}\cdot 3^{ij}\]

由于\(C_n^{n-i} = C_n^i\)\((-1)^{2n}=1\)\((-1)^{-i}=(-1)^i\),得

\[ans = 3\sum_{i = 0}^{n - 1}\sum_{j = 0}^{n - 1}C_n^iC_n^j(-1)^{i + j + 1}\cdot 3^{ij}\]

由于\((a + b)^n=\sum_{i=0}^nC_n^ia^ib^{n-i}\),作如下变换:

\[ans = 3\sum_{i = 0}^{n - 1}C_n^i (-1)^{i + 1}\sum_{j = 0}^{n - 1} C_n^j(-1)^j\cdot \left(3^i\right)^j\]

\[ans = 3\sum_{i = 0}^{n - 1}C_n^i (-1)^{i + 1} \sum_{j = 0}^{n - 1} C_n^j\left(-3^i\right)^j\]

\[ans = 3\sum_{i = 0}^{n - 1}C_n^i (-1)^{i + 1}\left[\left(1 + \left(-3^i\right) \right)^n - \left(-3^i\right)^n\right]\]

可以在\(O(n)\)的时间内完成计算。

代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
#include <cstdio>
#include <cstring>
#include <algorithm>

using namespace std;

typedef long long LL;

const int maxn = 1000010;
const LL mod = 998244353;

LL c[maxn];
LL inv[maxn];
LL pow3[maxn];

LL pow_mod(LL x, LL k) {
LL res = 1, cur = x;
while (k) {
if (k & 1) {
res = res * cur % mod;
}
cur = cur * cur % mod;
k >>= 1;
}
return res;
}

int main() {
int n;
scanf("%d", &n);
inv[1] = 1;
for (int i = 2; i <= n; ++i) {
inv[i] = (mod - mod / i) * inv[mod % i] % mod;
}
c[0] = pow3[0] = 1;
for (int i = 1; i <= n; ++i) {
c[i] = c[i - 1] * (n - i + 1) % mod * inv[i] % mod;
pow3[i] = pow3[i - 1] * 3 % mod;
}
LL ans1 = 0;
for (int i = 1; i <= n; ++i) {
LL sig = i & 1 ? 1 : -1;
ans1 = (ans1 + c[i] * sig * pow3[i] % mod * pow_mod(pow3[n], n - i)) % mod;
}
ans1 = ((ans1 + mod) << 1) % mod;
LL ans2 = 0;
for (int i = 0; i < n; ++i) {
LL sig = i & 1 ? 1 : -1;
LL tmp = mod - pow3[i];
ans2 = (ans2 + c[i] * sig * (pow_mod(1 + tmp, n) - pow_mod(tmp, n))) % mod;
}
ans2 = (ans2 + mod) * 3 % mod;
printf("%lld\n", (ans1 + ans2) % mod);
return 0;
}