欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

opencv3/C++ 机器学习-EM算法/Expectation Maximization

程序员文章站 2022-07-14 14:38:28
...

EM算法/Expectation Maximization

EM算法包含两步:E,求期望(Expectation),利用概率模型参数的现有估计值,计算隐藏变量的期望;M,求极大(Maximization),利用E 步上求得的隐藏变量的期望,对参数模型进行最大似然估计。所得参数估计值用于下个E步的计算,重复至收敛。

期望最大化/EM算法以具有一定数量混合物的高斯混合分布的形式估计多变量概率密度函数的参数。

考虑从高斯混合模型画出的d维欧几里得空间中的N个特征向量{ x1x2...xN}的集合:

 p(x;ak;Sk;πk)=k=1mπkpk(x),πk0,k=1mπk=1

 pk(x)=φ(x;ak,Sk)=1(2π)d/2|Sk|1/2exp{12(xak)TSk1(xak)}

其中m是高斯混合模型的数量, pk是具有均值 ak和协方差矩阵 Sk的正态分布密度, πk是第k个高斯混合模型的权重。 给定高斯混合模型个数M和样本 xii=1...N,算法找到所有高斯混合模型参数的最大似然估计(MLE),即 akSk πk

L(x,θ)=logp(x,θ)=i=1nlog(k=1mπkpk(x))maxθΘ,

Θ=
{(ak,Sk,πk):akRd,Sk=SkT>0,SkRd×d,πk0,k=1mπk=1}

EM算法是一个迭代过程。 每次迭代包括两个步骤。 在第一步E步即预期步骤中,使用当前可用的混合参数估计值,可以找出样本i属于混合模型k的概率 pik(在下面的公式中表示为 αik):

aki=πkφ(x;ak,Sk)j=1mπjφ(x;aj,Sj)

在第二步M步即最大化步骤中,使用计算出的概率对高斯混合模型的参数估计值进行细化:

πk=1Ni=1Naki;

πk=i=1Nakixii=1Naki ;

Sk=i=1Naki(xiak)(xiak)Ti=1Naki

或者,当提供 pik的初始值时,该算法可以从M步开始。 当 pik未知时的另一种选择是使用更简单的聚类算法对输入采样进行预先聚类,从而获得初始的 pik(通常用k-means算法实现)。

EM算法的一个主要问题是需要估计大量参数。 大多数参数存在于协方差矩阵中,这些矩阵大小为d×d,其中d是特征空间维度。 但在许多实际问题中,协方差矩阵接近于对角线或者甚至接近μkI,其中I是单位矩阵,μk是混合相关的“比例”参数。 因此,一个健壮的计算方案是对协方差矩阵加较强的约束,然后用估计的参数作为较少约束优化问题的输入(通常对角协方差矩阵已经足够了)。

OpenCV EM类

相关函数

  • virtual void setClustersNumber(int val);
    高斯混合模型中混合成分的数量。默认值是EM :: DEFAULT_NCLUSTERS = 5。
  • virtual void setCovarianceMatrixType(int val);
    协方差矩阵的类型。协方差矩阵的约束定义了其类型。
    COV_MAT_SPHERICAL= 0:缩放的单位矩阵 μkI。 对每个矩阵估计唯一的参数μk。用于约束条件相关时或作为优化的第一步(例如数据用PCA预处理时)。
    COV_MAT_DIAGONAL= 1:具有正对角元素的对角矩阵。 每个矩阵d个*参数。 (常选项,估算结果良好)
    COV_MAT_GENERIC= 2:对称正定矩阵。 每个矩阵中的*参数大约d2/2个。 不建议使用此选项,除非对参数或大量训练样本有相当准确的初始估计。

  • virtual bool trainEM(InputArray samples,OutputArray logLikelihoods=noArray(),OutputArray labels=noArray(),OutputArray probs=noArray()) ;
    估计样本集中的高斯混合模型参数。
    这种变化开始于Expectation步。模型参数的初始值通过k-means估计。与许多ML模型不同,EM是无监督学习算法,因此训练时不用输入类标签。通过样本数据计算高斯混合参数的最大似然估计,将结构中的所有参数进行存储:pik存概率, ak 存均值,Sk存covs [k],πk存权重,并且可选地为每个样本计算输出“类别标签”: labelsi=arg maxk(pi,k),i=1..N(每个样本的最可能的模型分量的索引)。训练好的模型可以用于预测。

