欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页  >  IT编程

数位dp踩坑

程序员文章站 2022-06-20 20:12:32
前言 数位DP是什么?以前总觉得这个概念很高大上,最近闲的没事,学了一下发现确实挺神奇的。 从一道简单题说起 "hdu 2089 "不要62"" 一个数字,如果包含'4'或者'62',它是不吉利的。给定m,n,0 using namespace std; typedef long long ll; ......

前言


数位dp是什么?以前总觉得这个概念很高大上,最近闲的没事,学了一下发现确实挺神奇的。

从一道简单题说起


一个数字,如果包含'4'或者'62',它是不吉利的。给定m,n,0<m≤n<10^6,统计[m,n]范围内吉利数的个数。

这题的数据范围比较小,只有1e6,理论上暴力也是可以解的。但数位dp的题目数据范围通常很大,往往达到1e18甚至更大,暴力法o(n)显然会tle。这个时候需要一种时间复杂度近似o(logn)的算法。仔细思考一下,其实可以使用排除法。从高位到低位依次排除0~1e6中不符合条件的数。

举个例子,1~999999中不包含4的数,步骤如下:
1.先排除6位数中最高位是4的数,即400000~499999,只需要判断最高位,就排除了10万个数。
2.接着排除次高位是4的数(此时默认最高位不是4),比如最高位是1时可以排除140000~149999,共1万个。注意,首位可以是0,此时相当于考虑000000~099999(初学者可能有点不理解,看完后面一些不考虑前导0的题就懂了)。
3.同理,继续排除4位数,3位数,直到结束。

数位dp 是对数字的“位”进行的和计数有关的dp,数的每一个位称为数位,一个数有个位、十位、百位、千位……数位dp用来解决和数字操作有关的问题,比如某区间内数字和,特定数字问题等。这些问题的数字范围通常很大,无法暴力解决,必须用接近o(logn)的算法。通过dp对“数位”操作,记录算过的区间的状态,用于后续计算快速筛选数字。

那么这道题用数位dp怎么实现呢?数位dp一般有两种方法,一种是dp预处理乱搞,另一种是记忆化dfs。前者不适合模板化,而后者效率高,模板易记,可以快速上手。

数位dp模板


#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
ll dp[20][20];
int a[20];
ll n,m;
// pos:当前到第几位 pre:上一位数字是什么 lead:是否有前导0 limit:是否有上界限制 (参数通常有pos和limit,是否有lead看题目要求,此外还有可能增加其他参数)
ll dfs(int pos,int pre,bool lead,bool limit){
    if(pos == 0) return 1; //搜完了,返回1.这里通常是返回1,但有些题有限制条件,返回时要判断是否符合限制条件.
    if(!lead && !limit && dp[pos][pre] != -1) return dp[pos][pre]; //记忆化
    int up = limit? a[pos] : 9;  //是否有上界,如果有上界那么不能超过上界.
    ll ans = 0;
    for(int i = 0;i <= up;i++){
        if(lead) ans += dfs(pos - 1,i,lead && i == 0,limit && i == a[pos]);
        else{
            if() continue;
            ans += dfs(pos - 1,i,lead && i == 0,limit && i == a[pos]);
        }
    }
    if(!lead && !limit) dp[pos][pre] = ans;
    return ans;
}
ll solve(ll x){
    int len = 0;
    while(x){   //拆位,a[i]表示这个数的第i位
        a[++len] = x % 10;
        x /= 10;
    }
    memset(dp,-1,sizeof(dp)); //初始化dp数组
    return dfs(len,0,true,true); //从高位向低位枚举
}
int main()
{
    cin >> n >> m;
    cout << solve(m) - solve(n - 1) << endl;
    return 0;
}

这里解答几个初学者常见的疑问

  • 为什么dfs的时候要控制上界(limit)?
    举一个简单的例子,我们要求这个区间[25,628]内满足某种条件的数的个数,这样我们枚举的数肯定不能超过628,因为628是上界。我们从高位往低位枚举,百位的所有可能是0,1,2,3,4,5,6。当百位枚举了1,十位可以枚举0-9,此时相当于十位的枚举没有限制(最大可以到9)。但是,如果百位枚举了6,那么十位只能枚举0-2,因为不能超过628。同理,如果百位枚举了6,十位枚举了2,那么个位只能枚举0-8了。
  • 前导零的影响(lead)?
    为什么百位可以枚举0呢?道理很简单,百位等于0时,相当于此时我们枚举的数是一个两位数,同理,百位和十位都等于0,相当于我们在枚举一个一位数。不过这样会带来前导0的影响,在某些题目,比如:windy数,这种考虑相邻两个数的某种关系的题目的时候,会让记忆化搜索出错,这个时候参数lead就非常重要了,我们通过增加参数来消除影响。

