GMM
程序员文章站
2024-03-25 08:13:34
...
许多概率模型有一系列可见变量和一系列潜变量,这时常常会涉及推断困难,就是指难以计算或其期望,而这样的操作在一些诸如最大似然学习的任务中往往是必需的。为此可以把精确推断问题描述为一个优化问题,借此推导出推断算法。为了构造这样一个优化问题,假设一个具有可见变量和潜变量的概率模型,按照最大似然估计,我们希望计算观察数据的对数概率,则有
有时候,边缘化消去的操作很费时或难以计算,作为替代,我们可以计算一个的证据下界
其中表示给定一个数据的可见变量,其潜变量为的概率,对于一个选择的合适分布来说,是容易计算的,对任意分布的选择来说,提供了似然函数的一个下界。越好地近似的分布,其下界越紧,当时,这个近似是完美的,也意味着,因此可以将推断问题看作找一个分布使最大的过程。
在潜变量模型中,EM算法是非常常见的训练方法,这是一种能够学到近似后验的算法。
- E步:令表示在这一步开始时的参数值。对任何我们想要训练的(对所有的或者小批量数据均成立)索引为的训练样本,令。通过这个定义,我们认为在当前参数下定义,如果我们改变,那么将会相应变化,但是还是不变并且等于。
- M步:使用选择的优化算法完全地或部分地关于最大化。
对于高斯混合模型,其概率分布为
其中,这时的是数据样本,潜变量代表数据样本是否来自个高斯分布中的第个高斯分布,在E步中有
在M步中,为了求解模型参数,需要对各参数求偏导
设样本数为,因为,所以有
同理求导则有
Matlab代码实现
classdef GMM
properties
numOfGMs; %高斯分布的个数
pDim; %可见变量的空间维数
pMix; %每个分布的混合比例
mu; %均值参数
sigma; %方差参数
trainData_x; %训练数据
trainData_q_h; %训练数据的潜变量编码分布
N_sample; %训练数据样本总数
batchSize; %批次大小
IterMax; %最大迭代次数
iter_train; %当前迭代次数
end
methods
function obj = GMM(pDim,numOfGMs) %类库构造函数
obj.numOfGMs = numOfGMs; %初始化类对象参数
obj.pDim = pDim;
obj.pMix = 1 / numOfGMs * ones(1,numOfGMs);
obj.mu = rand(numOfGMs,pDim);
obj.sigma = ones(pDim,pDim,numOfGMs);
for k = 1 : obj.numOfGMs
obj.sigma(:,:,k) = eye(obj.pDim);
end
obj.batchSize = 100;
obj.iter_train = 0;
obj.IterMax = 100;
end
function obj = train(obj,trainData_x,Iter)%训练
if(nargin == 2)
Iter = obj.IterMax;
end
obj.trainData_x = trainData_x;
obj.N_sample = size(obj.trainData_x,1);
batch_size = obj.batchSize;
trainData_q_h_batch = zeros(batch_size,obj.numOfGMs);
diagMat = eye(batch_size);
diagMat0 = find(diagMat);
findDiag = diagMat0;
while obj.iter_train < Iter
obj.iter_train = obj.iter_train + 1;
trainData_x_batch = obj.trainData_x(randperm(obj.N_sample , batch_size) , :);
%E步,求解q(h|v)
for k = 1 : obj.numOfGMs
trainData_q_h_batch(:,k) = obj.pMix(k) * mvnpdf(trainData_x_batch,obj.mu(k,:),obj.sigma(:,:,k));
end
trainData_q_h_batch = trainData_q_h_batch ./ sum(trainData_q_h_batch,2);
%M步,求解模型参数
obj.pMix = mean(trainData_q_h_batch);
sumTrainDataQH = sum(trainData_q_h_batch)';
obj.mu = trainData_q_h_batch' * trainData_x_batch ./ sumTrainDataQH;
for k = 1 : obj.numOfGMs
diagMat(findDiag) = trainData_q_h_batch(:,k);
obj.sigma(:,:,k) = (trainData_x_batch - obj.mu(k,:))' * diagMat * (trainData_x_batch - obj.mu(k,:)) / sumTrainDataQH(k);
end
end
end
function [obj] = predict(obj) %预测重构
batch_size = obj.N_sample;
indexRandBatch = randperm(obj.N_sample , batch_size);
testData_x_batch = obj.trainData_x(indexRandBatch , :);
%求解q(h|v)
trainData_q_h_batch = zeros(batch_size,obj.numOfGMs);
for k = 1 : obj.numOfGMs
trainData_q_h_batch(:,k) = obj.pMix(k) * mvnpdf(testData_x_batch,obj.mu(k,:),obj.sigma(:,:,k));
end
trainData_q_h_batch = trainData_q_h_batch ./ sum(trainData_q_h_batch,2);
obj.trainData_q_h(indexRandBatch,:) = trainData_q_h_batch;
end
end
end