BZOJ4513: [Sdoi2016]储能表(数位dp)
程序员文章站
2022-06-29 13:31:17
题意 "题目链接" Sol 一点思路都没有,只会暴力,没想到标算是数位dp??Orz 首先答案可以分成两部分来统计 设 $$ f_{i,j}= \begin{aligned} i\oplus j &\left( i\oplus j k\right) \\ 0 &\left( i\oplus j de ......
题意
sol
一点思路都没有,只会暴力,没想到标算是数位dp??orz
首先答案可以分成两部分来统计
设
\[ f_{i,j}= \begin{aligned} i\oplus j &\left( i\oplus j >k\right) \\ 0 &\left( i\oplus j <=k\right) \end{aligned} \]
那么我们要求的就是
\[\sum_{i=0}^{n - 1} \sum_{j = 0}^{m - 1} f(i, j) - k * \sum_{i = 0}^{n - 1} \sum_{j = 0}^{m - 1} [f(i, j)]\]
也就是说,我们要统计出满足条件的数的异或和以及满足条件的数的对数
考虑直接在二进制下数位dp,注意这里我们要记三维状态
\(f[len][0/1][0/1][0/1]\)表示此时到第\(len\)位,是否顶着\(n\)的上界,是否顶着\(m\)的上界,是否顶着\(k\)的下界
然后直接dp就可以了
// luogu-judger-enable-o2 #include<bits/stdc++.h> #define pair pair<ll, ll> #define mp make_pair #define fi first #define se second #define ll long long #define int long long using namespace std; const int maxn = 233; inline ll read() { char c = getchar(); int x = 0, f = 1; while(c < '0' || c > '9') {if(c == '-') f = -1; c = getchar();} while(c >= '0' && c <= '9') x = x * 10 + c - '0', c = getchar(); return x * f; } ll n, m, k, mod, lim, vis[maxn][2][2][2]; pair f[maxn][2][2][2]; void add2(ll &x, ll y) { if(x + y < 0) x = (x + y + mod); else x = (x + y >= mod ? x + y - mod : x + y); } ll add(ll x, ll y) { if(x + y < 0) return x + y + mod; return x + y >= mod ? x + y - mod : x + y; } ll mul(ll x, ll y) { return 1ll * x % mod * y % mod; } int get(ll x) { int len = 0; while(x) x >>= 1, len++; return len; } pair dfs(int now, int f1, int f2, int f3) { if(now > lim) return mp(0, 1); if(vis[now][f1][f2][f3]) return f[now][f1][f2][f3]; vis[now][f1][f2][f3] = 1; pair ans = mp(0, 0); int l1 = (n >> lim - now) & 1, l2 = (m >> lim - now) & 1, l3 = (k >> lim - now) & 1; //cout << (f1 &&(!l1)) << endl; for(int i = 0; i <= (f1 ? l1 : 1); i++) { for(int j = 0; j <= (f2 ? l2 : 1); j++) { if(f3 && ((i ^ j) < l3)) continue; pair nxt = dfs(now + 1, f1 && (i == l1), f2 && (j == l2), f3 && ((i ^ j) == l3)); add2(ans.se, nxt.se); add2(ans.fi, add(nxt.fi, mul(nxt.se, mul((i ^ j), (1ll << lim - now))))); } } return f[now][f1][f2][f3] = ans; } int solve() { memset(vis, 0, sizeof(vis)); memset(f, 0, sizeof(f)); lim = 0; n = read(); m = read(); k = read(); mod = read(); n--; m--; lim = max(get(n), max(get(k), get(m))); pair ans = dfs(1, 1, 1, 1); return add(ans.fi, -mul(k, ans.se)); } signed main() { for(int t = read(); t; t--, printf("%lld\n", solve())); return 0; } /* 5000 504363800392059286 554192717354508770 21453916680846604 401134357 */