JZOJ 5378. 【NOIP2017提高A组模拟9.19】闷声刷大题(60分)
程序员文章站
2024-02-12 22:54:34
...
题目
60分题解
最小费用流!!!!!
看到选k个a[i],b[i],并且选的每个b的下标不小于对应的a的下标。
考虑最小费用流!!!!!
将每一天拆成两个点(A部分和B部分)。
S向每个点的A部分连边(费用为a[i],容量为1)。每个点的A向B连一条边(费用为0,容量为1)。
这样就可以表达选的A的下标小于等于对应的B的下标的意思。
首先,普通的SPFA求费用流,只有40分。
然而,ZKW费用流,60分,但我不会打。
怎么从40优化到60?
①在spfa中加一个STL优化(因为这个SPFA是求最短路的)即最新加入队列中的元素比目前队头更优,就先让最新加入队列中的元素出队。
②加入读入优化。(这个显然可以优化一点点时间)
③能把long long写成int的变量就写成int,这样子也会快很多。
代码
#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<cmath>
#define N 300010
#define LL long long
#define fo(i,a,b) for(i=a;i<=b;i++)
#define fd(i,a,b) for(i=a;i>=b;i--)
#define eg(i,x) for(i=head[x];i;i=edge[i].next)
using namespace std;
struct note{
int to,next;
LL flow,val;
};note edge[N*10];
LL head[N],tot,ans;
LL S,T,a[N],b[N];
LL dis[N];
int pre[N],qu[N*10];
bool bz[N],p;
LL i,j,k,l,n,m;
LL read()
{
LL fh=1,res=0;char ch;
while((ch<'0'||ch>'9')&&ch!='-')ch=getchar();
if (ch=='-')fh=-1,ch=getchar();
while(ch>='0'&&ch<='9')res=res*10+ch-'0',ch=getchar();
return res*fh;
}
void lb(int x,int y,LL z,LL w){
edge[++tot].to=y;edge[tot].next=head[x];edge[tot].flow=z;edge[tot].val=w;head[x]=tot;
edge[++tot].to=x;edge[tot].next=head[y];edge[tot].flow=0;edge[tot].val=-w;head[y]=tot;
}
bool spfa(){
memset(dis,127,sizeof(dis));dis[S]=0;
memset(bz,0,sizeof(bz));
memset(pre,0,sizeof(pre));
int t=0,w=1,i,now;
qu[1]=S;
while(t<w){
now=qu[++t];
eg(i,now){
if (edge[i].flow>0 && dis[edge[i].to]>dis[now]+edge[i].val){
dis[edge[i].to]=dis[now]+edge[i].val;
pre[edge[i].to]=i;
if(!bz[edge[i].to]){
bz[edge[i].to]=1;
qu[++w]=edge[i].to;
if (dis[qu[w]]<dis[qu[t+1]])swap(qu[t+1],qu[w]);
}
}
}
bz[now]=0;
}
if (dis[T]!=9187201950435737471) return 1;else return 0;
}
void find(){
LL x=T,sum=0,mx=9187201950435737471;
while (x!=S){
mx=min(mx,edge[pre[x]].flow);
x=edge[pre[x]^1].to;
}
x=T;
while (x!=S){
edge[pre[x]].flow-=mx;
edge[pre[x]^1].flow+=mx;
sum+=edge[pre[x]].val;
x=edge[pre[x]^1].to;
}
ans+=sum*mx;
}
bool cmp(LL x,LL y){return x<y;}
int main(){
freopen("orz.in","r",stdin);
freopen("orz.out","w",stdout);
n=read();k=read();
fo(i,1,n)a[i]=read();
fo(i,1,n)b[i]=read();
p=0;
fo(i,2,n)if(a[i]!=a[i-1]){
p=1;
break;
}
if(!p){
ans=0;
sort(b+1,b+n+1,cmp);
fo(i,1,k) ans+=b[i]+a[i];
printf("%lld",ans);
return 0;
}
tot=1;S=0,T=2*n+2;
fo(i,1,n){
lb(S,i,1,a[i]);
lb(i,i+n,1,0);
if (i<n) lb(i+n,i+n+1,9187201950435737471,0);
lb(i+n,1+n*2,1,b[i]);
}
lb(n*2+1,T,k,0);
ans=0;
while (spfa())find();
printf("%lld",ans);
return 0;
}