opencv3/C++ 机器学习-决策树/DTrees
决策树/Decision Tree
决策树/Decision Tree是一棵二叉树(每棵非叶子节点有两个子节点的树)。可用于分类或回归问题。对于分类问题(形成分类树),每个叶节点都标有一个类标签;多个叶节点可能具有相同的标签。对于回归问题(形成回归树),每个叶结点分配一个常量,所以回归函数是分段常量。
决策树从根结点递归构造。所有训练数据(特征向量和响应)用于分割根节点。在每个节点中,根据一些标准找到最佳决策规则(最好的“主要”分割)。如分类问题用“不/纯度”,回归问题用方差和。
关于不纯度,不同算法使用的计算方法不一,如ID3用信息增益/Information Gain作为不纯度;C4.5用信息增益率/Information Gain Ratio作为不纯度;CART用基尼系数/Gini Index作为不纯度。
然后,若有必要,找到替代分裂点。替代分裂点类似于训练数据的主要分割结果。 所有的数据根据初始和替代分裂点来划分给左、右孩子结点(就像在预测算法里做的一样)。然后算法递归地继续分裂左右孩子结点。
节点递归过程的终止条件:
- 树的深度达到了指定的最大值。
- 在该结点训练样本的数目少于指定阈值。
- 在该结点所有的样本属于同一类(如果是回归的话,变化已非常小)。
- 能选择到的最好的分裂跟随机选择相比已经基本没有什么有意义的改进了。
树创建好之后,如有必要,可以使用交叉验证对其进行修剪。将可能导致模型过拟合的某些分支剪掉。通常仅适用于单决策树。树集合通常会建立一些足够小的树并且用他们自身的保护机制来防止过拟合。
变量重要性:
决策树除了用于预测之外,还可以用在多变量分析上。 构建的决策树算法的一个关键特性是它能够计算每个变量的重要性(相对决策力)。 每个变量的重要性的计算是在所有的在这个变量上的分裂进行的,不管是初始的还是替代的。这样的话,要准确计算变量重要性,即使没有缺失数据,替代分裂也必须包含在训练参数中。
OpenCV DTrees类
DTree可以表示一个单独的决策树,也可以表示树集成分类器中的一个基础分类器(Boosting或Random Trees)。
常用函数
virtual void setMaxDepth(int val) ;
树的最大可能深度。 训练算法在节点深度小于maxDepth的情况下分割节点。根节点具有零深度。如果符合其他终止标准或修剪树,则实际深度会更小。默认值为INT_MAX。virtual void setMinSampleCount(int val) ;
节点最小样本数量。若节点中的样本数量小于该值,则不会被分割。默认为10。virtual void setUseSurrogates(bool val) ;
若为true,则建立替代分裂点。 这些分裂点可以处理丢失的数据并正确计算变量的重要性。 默认值为false。virtual void setCVFolds(int val);
如果CVFolds> 1,则算法使用K折叠交叉验证修剪构建好的决策树,其中K等于CVFolds。 默认值是10。virtual void setUse1SERule(bool val);
若为true,则修剪将更加严格,使树更紧凑,抗噪声能力更强,但会降低部分准确度。 默认值为true。virtual void setTruncatePrunedTree(bool val);
若为true,则修剪后的分支会被完全移除。否则分支将被保留,并可能从原决策树中获得结果。 默认值为true。
决策树示例
从文件points.txt中读取点坐标以及对应的分类,然后建立决策树对点所在区域进行划分。
#include <opencv2/opencv.hpp>
#include <opencv2/ml.hpp>
#include <iostream>
#include <fstream>
using namespace std;
using namespace cv;
using namespace cv::ml;
int main()
{
vector<Point> trainedPoints;
vector<int> trainedPointsMarkers;
//读取文件中的点坐标
FILE *fp;
int flge = 0;
int fpoint,flabel;
Point point;
fp = fopen("E:\\points.txt", "r+");
if (fp == NULL)
{
printf("Cannot open the file!\n");
return -1;
}
while (!feof(fp))
{
fscanf(fp, "%d", &fpoint);
if (feof(fp)) break;
//依次为横坐标、纵坐标、分类
if ((
flge%3==0? point.x = fpoint:
flge%3==1? point.y = fpoint:
flge%3==2? flabel = fpoint : -1)<0)
return -1;
if (flge%3==2)
{
trainedPoints.push_back(point);
trainedPointsMarkers.push_back(flabel);
}
flge++;
}
vector<Vec3b> colors(2);
colors[0] = Vec3b(0, 255, 0);
colors[1] = Vec3b(0, 0, 255);
Mat src, dst;
src.create( 480, 640, CV_8UC3 );
src = Scalar::all(0);
src.copyTo(dst);
// 绘制点
for( size_t i = 0; i < trainedPoints.size(); i++ )
{
Scalar c = colors[trainedPointsMarkers[i]];
circle( src, trainedPoints[i], 3, c, -1 );
circle( dst, trainedPoints[i], 3, c, -1 );
}
imshow( "points", src );
//训练数据
Mat samples;
Mat(trainedPoints).reshape(1, (int)trainedPoints.size()).convertTo(samples, CV_32F);
//建立模型
Ptr<DTrees> model = DTrees::create();
//树的最大可能深度
model->setMaxDepth(8);
//节点最小样本数量
model->setMinSampleCount(2);
//是否建立替代分裂点
model->setUseSurrogates(false);
//交叉验证次数
model->setCVFolds(0);
//是否严格修剪
model->setUse1SERule(false);
//分支是否完全移除
model->setTruncatePrunedTree(false);
//训练
model->train(TrainData::create(samples, ROW_SAMPLE, Mat(trainedPointsMarkers)));
//显示结果
Mat testSample( 1, 2, CV_32FC1 );
for( int y = 0; y < dst.rows; y += 3 )
{
for( int x = 0; x < dst.cols; x += 3 )
{
testSample.at<float>(0) = (float)x;
testSample.at<float>(1) = (float)y;
int response = (int)model->predict( testSample );
dst.at<Vec3b>(y, x) = colors[response];
}
}
imshow( "Decision Tree", dst );
waitKey();
return 0;
}
文件points.txt中的内容为:
(依次为横坐标、纵坐标、分类)
269 263 0
242 244 0
221 224 0
210 180 0
227 142 0
257 120 0
280 105 0
338 95 0
423 92 0
384 93 0
458 100 0
497 133 0
514 162 0
546 219 0
568 272 0
582 326 0
593 372 0
605 410 0
613 432 0
612 453 0
298 279 0
334 283 0
375 291 0
357 312 0
313 312 0
266 299 0
218 278 0
201 259 0
193 224 0
186 189 0
186 156 0
199 119 0
227 102 0
261 90 0
303 87 0
348 81 0
328 72 0
377 73 0
411 75 0
452 79 0
490 107 0
525 147 0
549 189 0
572 244 0
597 306 0
615 354 0
624 393 0
630 423 0
336 300 0
257 280 0
224 255 0
239 287 0
204 151 0
220 121 0
294 73 0
355 60 0
391 60 0
446 64 0
496 94 0
528 145 0
550 182 0
532 178 0
564 217 0
589 241 0
602 283 0
603 297 0
571 300 0
584 272 0
609 336 0
597 350 0
332 182 1
340 168 1
380 151 1
359 162 1
366 179 1
357 183 1
397 182 1
409 198 1
421 221 1
429 201 1
443 221 1
452 247 1
473 274 1
482 305 1
485 351 1
480 367 1
465 393 1
415 418 1
352 424 1
294 414 1
231 401 1
174 382 1
135 361 1
120 338 1
99 302 1
84 266 1
71 233 1
63 201 1
43 127 1
27 69 1
15 30 1
15 58 1
9 112 1
39 98 1
38 146 1
28 174 1
52 169 1
48 218 1
45 243 1
17 138 1
22 93 1
27 127 1
42 194 1
69 247 1
72 294 1
98 333 1
132 333 1
142 352 1
156 389 1
182 408 1
219 426 1
252 431 1
329 441 1
384 444 1
308 437 1
329 425 1
429 429 1
391 428 1
446 407 1
463 357 1
465 325 1
453 289 1
450 265 1
419 226 1
393 196 1
354 174 1
309 166 1
286 167 1
274 168 1
329 153 1
337 158 1
359 153 1
382 166 1
420 175 1
439 206 1
464 239 1
480 268 1
497 310 1
505 350 1
496 392 1
482 415 1
447 439 1
384 455 1
334 453 1
292 449 1
265 431 1
243 411 1
194 391 1
17 193 1
41 251 1
67 307 1
80 350 1
99 380 1
160 413 1
202 432 1
249 448 1
299 465 1
330 462 1
364 463 1
395 462 1
412 454 1
285 325 0
330 337 0
353 336 0
364 324 0
388 309 0
393 291 0
358 298 0
327 320 0
203 98 0
248 70 0
300 41 0
284 54 0
371 50 0
344 45 0
411 45 0
457 51 0
507 76 0
554 136 0
580 196 0
609 272 0
626 319 0
529 119 0
295 196 1
311 186 1
323 178 1
290 179 1
391 158 1
453 187 1
467 228 1
497 262 1
507 306 1
469 204 1
492 256 1
482 236 1
481 333 1
491 284 1
469 420 1
452 382 1
427 396 1