samples ::样本。单通道矩阵,每一行为一个样本。若矩阵不是CV_64F类型,则将被转换为此类型的内部矩阵。
logLikelihoods ::可选输出矩阵,包含每个样本的似然对数值。大小nsamples×1,类型CV_64FC1。
labels ::每个样本的输出“类别标签”:labelsi=arg maxkpiki=1..N每个样本最可能的高斯混合模型分量)。大小nsamples×1 ,类型CV_32SC1。
probs ::可选输出矩阵,包含每个给定样本的各个高斯混合模型分量的后验概率。大小 nsamples×nclusters ,类型CV_64FC1。

应用示例

图像分割

使用EM算法对图像进行分割。

#include <opencv2/opencv.hpp>
#include <iostream>
using namespace std;
using namespace cv;
using namespace cv::ml;

int main()
{
    Vec3b colors[] ={Vec3b(0, 0, 255), Vec3b(0, 255, 0), Vec3b(255, 100, 100), Vec3b(255, 0, 255)};
    Mat data, labels, src, dst;
    src = imread("E:/image/image/red.jpg", 1);
    resize(src, src, Size(src.cols/1.5,src.rows/1.5));
    if(src.empty())
    {
        printf("can not load image \n");
        return -1;
    }
    src.copyTo(dst);
    for (int i = 0; i < src.rows; i++)
    for (int j = 0; j < src.cols; j++)
    {
        Vec3b point = src.at<Vec3b>(i, j);
        Mat tmp = (Mat_<float>(1, 3) << point[0], point[1], point[2]);
        data.push_back(tmp);
    }

    Ptr<EM> model = EM::create();
    model->setClustersNumber(4); //类个数
    model->setCovarianceMatrixType(EM::COV_MAT_SPHERICAL);
    model->setTermCriteria(TermCriteria(TermCriteria::COUNT + TermCriteria::EPS, 300, 0.1));
    model->trainEM(data, noArray(), labels, noArray());

    int n = 0;
    //显示结果,不同的类别用不同的颜色
    for (int i = 0; i < dst.rows; i++)
    for (int j = 0; j < dst.cols; j++)
    {
        int index = labels.at<int>(n);
        dst.at<Vec3b>(i, j) = colors[index];
        n++;
    }
    imshow("src", src);
    imshow("dst", dst);
    waitKey(0);

    return 0;
}

opencv3/C++ 机器学习-EM算法/Expectation Maximization

点坐标分类

从文件points.txt中读取点坐标以及对应的分类,然后使用EM算法对点所在区域进行划分。

#include <opencv2/opencv.hpp>
#include "opencv2/ml.hpp"
#include <iostream>  
#include <fstream> 

using namespace std;
using namespace cv;
using namespace cv::ml;

