欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页  >  互联网

HDU5955 H - Guessing the Dice Roll (AC自动机 + 高斯消元)

程序员文章站 2022-07-04 20:59:56
题意:你有一个骰子,有六面分别为1到6,等概率的出现六个数其中之一,然后一共有n名玩家,每个玩家给出自己的一个长度为L的序列,在投掷筛子的过程中,如果有一名玩家给出的序列和目前筛出的序列的最后L位相同的话,这名玩家就赢得了游戏的胜利。问每名玩家胜利的概率?思路:n个人(多模式串)假如我先把n个人每个人长度为L的序列当做模式串插入,然后每次的扔骰子那么就相当于站在当前节点去走向下一个节点,那么我们就可以利用AC自动机fail数组建立的过程去进行状态的转移,表示出节点之间能否互相到达的关系。然后我们就可利....

HDU5955 H - Guessing the Dice Roll (AC自动机 + 高斯消元)
题意:你有一个骰子,有六面分别为1到6,等概率的出现六个数其中之一,然后一共有n名玩家,每个玩家给出自己的一个长度为L的序列,在投掷筛子的过程中,如果有一名玩家给出的序列和目前筛出的序列的最后L位相同的话,这名玩家就赢得了游戏的胜利。问每名玩家胜利的概率?

思路:n个人(多模式串)假如我先把n个人每个人长度为L的序列当做模式串插入,然后每次的扔骰子那么就相当于站在当前节点去走向下一个节点,那么我们就可以利用AC自动机fail数组建立的过程去进行状态的转移,表示出节点之间能否互相到达的关系。

然后我们就可利用AC自动机所完成的状态转移关系,构造出一个A[i][j]数组,表示从j节点转移到i节点的概率为多少。对于某个节点i,它的概率应该是等于所有能一步到达它的节点的概率再乘1/6,那么对于整个自动机的中的每个节点我们都可以列一个方程,ex:
x 1 = a 12 ∗ x 2 + a 13 ∗ x 3 + . . . . . + a 1 n ∗ x n x_1 = a_{12}*x_2 + a_{13}*x_3 + .....+a_{1n}*x_n x1=a12x2+a13x3+.....+a1nxn
这样我们如果让 a i i = − 1 a_{ii} = -1 aii=1的话就可以表示出一个A系数矩阵和X矩阵的线性方程组。

注意:到根节点的概率为1,也就是 x r o o t = 1 x_{root} = 1 xroot=1,就相当于根节点的这个概率提供了矩阵方程组的常数矩阵。

然后就可以AC自动机建立节点之间的状态转移关系,高斯消元解矩阵方程组,然后最后我们的答案就是那些叶子节点的的概率即可。

我写的代码跑的挺慢的,等有空我再研究研究
代码:

#include <bits/stdc++.h>

using namespace std;
typedef long long ll;
const int MAXN = 110;
const double eps = 1e-6;
int n;

struct Aca{
	int Next[MAXN][7],end[MAXN],fail[MAXN];
	int size;
	queue<int>q;
	
	int newnode(){
		++size;
		memset(Next[size],0,sizeof(Next[size]));
		end[size] = 0;
		return size;
  	}

  	void init(){
  		memset(Next[1],0,sizeof(Next[1]));
  		end[1] = 0;
  		size = 1;
  	}

	void insert(int *s,int l,int k){//指针s代表要插入的串为什么 l为它的长度
		int p = 1;
		for(int i = 1;i <= l;i ++){
			int ch = s[i];
			if(!Next[p][ch]) Next[p][ch] = newnode();
			p = Next[p][ch];
		}
		end[p] = k;//end数组存的是 插入的节点是第几个插入的串
	}

	void build(){
		fail[1] = 1;
		for(int i = 1;i <= 6;i ++){//把初始与1相连的节点 放进去
			int &u = Next[1][i];
			if(!u) u = 1;
			else {
				fail[u] = 1;
				q.push(u);
			}
		}
		/*借助AC自动机失配指针 来转移相关的状态 */
		while(!q.empty()){
			int p = q.front();
			q.pop();
			for(int i = 1;i <= 6;i ++){
				int now = Next[p][i];
				if(!now){
					Next[p][i] = Next[fail[p]][i];//如果当前节点不存在 那就把它引导父节点的失配节点的对应位置
					continue;
				}
				fail[now] =  Next[fail[p]][i];
				end[now] |= end[fail[now]];//关系的转移继承
				q.push(now);
			}
		}
	}

}ac;
/**********高斯消元************/
double A[MAXN<<2][MAXN<<2],x[MAXN<<2],ans[MAXN<<2];
int equ,var;//等式的个数 变量的个数

int Gauss(){

	for(int k = 1,col = 1;k <= equ && col <= var;k ++,col ++){
		int maxr = k;
		//找出绝对值最大的一行进行消除 这样是为了 提高数值稳定性 应该可以理解为准确度
		for(int i = k + 1;i <= equ;i ++){
			if(fabs(A[i][col] > fabs(A[maxr][col])))
				maxr = i;
			if(fabs(A[maxr][col]) < eps) return 0;//无解
			if(k != maxr){//交换当前行 和 最大行
				for(int j = col;j <= var;j ++)
					swap(A[k][j],A[maxr][j]);
				swap(x[k],x[maxr]);
			}
			x[k] /= A[k][col];
			for(int j = col + 1;j <= var;j ++)
				A[k][j] /= A[k][col];
			A[k][col] = 1;
			for(int i = 1;i <= equ;i ++){
				if(i != k){
					x[i] -= x[k]*A[i][k];
					for(int j = col + 1;j <= var;j ++)
						A[i][j] -= A[k][j]*A[i][col];
					A[i][col] = 0;//已经变为零了
				}
			}
		}
	}
	return 1;
}
/***************************/

int s[27];

int main(){
	int t,l;
	scanf("%d",&t);
	while(t--){
		ac.init();
		scanf("%d%d",&n,&l);
		for(int i = 1;i <= n;i ++){
			for(int j = 1;j <= l;j ++){
				scanf("%d",&s[j]);
			}
			ac.insert(s,l,i);
		}
		// cout<<ac.size<<endl;
		ac.build();
		// cout<<"***************"<<endl;
		memset(A,0,sizeof(A));
		memset(x,0,sizeof(x));
		memset(ans,0,sizeof(ans));
		equ = ac.size,var = ac.size;
		for(int i = 1;i <= ac.size;i ++)
			A[i][i] = -1;//自己到自己变成-1 相当于等式右边
		x[1] = -1;//用来提供 常矩阵 相当于

		for(int i = 1;i <= ac.size;i ++){
			if(!ac.end[i]){
				for(int j = 1;j <= 6;j ++){
					A[ac.Next[i][j]][i] += 1.0/6;
				}
			}
		}
		// cout<<"-------"<<endl;
		// for(int i = 1;i <= equ;i ++){
		// 	for(int j = 1;j <= var;j ++){
		// 		printf(j == var?"%f\n":"%f ",A[i][j]);
		// 	}
		// }
		// cout<<"-------"<<endl;

		Gauss();

		for(int i = 1;i <= ac.size;i ++){
			if(ac.end[i])
				ans[ac.end[i]] = x[i];
		}
		for(int i = 1;i <= n;i ++){
			if(i == n) printf("%f\n",ans[i]*(-1.0));
			else printf("%f ",ans[i]);
		}
	}
	return 0;
}

本文地址:https://blog.csdn.net/weixin_45672411/article/details/110198946

相关标签: 字符串 数学