下面我们通过模板来秒掉这道题目吧

dp[pos][sta]表示第pos位前一位是6(sta = 1),或者不是6(sta = 0)时满足条件的数的数目,这里由于数的前一位是否是6会对答案构成影响,所以只需要在统计时把答案分成两类。

#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
int n,m;
ll dp[20][2];
int a[20];
int dfs(int pos,int pre,int sta,bool limit){
    if(pos == 0) return 1;
    if(!limit && dp[pos][sta] != -1) return dp[pos][sta];
    int up = limit?a[pos] : 9;
    int sum = 0;
    for(int i = 0;i <= up;i++){
        if(i == 4) continue;
        if(pre == 6 && i == 2) continue;
        sum += dfs(pos - 1,i,i == 6,limit && i == a[pos]);
    }
    if(!limit) dp[pos][sta] = sum;
    return sum;
}
int solve(int x){
    int len = 0;
    while(x){
        a[++len] = x % 10;
        x /= 10;
    }
    return dfs(len,-1,0,true);
}
int main(){
    memset(dp,-1,sizeof(dp));
    while(~scanf("%d %d",&m,&n) && n && m){
        cout << solve(n) - solve(m - 1) << endl;
    }
    return 0;
}

蓝桥杯 k好数

这题定义了一个叫"k好数"的概念,即该数在k进制下任意相邻两位的数不能相邻(即相差不能等于1),然后要统计l位k进制中k好数的数目。这道题的坑点在于要考虑前导0的影响,比如4进制下的两位数,11,20是k好数,但是32不是k好数。但如果我们求的是4进制下的一位数,问题就来了,如果没有lead这个参数,1这个数就不会被统计,因为传参时默认前一位是0,而0和1相邻,这个时候只需要加一个判断即可。

#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
ll mod = 1e9 + 7;
ll dp[105][105];
int k,l;
ll dfs(int pos,int pre,bool lead,bool limit){
    if(pos == 0) return 1;
    if(!lead && !limit && dp[pos][pre] != -1) return dp[pos][pre];
    int up = k - 1;
    ll ans = 0;
    for(int i = 0;i <= up;i++){
        if(lead) ans += dfs(pos - 1,i,lead && i == 0,limit && i == k - 1);
        else{
            if(abs(i - pre) == 1) continue;
            ans = (ans + dfs(pos - 1,i,lead && i == 0,limit && i == k - 1)) % mod;
        }
    }
    if(!lead && !limit) dp[pos][pre] = ans;
    return ans;
}
ll solve(int len){
    memset(dp,-1,sizeof(dp));
    return dfs(len,0,true,true);
}
int main()
{
    cin >> k >> l;
    cout << (solve(l) - solve(l - 1) + mod ) % mod << endl;
    return 0;
}

b-number hdu3652

这题定义了一个叫b数的东西,这个数含"13"这个串并且可以被13整除。难点在于怎么保存被13整除这个状态,其实很简单,再加一个参数tot记录余数,每次枚举%一下13,当pos=0时判断即可。

#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
ll dp[15][10][13][2];
int a[15];
ll n;
ll dfs(int pos,int pre,int tot,int ok,bool limit){
    if(pos == 0) return ok == 1 && tot == 0;
    if(!limit && dp[pos][pre][tot][ok] != -1) return dp[pos][pre][tot][ok];
    int up = limit? a[pos] : 9;
    ll ans = 0;
    for(int i = 0;i <= up;i++){
        if(pre == 1 && i == 3) ans += dfs(pos - 1,i,(tot * 10 + i) % 13,1,limit && i == a[pos]);
        else ans += dfs(pos - 1,i,(tot * 10 + i) % 13,ok,limit && i == a[pos]);
    }
    if(!limit) dp[pos][pre][tot][ok] = ans;
    return ans;
}
ll solve(ll x){
    int len = 0;
    while(x){
        a[++len] = x % 10;
        x /= 10;
    }
    return dfs(len,0,0,false,true);
}
int main()
{
    memset(dp,-1,sizeof(dp));
    while(~scanf("%lld",&n)) cout << solve(n) << endl;
    return 0;
}

花神的数论题

这道题挺有意思的,要统计的是每个数二进制下1的个数sum(i),然后求sum(1)到sum(n)的累乘。
    我们要转换一下思维,n上限是1e18,意味着这个数的二进制最多有50位,即最多就50个1,我们分别统计二进制下有一个1,两个1,三个1……五十个1的数的个数,然后用快速幂进行优化,问题就迎刃而解了。