//EM算法
int main()
{
    Mat src, dst;
    vector<Point>  trainedPoints;
    vector<int>    trainedPointsMarkers;
    //读取文件中的点坐标
    FILE *fp;
    int flge = 0;
    int fpoint,flabel;
    fp = fopen("E:\\points.txt", "r+");
    if (fp == NULL)
    {
        printf("Cannot open the file!\n");
        exit(0);
    }
    Point point;
    while (!feof(fp))
    {   
        fscanf(fp, "%d", &fpoint);
        if (feof(fp)) break;
        //依次为横坐标、纵坐标、分类
        if ((
            flge%3==0? point.x = fpoint: 
            flge%3==1? point.y = fpoint:
            flge%3==2? flabel = fpoint : -1)<0) 
        return -1;

        if (flge%3==2)
        {
            trainedPoints.push_back(point);
            trainedPointsMarkers.push_back(flabel);
        }
        flge++;
    }

    vector<Vec3b>  colors(4);
    colors[0] = Vec3b(0, 255, 0);
    colors[1] = Vec3b(0, 0, 255);
    colors[2] = Vec3b(0, 255, 255);
    colors[3] = Vec3b(255, 0, 0);

    src.create( 480, 640, CV_8UC3 );
    src = Scalar::all(0);
    // 绘制点
    for( size_t i = 0; i < trainedPoints.size(); i++ )
    {
        Scalar c = colors[trainedPointsMarkers[i]];
        circle( src, trainedPoints[i], 3, c, -1 );
    }
    src.copyTo(dst);
    imshow( "points", src );

    Mat samples;
    Mat(trainedPoints).reshape(1, (int)trainedPoints.size()).convertTo(samples, CV_32F);
    int nmodels = (int)colors.size();
    vector<Ptr<EM> > em_models(nmodels);
    Mat modelSamples;

    for( int i = 0; i < nmodels; i++ )
    {
        modelSamples.release();
        for( int j = 0; j < samples.rows; j++ )
        {
            if( trainedPointsMarkers[j] == i )
                modelSamples.push_back(samples.row(j));
        }
        // 训练模型
        if( !modelSamples.empty() )
        {
            const int componentCount = 5;
            Ptr<EM> em = EM::create();
            //高斯混合模型中混合成分的数量
            em->setClustersNumber(componentCount);
            //协方差矩阵的类型。
            em->setCovarianceMatrixType(EM::COV_MAT_DIAGONAL);
            //训练模型
            em->trainEM(modelSamples, noArray(), noArray(), noArray());
            em_models[i] = em;
        }
    }
    Mat testSample(1, 2, CV_32FC1 );
    Mat logLikelihoods(1, nmodels, CV_64FC1, Scalar(-DBL_MAX));
    for( int y = 0; y < src.rows; y += 3 )
    {
        for( int x = 0; x < src.cols; x += 3 )
        {
            testSample.at<float>(0) = (float)x;
            testSample.at<float>(1) = (float)y;
            for( int i = 0; i < nmodels; i++ )
            {
                if( !em_models[i].empty() )
                    logLikelihoods.at<double>(i) = em_models[i]->predict2(testSample, noArray())[0];
            }
            Point maxLoc;
            minMaxLoc(logLikelihoods, 0, 0, 0, &maxLoc);
            dst.at<Vec3b>(y, x) = colors[maxLoc.x];
        }
    }
    imshow( "EM", dst );

    waitKey();
    return 0;
}

opencv3/C++ 机器学习-EM算法/Expectation Maximizationopencv3/C++ 机器学习-EM算法/Expectation Maximization

文件points.txt中的内容为:
(依次为横坐标、纵坐标、分类)

