欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

GMM

程序员文章站 2024-03-25 08:13:34
...

  许多概率模型有一系列可见变量vv和一系列潜变量hh,这时常常会涉及推断困难,就是指难以计算p(hv)p(h|v)或其期望,而这样的操作在一些诸如最大似然学习的任务中往往是必需的。为此可以把精确推断问题描述为一个优化问题,借此推导出推断算法。为了构造这样一个优化问题,假设一个具有可见变量vv和潜变量hh的概率模型,按照最大似然估计,我们希望计算观察数据的对数概率log p(v;θ)log\ p(v;\theta),则有
(1)log p(v;θ)=log hp(x,h;θ)log\ p(v;\theta)=log\ \sum_{h}p(x,h;\theta) \tag{1}
  有时候,边缘化消去hh的操作很费时或难以计算,作为替代,我们可以计算一个logp(v;θ)log p(v;\theta)的证据下界L(v,θ,q)\mathcal{L}(v,\theta,q)
(2)log p(v;θ)=log hp(x,h;θ)=log hq(hv)p(v,h;θ)q(hv)hq(hv)logp(v,h;θ)q(hv)=Ehq[log p(h,v;θ)log q(hv)]=log p(v;θ)Ehq[log q(hv)log p(h,v;θ)+log p(v;θ)]=log p(v;θ)Ehqlog q(hv)p(v,h;θ)p(v;θ)=log p(v;θ)Ehqlog q(hv)p(hv)=log p(v;θ)DKL(q(hv)p(hv;θ))=Ehq[log p(v,h)]+H(q)=L(v,θ,q)log\ p(v;\theta)=log\ \sum_{h}p(x,h;\theta)=log\ \sum_hq(h|v)\frac{p(v,h;\theta)}{q(h|v)}\\ \geq\sum_hq(h|v)log\frac{p(v,h;\theta)}{q(h|v)}=\mathbb{E}_{h\sim q}[log\ p(h,v;\theta)-log\ q(h|v)]\\ =log\ p(v;\theta)-\mathbb{E}_{h\sim q}[log\ q(h|v)-log\ p(h,v;\theta)+log\ p(v;\theta)]\\ =log\ p(v;\theta)-\mathbb{E}_{h\sim q}log\ \frac{q(h|v)}{\frac{p(v,h;\theta)}{p(v;\theta)}}=log\ p(v;\theta)-\mathbb{E}_{h\sim q}log\ \frac{q(h|v)}{p(h|v)}\\ =log\ p(v;\theta)-D_{KL}(q(h|v)||p(h|v;\theta))=\mathbb{E}_{h\sim q}[log\ p(v,h)]+H(q)=\mathcal{L}(v,\theta,q) \tag{2}
  其中q(hv)q(h|v)表示给定一个数据的可见变量vv,其潜变量为hh的概率,对于一个选择的合适分布qq来说,L\mathcal{L}是容易计算的,对任意分布qq的选择来说,L\mathcal{L}提供了似然函数的一个下界。越好地近似p(hv)p(h|v)的分布q(hv)q(h|v),其下界越紧,当q(hv)=p(hv)q(h|v)=p(h|v)时,这个近似是完美的,也意味着L(v,θ,q)=log p(v;θ)\mathcal{L}(v,\theta,q)=log\ p(v;\theta),因此可以将推断问题看作找一个分布qq使L\mathcal{L}最大的过程。
  在潜变量模型中,EM算法是非常常见的训练方法,这是一种能够学到近似后验的算法。

  • E步:令θ(0)\theta^{(0)}表示在这一步开始时的参数值。对任何我们想要训练的(对所有的或者小批量数据均成立)索引为ii的训练样本v(i)v^{(i)},令q(hv(i))=p(hv(i);θ)q(h|v^{(i)})=p(h|v^{(i)};\theta)。通过这个定义,我们认为qq在当前参数θ(0)\theta^{(0)}下定义,如果我们改变θ\theta,那么p(hv;θ)p(h|v;\theta)将会相应变化,但是q(hv)q(h|v)还是不变并且等于p(hv;θ(0))p(h|v;\theta^{(0)})
  • M步:使用选择的优化算法完全地或部分地关于θ\theta最大化iL(v(i),θ,q)=ihq(hv(i))log p(v(i),h;θ)q(hv(i))\sum_i\mathcal{L}(v^{(i)},\theta,q)=\sum_i\sum_hq(h|v^{(i)})log\ \frac{p(v^{(i)},h;\theta)}{q(h|v^{(i)})}
      对于高斯混合模型,其概率分布为
    (3)p(v;θ)=k=1KπkN(vμk,Σk)p(v;\theta)=\sum_{k=1}^K\pi_kN(v|\mu_k,\Sigma_k) \tag{3}
    其中k=1Kπk=1\sum_{k=1}^K\pi_k=1,这时的vv是数据样本,潜变量h={h1,h2, ,hK},hk={0,1}h=\{h_1,h_2,\cdots,h_K\},h_k=\{0,1\}代表数据样本是否来自KK个高斯分布中的第kk个高斯分布,在E步中有
    (4)q(hk=1v(i))=πkN(v(i)μk,Σk)k=1KπkN(v(i)μk,Σk)=πkN(v(i)μk,Σk)p(v(i);θ)q(h_k=1|v^{(i)})=\frac{\pi_kN(v^{(i)}|\mu_k,\Sigma_k)}{\sum_{k=1}^K\pi_kN(v^{(i)}|\mu_k,\Sigma_k)}=\frac{\pi_kN(v^{(i)}|\mu_k,\Sigma_k)}{p(v^{(i)};\theta)} \tag{4}
      在M步中,为了求解模型参数πk,μk,Σk\pi_k,\mu_k,\Sigma_k,需要对各参数求偏导
    (5)Jπ=ikq(hkv(i))log (πk)+λ(kπk1)J_\pi=\sum_i\sum_kq(h_k|v^{(i)})log\ (\pi_k)+\lambda(\sum_k\pi_k-1) \tag{5}
    (6)Jππk=iq(hkv(i))1πk+λ=0\frac{\partial J_\pi}{\partial\pi_k}=\sum_i q(h_k|v^{(i)})\frac{1}{\pi_k}+\lambda=0 \tag{6}
    设样本数为mm,因为kπk=1\sum_k\pi_k=1,所以有
    (7)ikq(hkv(i))=λkπk=λ=i1=m\sum_i\sum_kq(h_k|v^{(i)})=-\lambda\sum_k\pi_k=-\lambda=\sum_i1=m \tag{7}
    (8)πk=iq(hkv(i))m\pi_k=\frac{\sum_iq(h_k|v^{(i)})}{m} \tag{8}
    同理求导则有
    (9)μk=iq(hkv(i))v(i)iq(hkv(i))\mu_k=\frac{\sum_iq(h_k|v^{(i)})v^{(i)}}{\sum_iq(h_k|v^{(i)})} \tag{9}
    (10)Σk=iq(hkv(i))(v(i)μk)(v(i)μk)Tiq(hkv(i))\Sigma_k=\frac{\sum_iq(h_k|v^{(i)})(v^{(i)}-\mu_k)(v^{(i)}-\mu_k)^T}{\sum_iq(h_k|v^{(i)})} \tag{10}