#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const ll mod = 1e7 + 7;
ll dp[51][51][51];
int a[51];
ll ans[51];
ll n;
ll poww(ll x,ll y,ll p){
    ll ans = 1;
    while(y){
        if(y & 1) ans = ans * x % p;
        x = x * x % p;
        y >>= 1;
    }
    return ans;
}
//位置 当前统计到的1的个数 目标要统计的1的个数 上界限制
ll dfs(int pos,int tmp,int tot,bool limit){
    if(pos == 0) return tmp == tot;
    if(!limit && dp[pos][tmp][tot] != -1) return dp[pos][tmp][tot];
    int up = limit? a[pos] : 1;
    ll sum = 0;
    for(int i = 0;i <= up;i++){
        sum += dfs(pos - 1,tmp + (i == 1),tot,limit && a[pos] == i);
    }
    if(!limit) dp[pos][tmp][tot] = sum;
    return sum;
}
ll solve(ll x){
    int len = 0;
    ll sum = 1;
    while(x){ //拆出二进制的每一位
        a[++len] = x & 1;
        x >>= 1;
    }
    memset(dp,-1,sizeof(dp));
    //这里是一个很巧妙的优化 分别统计二进制中1的个数为i的数的个数,然后快速幂优化
    for(int i = 1;i <= len;i++){
        ans[i] = dfs(len,0,i,true);
        sum = sum * poww(i,ans[i],mod) % mod;
    }
    return sum;
}
int main()
{
    cin >> n;
    cout << solve(n) << endl;
    return 0;
}

洛谷2602 数字计数

这道题大意是求某区间内数字0-9出现的次数。我大概的想法是,分别统计区间内存在1个1,2个1,3个1……的数的个数,存在1个2,2个2,……,1个9,2个9,……的数的个数,然后求累加,当然题解有更好的方法,以下仅供参考(注意前导0的处理,否则统计0的时候会出错)

#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
ll dp[15][15][15];
int a[15];
ll ans[10][15];
ll total[10];
ll n,m;
int num;
//前导0很关键!!
ll dfs(int pos,int now,int tot,bool lead,bool limit){
    if(pos == 0) return now == tot;
    if(!limit && !lead && dp[pos][now][tot] != -1) return dp[pos][now][tot];
    int up = limit? a[pos] : 9;
    ll sum = 0;
    for(int i = 0;i <= up;i++){
        if(lead && i == 0) sum += dfs(pos - 1,now,tot,true,limit && i == a[pos]);
        else sum += dfs(pos - 1,now + (i == num),tot,lead && i == 0,limit && i == a[pos]);
    }
    if(!limit && !lead) dp[pos][now][tot] = sum;
    return sum;
}
ll solve(ll x,bool ok){
    int len = 0;
    while(x){
        a[++len] = x % 10;
        x /= 10;
    }
    memset(dp,-1,sizeof(dp));
    memset(ans,0,sizeof(ans));
    for(int i = 0;i < 10;i++){
        num = i;
        for(int j = 1;j <= len;j++){
            ans[i][j] = dfs(len,0,j,true,true);
            if(ok) total[i] += ans[i][j] * j;
            else total[i] -= ans[i][j] * j;
        }
    }
}
int main()
{
    cin >> n >> m;
    solve(m,1);
    solve(n - 1,0);
    for(int i = 0;i < 10;i++) cout << total[i] << " ";
    return 0;
}

洛谷4999 烦人的数学作业

题目的大意是求l-r区间内每个数的各位数字和,假设这个数共len位,那么它的各位数字之和不会超过9*len(因为每位数字最大是9),我们只需要从小到大枚举统计即可,本质思想和上面的题目是一样的。

#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const ll mod = 1e9 + 7;
int t;
ll n,m;
ll dp[20][180][180],ans[180];
int a[20];
//位置 当前的各位和 目标和 上界限制
ll dfs(int pos,int now,int tot,bool limit){
    if(pos == 0) return now == tot;
    if(!limit && dp[pos][now][tot] != -1) return dp[pos][now][tot];
    int up = limit? a[pos] : 9;
    ll sum = 0;
    for(int i = 0;i <= up;i++){
        sum += dfs(pos - 1,now + i,tot,limit && a[pos] == i);
    }
    if(!limit) dp[pos][now][tot] = sum;
    return sum;
}
ll solve(ll x){
    int len = 0;
    ll sum = 0;
    while(x){
        a[++len] = x % 10;
        x /= 10;
    }
    //统计各位和为1-9*len的所有情况
    for(int i = 1;i <= 9 * len;i++){
        ans[i] = dfs(len,0,i,true);
        sum = (sum + i * (ans[i] % mod)) % mod;
    }
    return sum;
}
int main()
{
    scanf("%d",&t);
    memset(dp,-1,sizeof(dp));
    while(t--){
        scanf("%lld %lld",&n,&m);
        //这里减法要处理!!!
        cout << (solve(m) - solve(n - 1) + mod) % mod << endl;
    }
    return 0;
}

蒟蒻的第一篇blog,有错还请各位神犇指出~~