281 234 0
265 227 0
261 204 0
273 185 0
298 171 0
326 178 0
328 206 0
330 226 0
323 245 0
300 256 0
280 259 0
262 257 0
245 240 0
238 229 0
236 206 0
239 184 0
265 158 0
279 154 0
297 153 0
325 153 0
338 168 0
352 206 0
353 229 0
305 211 0
305 237 0
240 308 3
219 300 3
200 281 3
189 237 3
184 196 3
191 163 3
214 140 3
246 128 3
288 127 3
337 128 3
265 122 3
306 123 3
287 113 3
322 114 3
349 118 3
368 135 3
389 174 3
399 197 3
399 233 3
388 261 3
365 295 3
344 310 3
280 319 3
265 319 3
309 316 3
262 333 3
223 320 3
193 303 3
180 290 3
175 262 3
165 228 3
169 183 3
177 143 3
186 127 3
205 115 3
260 104 3
235 97  3
293 97  3
357 102 3
392 119 3
408 150 3
396 149 3
412 200 3
419 247 3
405 291 3
366 320 3
303 332 3
341 333 3
330 105 3
273 92  3
280 209 0
333 256 0
298 271 0
165 163 3
157 205 3
159 246 3
159 276 3
171 309 3
189 327 3
207 333 3
239 340 3
284 352 3
339 351 3
371 334 3
382 311 3
394 282 3
423 269 3
425 218 3
424 179 3
411 157 3
340 85  3
308 85  3
221 118 3
151 175 3
254 88  3
281 79  3
185 374 1
171 359 1
151 342 1
137 323 1
122 298 1
116 272 1
116 233 1
116 196 1
119 157 1
119 139 1
137 103 1
145 92  1
157 73  1
181 62  1
218 55  1
260 51  1
317 49  1
362 53  1
393 65  1
427 83  1
454 112 1
468 130 1
490 168 1
504 197 1
511 219 1
516 235 1
539 276 1
563 335 1
580 373 1
593 404 1
605 425 1
616 444 1
628 466 1
222 382 1
259 384 1
306 389 1
293 401 1
274 402 1
247 402 1
210 402 1
193 393 1
159 380 1
141 369 1
123 339 1
116 316 1
105 294 1
95  270 1
92  244 1
89  217 1
88  188 1
89  160 1
90  111 1
93  94  1
81  132 1
110 117 1
116 73  1
149 45  1
165 36  1
222 35  1
290 27  1
251 31  1
351 40  1
329 35  1
284 40  1
384 48  1
410 55  1
441 77  1
461 102 1
484 130 1
510 171 1
524 201 1
535 232 1
546 260 1
562 297 1
579 322 1
595 351 1
615 379 1
632 406 1
620 400 1
473 230 2
479 268 2
477 299 2
473 316 2
467 332 2
462 252 2
462 303 2
457 330 2
450 346 2
428 375 2
395 397 2
381 406 2
320 436 2
315 437 2
283 447 2
224 450 2
190 445 2
177 442 2
140 428 2
125 416 2
99  405 2
84  392 2
79  387 2
59  347 2
56  336 2
45  299 2
40  259 2
37  224 2
25  155 2
19  97  2
17  43  2
17  26  2
20  75  2
18  132 2
18  170 2
11  202 2
13  233 2
18  265 2
25  307 2
32  342 2
44  375 2
51  389 2
59  410 2
69  430 2
117 459 2
179 460 2
110 442 2
186 472 2
286 476 2
353 467 2
377 457 2
236 462 2
318 459 2
353 447 2
386 427 2
404 419 2
448 381 2
413 431 2
428 399 2
470 342 2
447 401 2
375 438 2
369 421 2
337 445 2
265 464 2
220 471 2
182 470 2
138 453 2
85  417 2
119 429 2
149 447 2
157 465 2
561 274 1
553 317 1
586 359 1
586 300 1
624 382 1
623 318 1
608 297 1
577 240 1
556 219 1
550 205 1
522 152 1
619 325 1
628 347 1
627 279 1
578 215 1
546 189 1
601 264 1
587 231 1
496 137 1
492 104 1
479 85  1
464 66  1
441 49  1
423 42  1
364 24  1
345 21  1
314 18  1
262 17  1
243 17  1
202 22  1
175 46  1
152 61  1
125 91  1
101 151 1
99  172 1
102 233 1
104 246 1
76  159 1
83  196 1
80  235 1
87  258 1
90  277 1
97  307 1
110 323 1
136 360 1
160 366 1
185 383 1
179 401 1
221 401 1
265 396 1
230 390 1
267 412 1
309 404 1
323 399 1
341 390 1
328 386 1
463 366 2
433 409 2
428 431 2
401 454 2
358 468 2
329 472 2
294 464 2
250 468 2
224 476 2
205 463 2
249 446 2
210 439 2
161 427 2
115 404 2
98  393 2
84  377 2
73  367 2
52  334 2
41  324 2
20  303 2
16  284 2
10  320 2
14  353 2
8   353 2
9   384 2
55  418 2
36  391 2
29  395 2
32  208 2
15  165 2
8   127 2
6   98  2
5   72  2
5   63  2
5   4   2
6   29  2
31  242 2
33  275 2
7   300 2
20  338 2
28  367 2
25  386 2
29  414 2
43  430 2
68  450 2
102 469 2
97  457 2
92  439 2
100 430 2
135 473 2
292 412 1
244 414 1
214 411 1
253 411 1
275 390 1
207 377 1
334 399 1
296 416 1
243 416 1
199 406 1
356 379 1
374 374 1
361 383 1
353 391 1
342 395 1
333 401 1
313 414 1
245 382 1
242 393 1
230 408 1
176 384 1
154 392 1
134 381 1
121 359 1
113 344 1
291 384 1
285 397 1
285 407 1
279 412 1
107 211 1
101 195 1
444 362 2
424 386 2
350 423 2
338 429 2
304 444 2
263 453 2
405 394 2
427 365 2
443 334 2
449 315 2
458 291 2
463 275 2
478 284 2
463 390 2
447 420 2
438 435 2
416 458 2
413 462 2
392 474 2
378 475 2