欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

Agc019_F Yes or No

程序员文章站 2022-05-01 23:05:36
...

传送门

题目大意

有若干道判断题,其中有$n$道答案是$Yes$,另外$m$道答案是$No$,问题除了答案差异本质相同。这些题一道都不会做,但是事先知道$n$和$m$的数量。每次机器会事先等概率地排列着$n+m$个答案(共$\dbinom{n+m}{n}$种可能),概率地选择一道没有问过的题目询问,然后答题者就必须给出答案,随后机器就会立即反馈你这道题是否判断错误,求如果采用最优策略,期望最多猜对多少道题,答案对$998244353$取模。

 

题解

神仙题,$tourist$出的是一个很难略微复杂的解法,然后被人用冷静思考大力分析给碾过去了$Orz$...

由于交换$n,m$对答案并无影响,不妨设$n\leq m$。

考虑每一种答案的排列对应着从$(m,n)$到$(0,0)$的每次只能沿着正下或正左走$1$的距离的一条路径。

由于答案取每种路径的概率是相等的,我们只需要算所有路径的期望之和即可。

假设对于某一条路径,到达了$x,y$,当$x<y$时,我们一定会猜它向下走,当$x>y$时,我们一定会才它向左走。

所以对于路径上所有$x\ne y$的状态猜对的数量路径与下图红色边的交集大小。

通过找规律可以发现这个值一定是$m$。因为当$x\ne y$时,一定会有某一个状态走到$x=y$,所以可以证明。

对于所有$x=y$的状态,不论怎么猜都会毫无头绪,所以贡献是$\frac {1}{2}$。

这部分的贡献是所有路径经过的$(x,y)(x=y,x,y>0)$的点的数量,除以所有路径数。

可以预处理阶乘组合数$O(n)$解决。

 

#include<algorithm>
#include<iostream>
#include<cstring>
#include<cstdio>
#include<cmath>
#define LL long long
#define mod 998244353
#define inv2 499122177
#define M 1000020
using namespace std;
namespace IO{
	int Top=0; char SS[20];
	void write(int x){
		if(!x){putchar('0');return;} if(x<0) x=-x,putchar('-');
		while(x) SS[++Top]=x%10,x/=10;
		while(Top) putchar(SS[Top]+'0'),--Top;
	}
	int read(){
		int nm=0,fh=1; char cw=getchar();
		for(;!isdigit(cw);cw=getchar()) if(cw=='-') fh=-fh;
		for(;isdigit(cw);cw=getchar()) nm=nm*10+(cw-'0');
		return nm*fh;
	}
}using namespace IO;
int mul(int x,int y){return (LL)x*(LL)y%mod;}
int add(int x,int y){return (x+y>=mod)?x+y-mod:x+y;}
int mus(int x,int y){return (x-y<0)?x-y+mod:x-y;}
int qpow(int x,int sq){
	int res=1;
	for(;sq;sq>>=1,x=mul(x,x)) if(sq&1) res=mul(res,x);
	return res;
}
int n,m,fac[M],ifac[M],ans;
int C(int tot,int tk){return mul(fac[tot],mul(ifac[tot-tk],ifac[tk]));}
int main(){
	n=read(),m=read(),fac[0]=1; if(n>m) swap(n,m);
	for(int i=1;i<=n+m;i++) fac[i]=mul(fac[i-1],i); ifac[n+m]=qpow(fac[n+m],mod-2);
	for(int i=n+m;i;i--) ifac[i-1]=mul(ifac[i],i); 
	for(int i=1;i<=n;i++) ans=add(ans,mul(C(i<<1,i),C(n+m-(i<<1),n-i)));
	write(add(mul(ans,mul(inv2,qpow(C(n+m,m),mod-2))),m)),putchar('\n'); return 0;
}