Matlab代码实现

classdef GMM
    properties
        numOfGMs;                         %高斯分布的个数
        pDim;                             %可见变量的空间维数
        pMix;                             %每个分布的混合比例
        mu;                               %均值参数
        sigma;                            %方差参数
        trainData_x;                      %训练数据
        trainData_q_h;                    %训练数据的潜变量编码分布
        N_sample;                         %训练数据样本总数
        batchSize;                        %批次大小
        IterMax;                          %最大迭代次数
        iter_train;                       %当前迭代次数
    end
    methods
        function obj = GMM(pDim,numOfGMs) %类库构造函数
            obj.numOfGMs = numOfGMs;      %初始化类对象参数
            obj.pDim = pDim;
            obj.pMix = 1 / numOfGMs * ones(1,numOfGMs);
            obj.mu = rand(numOfGMs,pDim);
            obj.sigma = ones(pDim,pDim,numOfGMs);
            for k = 1 : obj.numOfGMs
                obj.sigma(:,:,k) = eye(obj.pDim);
            end
            obj.batchSize = 100;
            obj.iter_train = 0;
            obj.IterMax = 100;
        end
        function obj = train(obj,trainData_x,Iter)%训练
            if(nargin == 2)
                Iter = obj.IterMax;
            end
            obj.trainData_x = trainData_x;
            obj.N_sample = size(obj.trainData_x,1);
            batch_size = obj.batchSize;
            trainData_q_h_batch = zeros(batch_size,obj.numOfGMs);
            diagMat = eye(batch_size);
            diagMat0 = find(diagMat);
            findDiag = diagMat0;
            while obj.iter_train < Iter
                obj.iter_train = obj.iter_train + 1;
                trainData_x_batch = obj.trainData_x(randperm(obj.N_sample , batch_size) , :);
                %E步,求解q(h|v)
                for k = 1 : obj.numOfGMs
                    trainData_q_h_batch(:,k) = obj.pMix(k) * mvnpdf(trainData_x_batch,obj.mu(k,:),obj.sigma(:,:,k));
                end
                trainData_q_h_batch = trainData_q_h_batch ./ sum(trainData_q_h_batch,2);
                %M步,求解模型参数
                obj.pMix = mean(trainData_q_h_batch);
                sumTrainDataQH = sum(trainData_q_h_batch)';
                obj.mu = trainData_q_h_batch' * trainData_x_batch ./ sumTrainDataQH;
                for k = 1 : obj.numOfGMs
                    diagMat(findDiag) = trainData_q_h_batch(:,k);
                    obj.sigma(:,:,k) = (trainData_x_batch - obj.mu(k,:))' * diagMat * (trainData_x_batch - obj.mu(k,:)) / sumTrainDataQH(k);
                end
            end
        end
        function [obj] = predict(obj)     %预测重构
            batch_size = obj.N_sample;
            indexRandBatch = randperm(obj.N_sample , batch_size);
            testData_x_batch = obj.trainData_x(indexRandBatch , :);
            %求解q(h|v)
            trainData_q_h_batch = zeros(batch_size,obj.numOfGMs);
            for k = 1 : obj.numOfGMs
                trainData_q_h_batch(:,k) = obj.pMix(k) * mvnpdf(testData_x_batch,obj.mu(k,:),obj.sigma(:,:,k));
            end
            trainData_q_h_batch = trainData_q_h_batch ./ sum(trainData_q_h_batch,2);
            obj.trainData_q_h(indexRandBatch,:) = trainData_q_h_batch;
        end
    end
end