欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页  >  IT编程

动态dp初探

程序员文章站 2022-04-28 13:28:50
动态区间最大子段和问题 给出长度为$n$的序列和$m$次操作,每次修改一个元素的值或查询区间的最大字段和(SP1714 GSS3)。 设$f[i]$为以下标$i$结尾的最大子段和,$g[i]$表示从起始位置到$i$以内的最大子段和。 $$ f[i]=\max(f[i 1]+a[i],a[i])\\g ......

动态区间最大子段和问题

给出长度为\(n\)的序列和\(m\)次操作,每次修改一个元素的值或查询区间的最大字段和(sp1714 gss3)。

\(f[i]\)为以下标\(i\)结尾的最大子段和,\(g[i]\)表示从起始位置到\(i\)以内的最大子段和。
\[ f[i]=\max(f[i-1]+a[i],a[i])\\g[i]=\max(g[i-1],f[i]) \]
定义如下的矩阵乘法,显然这满足乘法结合律和分配律。
\[ c=ab\\c[i,j]=\max_{k}(a[i,k]+b[k,j]) \]
将转移写为矩阵(注意\(g[i]=\max(g[i-1],f[i-1]+a[i],a[i])\)
\[ \begin{bmatrix} f[i]\\ g[i]\\ 0\end{bmatrix} = \begin{bmatrix} a[i]&-\infty&a[i]\\ a[i]&0&a[i]\\ -\infty&-\infty&0\end{bmatrix} \begin{bmatrix} f[i-1]\\ g[i-1]\\ 0\end{bmatrix} \]
可知每个元素\(a[i]\)都对应了一个矩阵,可以认为区间\([l,r]​\)的答案所在矩阵正是
\[ (\prod_{i=l}^k \begin{bmatrix} a[i]&-\infty&a[i]\\ a[i]&0&a[i]\\ -\infty&-\infty&0 \end{bmatrix} )\begin{bmatrix} 0\\ -\infty\\ 0\end{bmatrix} \]
因此可以用线段树维护区间矩阵乘积。

#include <bits/stdc++.h>
using namespace std;
const int n=5e4+10;
const int inf=0x3f3f3f3f;

struct mtr {
    int a[3][3];
    int*operator[](int d) {return a[d];}
    inline mtr() {}
    inline mtr(int val) {
        a[0][0]=a[0][2]=a[1][0]=a[1][2]=val;
        a[0][1]=a[2][0]=a[2][1]=-inf;
        a[1][1]=a[2][2]=0;
    }
    mtr operator*(mtr b) {
        static mtr c;
        memset(&c,-inf,sizeof c);
        for(int i=0; i<3; ++i) 
        for(int k=0; k<3; ++k) 
        for(int j=0; j<3; ++j) 
            c[i][j]=max(c[i][j],a[i][k]+b[k][j]);
        return c; 
    }
} t,a[n<<2];

#define ls (x<<1)
#define rs (x<<1|1)
void build(int x,int l,int r) {
    if(l==r) {
        scanf("%d",&r);
        a[x]=mtr(r);
        return;
    }
    int mid=(l+r)>>1;
    build(ls,l,mid);
    build(rs,mid+1,r);
    a[x]=a[ls]*a[rs];
}
void modify(int x,int l,int r,int p,int val) {
    if(l==r) {
        a[x]=mtr(val);
        return;
    }
    int mid=(l+r)>>1;
    if(p<=mid) modify(ls,l,mid,p,val);
    else modify(rs,mid+1,r,p,val);
    a[x]=a[ls]*a[rs];
}
mtr query(int x,int l,int r,int l,int r) {
    if(l<=l&&r<=r) return a[x];
    int mid=(l+r)>>1;
    if(r<=mid) return query(ls,l,mid,l,r);
    if(mid<l) return query(rs,mid+1,r,l,r);
    return query(ls,l,mid,l,r)*query(rs,mid+1,r,l,r); 
}

int main() {
    memset(&t,-inf,sizeof t); //notice
    t[0][0]=t[2][0]=0;
    int n,q;
    scanf("%d",&n);
    build(1,1,n);
    scanf("%d",&q);
    for(int op,l,r; q--; ) {
        scanf("%d%d%d",&op,&l,&r);
        if(op==0) modify(1,1,n,l,r);
        else {
            mtr ret=query(1,1,n,l,r)*t;
            printf("%d\n",max(ret[0][0],ret[1][0]));
        } 
    }
    return 0;
}

动态树上最大权独立集

注意断句 给出一棵\(n​\)个节点树和\(m​\)次操作,每次操作修改一个点权并计算当前树上的最大权独立集的权值。

重链剖分,设\(y\)\(x\)的某个儿子,\(s\)是重儿子,\(f[x,t]\)表示在以\(x\)为根的子树中不选/选\(x\)时的最大权独立集权值,\(g[x,t]\)表示在以\(x\)的为根的子树中除去以\(s\)为根的子树部分内不选/选\(x\)的最大权独立集权值,。
\[ g[x,0]=\sum_{y\not=s}\max(f[y,0],f[y,1])\\ g[x,1]=a[x]+\sum_{y\not=s} f[y,0]\\ f[x,0]=g[x,0]+\max(f[s,0],f[s,1])\\ f[x,1]=g[x,1]+f[s,0] \]
然后改写为矩阵乘法
\[ \begin{bmatrix} f[x,0]\\ f[x,1] \end{bmatrix}= \begin{bmatrix} g[x,0]&g[x,0]\\ g[x,1]&-\infty \end{bmatrix} \begin{bmatrix} f[s,0]\\ f[s,1] \end{bmatrix} \]
\(s\)不存在时,钦定\(f[s,0]=0\)\(f[s,1]=-\infty\)。进一步可发现在一条链内,链顶的\(f[,t]​\)值正是链上所有的“g矩阵”(应该明白指的是那个吧)乘起来的第一列值。

因此我们可以树剖维护重链上这些矩阵的乘积,更新时从修改点跳到每条重链的链顶,计算链顶部\(f[,t]\),更新他父亲的\(g[,t]\)(显然他不是父亲的重儿子),然后再跳往父亲所在重链……。

也可以lct来做,(试了一下树剖发现麻烦爆了)每次access修改点到根,然后对这部分重计算就好了。

#include <bits/stdc++.h>
using namespace std;
const int n=1e5+10;
const int inf=0x3f3f3f3f;

struct mtr {
    int a[2][2];
    int*operator[](int d) {return a[d];}
    mtr() {memset(a,-inf,sizeof a);}
    mtr operator*(mtr b) {
        mtr c;
        for(int i=0; i<2; ++i) 
        for(int k=0; k<2; ++k) 
        for(int j=0; j<2; ++j) 
            c[i][j]=max(c[i][j],a[i][k]+b[k][j]);
        return c; 
    }
} g[n],pg[n]; 

int n,m,a[n];
int head[n],to[n<<1],last[n<<1];
int fa[n],ch[n][2],dp[n][2];

void add_edge(int x,int y) {
    static int cnt=0;
    to[++cnt]=y,last[cnt]=head[x],head[x]=cnt;
}
void dfs(int x) {
    dp[x][1]=a[x];
    for(int i=head[x]; i; i=last[i]) {
        if(to[i]==fa[x]) continue;
        fa[to[i]]=x;
        dfs(to[i]);
        dp[x][0]+=max(dp[to[i]][0],dp[to[i]][1]);
        dp[x][1]+=dp[to[i]][0];
    }
    g[x][0][0]=g[x][0][1]=dp[x][0];
    g[x][1][0]=dp[x][1];
    pg[x]=g[x];
} 
void update(int x) {
    pg[x]=g[x];
    if(ch[x][0]) pg[x]=pg[ch[x][0]]*pg[x]; //无交换律 
    if(ch[x][1]) pg[x]=pg[x]*pg[ch[x][1]];
}
int get(int x) {
    return ch[fa[x]][0]==x?0:(ch[fa[x]][1]==x?1:-1);
}
void rotate(int x) {
    int y=fa[x],k=get(x);
    if(~get(y)) ch[fa[y]][get(y)]=x;
    fa[x]=fa[y];
    fa[ch[y][k]=ch[x][k^1]]=y;
    fa[ch[x][k^1]=y]=x;
    update(y);
    update(x); 
}
void splay(int x) {
    while(~get(x)) {
        int y=fa[x];
        if(~get(y)) rotate(get(x)^get(y)?x:y);
        rotate(x);
    }
} 
void access(int x) {
    for(int y=0; x; x=fa[y=x]) {
        splay(x);
        if(ch[x][1]) { //旧的重儿子 
            g[x][0][0]+=max(pg[ch[x][1]][0][0],pg[ch[x][1]][1][0]);
            g[x][1][0]+=pg[ch[x][1]][0][0];
        }
        if(y) { //新的重儿子 
            g[x][0][0]-=max(pg[y][0][0],pg[y][1][0]);
            g[x][1][0]-=pg[y][0][0];
        }
        g[x][0][1]=g[x][0][0]; //别忘了 
        ch[x][1]=y;
        update(x);
    }
}
void modify(int x,int y) {
    access(x);
    splay(x);
    g[x][1][0]+=y-a[x];
    update(x);
    a[x]=y;
}

int main() {
    scanf("%d%d",&n,&m);
    for(int i=1; i<=n; ++i) scanf("%d",a+i);
    for(int x,y,i=n; --i; ) {
        scanf("%d%d",&x,&y);
        add_edge(x,y);
        add_edge(y,x);
    }
    dfs(1); //所有连边是轻边 
    for(int x,y; m--; ) {
        scanf("%d%d",&x,&y);
        modify(x,y);
        splay(1);
        printf("%d\n",max(pg[1][0][0],pg[1][1][0]));
    }
    return 0;
}

全局平衡二叉树

然后讲一讲这道题的毒瘤加强版。传送门

数据加强并且经过特殊构造,树剖和lct都过不了了。树剖本身复杂度太大, o(\(m\log^2n\))过不了百万是很正常的;而lct虽然只有一个\(\log\) ,但由于常数过大也被卡了。

树剖的两个 \(\log\) 基本上可以放弃治疗了。但是我们不禁要问,lct究竟慢在哪里?

仔细想想,lct的access复杂度之所以是一个 \(\log​\) ,是由于splay的势能分析在整棵lct上依然成立,也就是说可以把lct看作一棵大splay,在这棵大splay上的一次access只相当于一次splay。

话虽然是这么说,但是实际上当我们不停地随机access的时候,要调整的轻重链数量还是很多的。感受一下,拿极端情形来说,如果树是一条链,一开始全是轻边,那么对链末端的结点access一次显然应该是 \(o(n)\)的。所以其实lct的常数大就大在它是靠势能法得到的 \(o(\log n)\),这么不靠谱的玩意是容易gg的。

但是如果我们不让lct放任*地access,而是一开始就给它构造一个比较优雅的姿态并让它静止(本来这棵树也不需要动),那么它也许就有救了。我们可以按照树链剖分的套路先划分出轻重边,然后对于重链建立一棵形态比较好的splay,至于轻儿子就跟原来的lct一样直接用轻边挂上即可。什么叫“形态比较好”呢?我们给每个点 \(x​\) 定义其权重为 size[x]-size[son[x]],其中 son[x] 是它的重儿子,那么对于一条重链,我们可以先找到它的带权重心作为当前节点,然后对左右分别递归建树。

by gkxx

似乎较lct常数更小,也蛮好些的。

#include <bits/stdc++.h> /*卡着时限过*/
using namespace std;

namespace io {
    const unsigned buffsize=1<<24,output=1<<24;
    static char ch[buffsize],*st=ch,*t=ch;
    inline char getc() {
        return((st==t)&&(t=(st=ch)+fread(ch,1,buffsize,stdin),st==t)?0:*st++);
    }
    static char out[output],*nowps=out;
    inline void flush() {
        fwrite(out,1,nowps-out,stdout);
        nowps=out;
    }
    template<typename t>inline void read(t&x) {
        x=0;
        static char ch;
        t f=1;
        for(ch=getc(); !isdigit(ch); ch=getc())if(ch=='-')f=-1;
        for(; isdigit(ch); ch=getc())x=x*10+(ch^48);
        x*=f;
    }
    template<typename t>inline void write(t x,char ch='\n') {
        if(!x)*nowps++=48;
        if(x<0)*nowps++='-',x=-x;
        static unsigned sta[111],tp;
        for(tp=0; x; x/=10)sta[++tp]=x%10;
        for(; tp; *nowps++=sta[tp--]^48);
        *nowps++=ch;
        flush();
    }
}
using io::read;
using io::write;

const int n=1e6+10;
const int inf=0x3f3f3f3f;

struct mtr {
    int a[2][2];
    int*operator[](int x) {return a[x]; }
    inline mtr() {}
    inline mtr(int g0,int g1) {
        a[0][0]=a[0][1]=g0;
        a[1][0]=g1;
        a[1][1]=-inf;
    }
    inline mtr operator*(mtr b) {
        mtr c;
        c[0][0]=max(a[0][0]+b[0][0],a[0][1]+b[1][0]);
        c[0][1]=max(a[0][0]+b[0][1],a[0][1]+b[1][1]);
        c[1][0]=max(a[1][0]+b[0][0],a[1][1]+b[1][0]);
        c[1][1]=max(a[1][0]+b[0][1],a[1][1]+b[1][1]);
        return c;
    }
    void print() {
        printf("%d %d\n%d %d\n\n",a[0][0],a[0][1],a[1][0],a[1][1]); 
    }
};

int n,m,a[n];
int head[n],to[n<<1],last[n<<1];
int siz[n],son[n],g[n][2];
inline void add_edge(int x,int y) {
    static int cnt=0;
    to[++cnt]=y,last[cnt]=head[x],head[x]=cnt;
}
void dfs1(int x,int pa) {
    siz[x]=1;
    g[x][1]=a[x];
    for(int i=head[x]; i; i=last[i]) {
        if(to[i]==pa) continue;
        dfs1(to[i],x);
        siz[x]+=siz[to[i]];
        if(siz[to[i]]>siz[son[x]]) son[x]=to[i];
        g[x][0]+=max(g[to[i]][0],g[to[i]][1]);
        g[x][1]+=g[to[i]][0];
    }
}
void dfs2(int x,int pa) {
    if(!son[x]) return;
    g[x][0]-=max(g[son[x]][0],g[son[x]][1]);
    g[x][1]-=g[son[x]][0];
    for(int i=head[x]; i; i=last[i]) 
        if(to[i]!=pa) dfs2(to[i],x); 
}

mtr g[n],pg[n];
int root,fa[n],ch[n][2];
int stk[n],tp;
bool is_root[n];

inline void update(int x) {
    pg[x]=g[x];
    if(ch[x][0]) pg[x]=pg[ch[x][0]]*pg[x];
    if(ch[x][1]) pg[x]=pg[x]*pg[ch[x][1]];
}
int chain(int l,int r) {
    if(r<l) return 0;
    int sum=0,pre=0;
    for(int i=l; i<=r; ++i) sum+=siz[stk[i]]-siz[son[stk[i]]];
    for(int i=l; i<=r; ++i) {
        pre+=siz[stk[i]]-siz[son[stk[i]]];
        if((pre<<1)>=sum) {
            int x=stk[i];
            ch[x][0]=chain(l,i-1);
            ch[x][1]=chain(i+1,r);
            if(ch[x][0]) fa[ch[x][0]]=x;
            if(ch[x][1]) fa[ch[x][1]]=x;
            update(x);
            return x;
        }
    }
    return 2333;
}
int tree(int top,int pa) {
    for(int x=top; x; x=son[pa=x]) {
        for(int i=head[x]; i; i=last[i]) {
            if(to[i]!=son[x]&&to[i]!=pa) {
                fa[tree(to[i],x)]=x;
            }
        } 
        g[x]=mtr(g[x][0],g[x][1]);
    }
    tp=0;
    for(int x=top; x; x=son[x]) stk[++tp]=x;
    return chain(1,tp);
}
inline void build() {
    root=tree(1,0);
    for(int i=1; i<=n; ++i) {
        is_root[i]=ch[fa[i]][0]!=i&&ch[fa[i]][1]!=i;
    }
}
inline int solve(int x,int y) {
    g[x][1]+=y-a[x];
    a[x]=y;
    for(int f0,f1; x; x=fa[x]) {
        f0=pg[x][0][0];
        f1=pg[x][1][0];
        g[x]=mtr(g[x][0],g[x][1]);
        update(x);
        if(fa[x]&&is_root[x]) {
            g[fa[x]][0]+=max(pg[x][0][0],pg[x][1][0])-max(f0,f1);
            g[fa[x]][1]+=pg[x][0][0]-f0;
        }
    }
    return max(pg[root][0][0],pg[root][1][0]);
}

int main() {
    read(n);
    read(m);
    for(int i=1; i<=n; ++i) read(a[i]);
    for(int x,y,i=n; --i; ) {
        read(x);
        read(y);
        add_edge(x,y);
        add_edge(y,x);
    }
    dfs1(1,0);
    dfs2(1,0);
    build();
    int lastans=0;
    for(int x,y; m--; ) {
        read(x);
        read(y);
        x^=lastans;
        lastans=solve(x,y);
        write(lastans);
    }
    return 0;
}

noip18 保卫王国

给出一棵\(n​\)个节点树和\(m​\)次询问,每次询问强制选/不选两个点然后计算当前树上的最小覆盖集,询问互相独立。

提示:强制选一个点就是把它的点权改成0,强制不选就是改成 \(\infty\);最小覆盖权值+最大独立集权值=总权值。

#include <bits/stdc++.h>
using namespace std;

const int n=1e6+10;
const long long inf=1e10;

struct mtr {
    long long a[2][2];
    long long*operator[](int x) {return a[x]; }
    inline mtr() {} 
    inline mtr(long long g0,long long g1) {
        a[0][0]=a[0][1]=g0;
        a[1][0]=g1;
        a[1][1]=-inf;
    }
    inline mtr operator*(mtr b) {
        mtr c;
        c[0][0]=max(a[0][0]+b[0][0],a[0][1]+b[1][0]);
        c[0][1]=max(a[0][0]+b[0][1],a[0][1]+b[1][1]);
        c[1][0]=max(a[1][0]+b[0][0],a[1][1]+b[1][0]);
        c[1][1]=max(a[1][0]+b[0][1],a[1][1]+b[1][1]);
        return c;
    }
};

int n,m;
long long a[n],g[n][2];
int head[n],to[n<<1],last[n<<1];
int prt[n],siz[n],son[n];
inline void add_edge(int x,int y) {
    static int cnt=0;
    to[++cnt]=y,last[cnt]=head[x],head[x]=cnt;
}
void dfs1(int x,int pa) {
    siz[x]=1;
    g[x][1]=a[x];
    for(int i=head[x]; i; i=last[i]) {
        if(to[i]==pa) continue;
        prt[to[i]]=x;
        dfs1(to[i],x);
        siz[x]+=siz[to[i]];
        if(siz[to[i]]>siz[son[x]]) son[x]=to[i];
        g[x][0]+=max(g[to[i]][0],g[to[i]][1]);
        g[x][1]+=g[to[i]][0];
    }
}
void dfs2(int x,int pa) {
    if(!son[x]) return;
    g[x][0]-=max(g[son[x]][0],g[son[x]][1]);
    g[x][1]-=g[son[x]][0];
    for(int i=head[x]; i; i=last[i]) 
        if(to[i]!=pa) dfs2(to[i],x); 
}

mtr g[n],pg[n];
int root,fa[n],ch[n][2];
int stk[n],tp;
bool is_root[n];

inline void update(int x) {
    pg[x]=g[x];
    if(ch[x][0]) pg[x]=pg[ch[x][0]]*pg[x];
    if(ch[x][1]) pg[x]=pg[x]*pg[ch[x][1]];
}
int chain(int l,int r) {
    if(r<l) return 0;
    int sum=0,pre=0;
    for(int i=l; i<=r; ++i) sum+=siz[stk[i]]-siz[son[stk[i]]];
    for(int i=l; i<=r; ++i) {
        pre+=siz[stk[i]]-siz[son[stk[i]]];
        if((pre<<1)>=sum) {
            int x=stk[i];
            ch[x][0]=chain(l,i-1);
            ch[x][1]=chain(i+1,r);
            if(ch[x][0]) fa[ch[x][0]]=x;
            if(ch[x][1]) fa[ch[x][1]]=x;
            update(x);
            return x;
        }
    }
    return 2333; 
}
int tree(int top,int pa) {
    for(int x=top; x; x=son[pa=x]) {
        for(int i=head[x]; i; i=last[i]) {
            if(to[i]!=son[x]&&to[i]!=pa) {
                fa[tree(to[i],x)]=x;
            }
        } 
        g[x]=mtr(g[x][0],g[x][1]);
    }
    tp=0;
    for(int x=top; x; x=son[x]) stk[++tp]=x;
    return chain(1,tp);
}
inline void build() {
    root=tree(1,0);
    for(int i=1; i<=n; ++i) {
        is_root[i]=ch[fa[i]][0]!=i&&ch[fa[i]][1]!=i;
    }
}
long long tot,res;
inline void solve(int x,long long y) {
    tot+=y;
    g[x][1]+=y;
    for(long long f0,f1; x; x=fa[x]) {
        f0=pg[x][0][0];
        f1=pg[x][1][0];
        g[x]=mtr(g[x][0],g[x][1]);
        update(x);
        if(fa[x]&&is_root[x]) {
            g[fa[x]][0]+=max(pg[x][0][0],pg[x][1][0])-max(f0,f1);
            g[fa[x]][1]+=pg[x][0][0]-f0;
        }
    }
}
inline void solve(int x,int p,int y,int q) {
    long long sx,sy;
    if(!p&&!q) sx=inf,sy=inf,res=0;
    else if(!p&&q) sx=inf,sy=0,res=a[y];
    else if(p&&!q) sx=0,sy=inf,res=a[x];
    else sx=0,sy=0,res=a[x]+a[y];
    solve(x,sx-a[x]);
    solve(y,sy-a[y]);
    res+=tot-max(pg[root][0][0],pg[root][1][0]);
    solve(x,a[x]-sx);
    solve(y,a[y]-sy);
}

char type[10]; 
int main() { //此代码 在-o2时极快
    freopen("defense.in","r",stdin);
    freopen("defense.ans","w",stdout); 
    scanf("%d%d%s",&n,&m,type);
    for(int i=1; i<=n; ++i) {
        scanf("%lld",a+i);
        tot+=a[i];
    }
    for(int x,y,i=n; --i; ) {
        scanf("%d%d",&x,&y);
        add_edge(x,y);
        add_edge(y,x);
    }
    dfs1(1,0);
    dfs2(1,0);
    build();
    for(int x,p,y,q; m--; ) {
        scanf("%d%d%d%d",&x,&p,&y,&q);
        if(!p&&!q&&(prt[x]==y||prt[y]==x)) {
            puts("-1");
            continue;
        }
        solve(x,p,y,q);
        printf("%lld\n",res);
    }
    return 0;
}

更多习(tian)题(keng)

bzoj4911 [sdoi2017]切树游戏

bzoj4721 洪水