欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

隐马尔可夫模型之:前向算法

程序员文章站 2022-05-06 12:21:42
...
隐马尔可夫模型(hidden markov model 简称hmm)广泛应用于语音识别,机器翻译等领域。

隐马尔可夫模型的具体定义,请参考著名论文《A tutorial on Hidden Markov Models and selected applications in speech recognition》,在阅读以下内容之前,建议读者阅读这篇论文的第I II III 节,理论性的东西在此不做赘述。

hmm通常解决以下三类问题:

1.给定一个hmm和观察序列,判断生成这个观察序列的可能性;
2.给定一个hmm和观察序列,给出最可能生成这个观察序列的隐藏序列;
3.给定一个观察序列,训练一个hmm。

第1个问题,通常称为评估问题,可以用前向算法(forward algorithm)来解决,使用了动态规划技术,将该问题的时间复杂度降为O(N*N*T),其中N为隐藏状态的个数,T为给定的观察序列的长度,下面给出java代码:

package hmm;

import java.util.HashMap;
import java.util.Map;

/**
 * 隐马尔可夫模型
 * @author xuguanglv
 *
 */
public class Hmm {
	//初始概率向量
	private static double[] pai = {0.63, 0.17, 0.20};
	
	//状态转移矩阵
	private static double[][] A = {{0.500, 0.375, 0.125},
							        {0.250, 0.125, 0.625},
							        {0.250, 0.375, 0.375}};
	
	//混淆矩阵
	private static double[][] B = {{0.60, 0.20, 0.15, 0.05},
							        {0.25, 0.25, 0.25, 0.25},
							        {0.05, 0.10, 0.35, 0.50}};
	
	//隐藏状态索引
	private static Map<String, Integer> hiddenStateIndex = new HashMap<String, Integer>();
	static{
		hiddenStateIndex.put("S(0)", 0);
		hiddenStateIndex.put("S(1)", 1);
		hiddenStateIndex.put("S(2)", 2);
	}
	
	//观察状态索引
	private static Map<String, Integer> observableStateIndex = new HashMap<String, Integer>();
	static{
		observableStateIndex.put("O(0)", 0);
		observableStateIndex.put("O(1)", 1);
		observableStateIndex.put("O(2)", 2);
		observableStateIndex.put("O(3)", 3);
	}
	
	//前向算法 根据观察序列和已知的隐马尔可夫模型 返回这个模型生成这个观察序列的概率
	//alpha[t][j]表示t时刻由隐藏状态S(j)生成观察状态O(t)的概率
	public static double forward(String[] observedSequence){
		double[][] alpha = new double[observedSequence.length][A.length];
		
		//利用动态规划计算出alpha数组
		//初始化
		for(int i = 0; i <= A.length - 1; i++){
			int index = observableStateIndex.get(observedSequence[0]);
			alpha[0][i] = pai[i] * B[i][index];
		}
		for(int t = 1; t <= observedSequence.length - 1; t++){
			for(int j = 0; j <= A.length - 1; j++){
				double sum = 0;
				for(int i = 0; i <= A.length - 1; i++){
					sum += (alpha[t - 1][i] * A[i][j]);
				}
				int index = observableStateIndex.get(observedSequence[t]);
				alpha[t][j] = sum * B[j][index];
			}
		}
		double prob = 0;
		for(int i = 0; i <= A.length - 1; i++){
			prob += alpha[observedSequence.length - 1][i];
		}
		return prob;
	}
	
	public static void main(String[] args){
		String[] observedSequence = {"O(0)", "O(2)", "O(3)"};
		System.out.println(forward(observedSequence));
	}
}