数位dp踩坑
前言
数位dp是什么?以前总觉得这个概念很高大上,最近闲的没事,学了一下发现确实挺神奇的。
从一道简单题说起
一个数字,如果包含'4'或者'62',它是不吉利的。给定m,n,0<m≤n<10^6,统计[m,n]范围内吉利数的个数。
这题的数据范围比较小,只有1e6,理论上暴力也是可以解的。但数位dp的题目数据范围通常很大,往往达到1e18甚至更大,暴力法o(n)显然会tle。这个时候需要一种时间复杂度近似o(logn)的算法。仔细思考一下,其实可以使用排除法。从高位到低位依次排除0~1e6中不符合条件的数。
举个例子,1~999999中不包含4的数,步骤如下:
1.先排除6位数中最高位是4的数,即400000~499999,只需要判断最高位,就排除了10万个数。
2.接着排除次高位是4的数(此时默认最高位不是4),比如最高位是1时可以排除140000~149999,共1万个。注意,首位可以是0,此时相当于考虑000000~099999(初学者可能有点不理解,看完后面一些不考虑前导0的题就懂了)。
3.同理,继续排除4位数,3位数,直到结束。
数位dp 是对数字的“位”进行的和计数有关的dp,数的每一个位称为数位,一个数有个位、十位、百位、千位……数位dp用来解决和数字操作有关的问题,比如某区间内数字和,特定数字问题等。这些问题的数字范围通常很大,无法暴力解决,必须用接近o(logn)的算法。通过dp对“数位”操作,记录算过的区间的状态,用于后续计算快速筛选数字。
那么这道题用数位dp怎么实现呢?数位dp一般有两种方法,一种是dp预处理乱搞,另一种是记忆化dfs。前者不适合模板化,而后者效率高,模板易记,可以快速上手。
数位dp模板
#include <bits/stdc++.h> using namespace std; typedef long long ll; ll dp[20][20]; int a[20]; ll n,m; // pos:当前到第几位 pre:上一位数字是什么 lead:是否有前导0 limit:是否有上界限制 (参数通常有pos和limit,是否有lead看题目要求,此外还有可能增加其他参数) ll dfs(int pos,int pre,bool lead,bool limit){ if(pos == 0) return 1; //搜完了,返回1.这里通常是返回1,但有些题有限制条件,返回时要判断是否符合限制条件. if(!lead && !limit && dp[pos][pre] != -1) return dp[pos][pre]; //记忆化 int up = limit? a[pos] : 9; //是否有上界,如果有上界那么不能超过上界. ll ans = 0; for(int i = 0;i <= up;i++){ if(lead) ans += dfs(pos - 1,i,lead && i == 0,limit && i == a[pos]); else{ if() continue; ans += dfs(pos - 1,i,lead && i == 0,limit && i == a[pos]); } } if(!lead && !limit) dp[pos][pre] = ans; return ans; } ll solve(ll x){ int len = 0; while(x){ //拆位,a[i]表示这个数的第i位 a[++len] = x % 10; x /= 10; } memset(dp,-1,sizeof(dp)); //初始化dp数组 return dfs(len,0,true,true); //从高位向低位枚举 } int main() { cin >> n >> m; cout << solve(m) - solve(n - 1) << endl; return 0; }
这里解答几个初学者常见的疑问
- 为什么dfs的时候要控制上界(limit)?
举一个简单的例子,我们要求这个区间[25,628]内满足某种条件的数的个数,这样我们枚举的数肯定不能超过628,因为628是上界。我们从高位往低位枚举,百位的所有可能是0,1,2,3,4,5,6。当百位枚举了1,十位可以枚举0-9,此时相当于十位的枚举没有限制(最大可以到9)。但是,如果百位枚举了6,那么十位只能枚举0-2,因为不能超过628。同理,如果百位枚举了6,十位枚举了2,那么个位只能枚举0-8了。 - 前导零的影响(lead)?
为什么百位可以枚举0呢?道理很简单,百位等于0时,相当于此时我们枚举的数是一个两位数,同理,百位和十位都等于0,相当于我们在枚举一个一位数。不过这样会带来前导0的影响,在某些题目,比如:windy数,这种考虑相邻两个数的某种关系的题目的时候,会让记忆化搜索出错,这个时候参数lead就非常重要了,我们通过增加参数来消除影响。
下面我们通过模板来秒掉这道题目吧
dp[pos][sta]表示第pos位前一位是6(sta = 1),或者不是6(sta = 0)时满足条件的数的数目,这里由于数的前一位是否是6会对答案构成影响,所以只需要在统计时把答案分成两类。
#include <bits/stdc++.h> using namespace std; typedef long long ll; int n,m; ll dp[20][2]; int a[20]; int dfs(int pos,int pre,int sta,bool limit){ if(pos == 0) return 1; if(!limit && dp[pos][sta] != -1) return dp[pos][sta]; int up = limit?a[pos] : 9; int sum = 0; for(int i = 0;i <= up;i++){ if(i == 4) continue; if(pre == 6 && i == 2) continue; sum += dfs(pos - 1,i,i == 6,limit && i == a[pos]); } if(!limit) dp[pos][sta] = sum; return sum; } int solve(int x){ int len = 0; while(x){ a[++len] = x % 10; x /= 10; } return dfs(len,-1,0,true); } int main(){ memset(dp,-1,sizeof(dp)); while(~scanf("%d %d",&m,&n) && n && m){ cout << solve(n) - solve(m - 1) << endl; } return 0; }
蓝桥杯 k好数
这题定义了一个叫"k好数"的概念,即该数在k进制下任意相邻两位的数不能相邻(即相差不能等于1),然后要统计l位k进制中k好数的数目。这道题的坑点在于要考虑前导0的影响,比如4进制下的两位数,11,20是k好数,但是32不是k好数。但如果我们求的是4进制下的一位数,问题就来了,如果没有lead这个参数,1这个数就不会被统计,因为传参时默认前一位是0,而0和1相邻,这个时候只需要加一个判断即可。
#include <bits/stdc++.h> using namespace std; typedef long long ll; ll mod = 1e9 + 7; ll dp[105][105]; int k,l; ll dfs(int pos,int pre,bool lead,bool limit){ if(pos == 0) return 1; if(!lead && !limit && dp[pos][pre] != -1) return dp[pos][pre]; int up = k - 1; ll ans = 0; for(int i = 0;i <= up;i++){ if(lead) ans += dfs(pos - 1,i,lead && i == 0,limit && i == k - 1); else{ if(abs(i - pre) == 1) continue; ans = (ans + dfs(pos - 1,i,lead && i == 0,limit && i == k - 1)) % mod; } } if(!lead && !limit) dp[pos][pre] = ans; return ans; } ll solve(int len){ memset(dp,-1,sizeof(dp)); return dfs(len,0,true,true); } int main() { cin >> k >> l; cout << (solve(l) - solve(l - 1) + mod ) % mod << endl; return 0; }
b-number hdu3652
这题定义了一个叫b数的东西,这个数含"13"这个串并且可以被13整除。难点在于怎么保存被13整除这个状态,其实很简单,再加一个参数tot记录余数,每次枚举%一下13,当pos=0时判断即可。
#include <bits/stdc++.h> using namespace std; typedef long long ll; ll dp[15][10][13][2]; int a[15]; ll n; ll dfs(int pos,int pre,int tot,int ok,bool limit){ if(pos == 0) return ok == 1 && tot == 0; if(!limit && dp[pos][pre][tot][ok] != -1) return dp[pos][pre][tot][ok]; int up = limit? a[pos] : 9; ll ans = 0; for(int i = 0;i <= up;i++){ if(pre == 1 && i == 3) ans += dfs(pos - 1,i,(tot * 10 + i) % 13,1,limit && i == a[pos]); else ans += dfs(pos - 1,i,(tot * 10 + i) % 13,ok,limit && i == a[pos]); } if(!limit) dp[pos][pre][tot][ok] = ans; return ans; } ll solve(ll x){ int len = 0; while(x){ a[++len] = x % 10; x /= 10; } return dfs(len,0,0,false,true); } int main() { memset(dp,-1,sizeof(dp)); while(~scanf("%lld",&n)) cout << solve(n) << endl; return 0; }
花神的数论题
这道题挺有意思的,要统计的是每个数二进制下1的个数sum(i),然后求sum(1)到sum(n)的累乘。
我们要转换一下思维,n上限是1e18,意味着这个数的二进制最多有50位,即最多就50个1,我们分别统计二进制下有一个1,两个1,三个1……五十个1的数的个数,然后用快速幂进行优化,问题就迎刃而解了。
#include <bits/stdc++.h> using namespace std; typedef long long ll; const ll mod = 1e7 + 7; ll dp[51][51][51]; int a[51]; ll ans[51]; ll n; ll poww(ll x,ll y,ll p){ ll ans = 1; while(y){ if(y & 1) ans = ans * x % p; x = x * x % p; y >>= 1; } return ans; } //位置 当前统计到的1的个数 目标要统计的1的个数 上界限制 ll dfs(int pos,int tmp,int tot,bool limit){ if(pos == 0) return tmp == tot; if(!limit && dp[pos][tmp][tot] != -1) return dp[pos][tmp][tot]; int up = limit? a[pos] : 1; ll sum = 0; for(int i = 0;i <= up;i++){ sum += dfs(pos - 1,tmp + (i == 1),tot,limit && a[pos] == i); } if(!limit) dp[pos][tmp][tot] = sum; return sum; } ll solve(ll x){ int len = 0; ll sum = 1; while(x){ //拆出二进制的每一位 a[++len] = x & 1; x >>= 1; } memset(dp,-1,sizeof(dp)); //这里是一个很巧妙的优化 分别统计二进制中1的个数为i的数的个数,然后快速幂优化 for(int i = 1;i <= len;i++){ ans[i] = dfs(len,0,i,true); sum = sum * poww(i,ans[i],mod) % mod; } return sum; } int main() { cin >> n; cout << solve(n) << endl; return 0; }
洛谷2602 数字计数
这道题大意是求某区间内数字0-9出现的次数。我大概的想法是,分别统计区间内存在1个1,2个1,3个1……的数的个数,存在1个2,2个2,……,1个9,2个9,……的数的个数,然后求累加,当然题解有更好的方法,以下仅供参考(注意前导0的处理,否则统计0的时候会出错)
#include <bits/stdc++.h> using namespace std; typedef long long ll; ll dp[15][15][15]; int a[15]; ll ans[10][15]; ll total[10]; ll n,m; int num; //前导0很关键!! ll dfs(int pos,int now,int tot,bool lead,bool limit){ if(pos == 0) return now == tot; if(!limit && !lead && dp[pos][now][tot] != -1) return dp[pos][now][tot]; int up = limit? a[pos] : 9; ll sum = 0; for(int i = 0;i <= up;i++){ if(lead && i == 0) sum += dfs(pos - 1,now,tot,true,limit && i == a[pos]); else sum += dfs(pos - 1,now + (i == num),tot,lead && i == 0,limit && i == a[pos]); } if(!limit && !lead) dp[pos][now][tot] = sum; return sum; } ll solve(ll x,bool ok){ int len = 0; while(x){ a[++len] = x % 10; x /= 10; } memset(dp,-1,sizeof(dp)); memset(ans,0,sizeof(ans)); for(int i = 0;i < 10;i++){ num = i; for(int j = 1;j <= len;j++){ ans[i][j] = dfs(len,0,j,true,true); if(ok) total[i] += ans[i][j] * j; else total[i] -= ans[i][j] * j; } } } int main() { cin >> n >> m; solve(m,1); solve(n - 1,0); for(int i = 0;i < 10;i++) cout << total[i] << " "; return 0; }
洛谷4999 烦人的数学作业
题目的大意是求l-r区间内每个数的各位数字和,假设这个数共len位,那么它的各位数字之和不会超过9*len(因为每位数字最大是9),我们只需要从小到大枚举统计即可,本质思想和上面的题目是一样的。
#include <bits/stdc++.h> using namespace std; typedef long long ll; const ll mod = 1e9 + 7; int t; ll n,m; ll dp[20][180][180],ans[180]; int a[20]; //位置 当前的各位和 目标和 上界限制 ll dfs(int pos,int now,int tot,bool limit){ if(pos == 0) return now == tot; if(!limit && dp[pos][now][tot] != -1) return dp[pos][now][tot]; int up = limit? a[pos] : 9; ll sum = 0; for(int i = 0;i <= up;i++){ sum += dfs(pos - 1,now + i,tot,limit && a[pos] == i); } if(!limit) dp[pos][now][tot] = sum; return sum; } ll solve(ll x){ int len = 0; ll sum = 0; while(x){ a[++len] = x % 10; x /= 10; } //统计各位和为1-9*len的所有情况 for(int i = 1;i <= 9 * len;i++){ ans[i] = dfs(len,0,i,true); sum = (sum + i * (ans[i] % mod)) % mod; } return sum; } int main() { scanf("%d",&t); memset(dp,-1,sizeof(dp)); while(t--){ scanf("%lld %lld",&n,&m); //这里减法要处理!!! cout << (solve(m) - solve(n - 1) + mod) % mod << endl; } return 0; }
蒟蒻的第一篇blog,有错还请各位神犇指出~~