HDU-6314 Matrix(计数)
Matrix
Time Limit: 2000/1000 MS (Java/Others) Memory Limit: 332768/332768 K (Java/Others)
Total Submission(s): 449 Accepted Submission(s): 107
Problem Description
Samwell Tarly is learning to draw a magical matrix to protect himself from the White Walkers.
the magical matrix is a matrix with n rows and m columns, and every single block should be painted either black or white.
Sam wants to know how many ways to paint the matrix, satisfied that the final matrix has at least A rows, B columns was painted completely black. Cause the answer might be too big, you only need to output it modulo 998244353.
Input
There might be multiple test cases, no more than 5. You need to read till the end of input.
For each test case, a line containing four integers n,m,A,B.
1≤n,m,A,B≤3000.
Output
For each test case, output a line containing the answer modulo 998244353.
Sample Input
3 4 1 2
Sample Output
169
#include <bits/stdc++.h>
#define x first
#define y second
using namespace std;
typedef pair<int, int> PII;
typedef long long LL;
const int MX = 3e3 + 5;
const int mod = 998244353;
int c[MX][MX];
int tw[MX * MX];
LL f1[MX], f2[MX];
void pre_solve() {
c[1][0] = c[1][1] = 1;
for (int i = 2; i < MX; i++) {
c[i][0] = 1;
for (int j = 1; j <= i; j++) {
c[i][j] = c[i - 1][j] + c[i - 1][j - 1];
if (c[i][j] >= mod) c[i][j] -= mod;
}
}
tw[0] = 1;
for (int i = 1; i < MX * MX; i++) {
tw[i] = (tw[i - 1] << 1);
if (tw[i] >= mod) tw[i] -= mod;
}
}
int main() {
freopen ("in.txt", "r", stdin);
int n, m, a, b;
pre_solve();
while (~scanf ("%d%d%d%d", &n, &m, &a, &b) ) {
//f1[n]:恰好染黑n行的方案数
//恰好染黑x行时,恰好染黑i行的方案已经被计算过C[x][i]次
f1[a] = 1;
for (int x = a + 1; x <= n; x++) {
f1[x] = 1;
for (int i = a; i < x; i++) {
f1[x] = (f1[x] - f1[i] * c[x][i] % mod + mod) % mod;
}
f1[x] = f1[x] % mod;
}
//f2[m]:恰好染黑m列的方案数
//恰好染黑y列时,恰好染黑j列的方案已经被计算过C[y][j]次
f2[b] = 1;
for (int y = b + 1; y <= m; y++) {
f2[y] = 1;
for (int j = b; j < y; j++) {
f2[y] = (f2[y] - f2[j] * c[y][j] % mod + mod) % mod;
}
f2[y] = f2[y] % mod;
}
//ans = sum_{i=a...n,j=b...m} f1[i] * f2[i] * c[n][i] * c[m][j] * 2^((n-i)*(m-j))
for (int i = a; i <= n; i++) f1[i] = f1[i] * c[n][i] % mod;
for (int j = a; j <= m; j++) f2[j] = f2[j] * c[m][j] % mod;
LL ans = 0;
for (int i = a; i <= n; i++) {
for (int j = b; j <= m; j++) {
ans = (ans + tw[ (n - i) * (m - j)] * f1[i] % mod * f2[j]) % mod;
}
}
printf ("%lld\n", ans);
}
}