欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页  >  科技

大数据技术GBDT算法解析

程序员文章站 2022-07-10 10:53:50
大数据技术GBDT算法解析。 1. GBDT的基本思想 单模型情况下预测结果容易产生过拟合,例如普通决策树,要想达到比较好的预测效果,需要将树的深度调得比较深,叶节点的最大样本数目调得小一...

大数据技术GBDT算法解析。

1. GBDT的基本思想

单模型情况下预测结果容易产生过拟合,例如普通决策树,要想达到比较好的预测效果,需要将树的深度调得比较深,叶节点的最大样本数目调得小一点等才能达到比较高的准确率。但是这样会带来严重的过拟合问题,针对这些问题,GBDT采用多颗决策树组合的方法来实现比较高的准确率,同时避免过拟合问题。
假设待分类样本为((X(1),y(1)),(X(2),y(2)),?,(X(m),y(m))),其中m为样本数量,X(i)为第i个样本的特征,y(i)为第i个样本的类标签。GBDT的任务是构建K颗决策树f1,f2,?fk,对于每个样本i,其最终的预测值是每颗决策树的预测值的和:
predict(X(i))=∑k=1Kfk(X(i))

2, 单颗决策树的构建过程

对于一批训练样本((X(1),y(1)),(X(2),y(2)),?,(X(m),y(m))),首先计算标签的均值作为第一步的预测值:
μ=1m∑i=1my(i)
然后计算每一个样本的残差:
dY(i)=y(i)?μ
这样得到的残差作为第一棵树的学习标准。即:
大数据技术GBDT算法解析
接下来会以((X(1),dY(1)),(X(2),dY(2)),?,(X(m),dY(m)))作为第一颗树的根节点,学习出一颗CART树,具体学习方法见CART树算法详解。当指定最大树的深度,最大叶节点的个数,叶节点包含的最大样本数目后,树会在某一时刻停止训练,此时得到学习器,也就是第一个决策树tree1
对于得到的tree1和所有的样本,根据tree1得到每个样本的预测值predicti,然后跟新每个样本的残差:
dYi=dYi?αk×predict(treek,X(i))
其中αk为学习率,通常设置为定值, X(i)为第i个样本的特征值, predict(treek,X(i))为第k颗决策树对第i个样本的预测值。由此得到更新后的残差值(dY(1),dY(2),?,dY(m)),然后作为第2颗树的学习标准,以此类推,直到训练到第K颗树为止。

3. 损失函数与梯度下降

在GBDT决策树当中,采用的损失函数为:
L(θ)=12∑i=1m(hθ(X(i))?y(i))2m为样本数量
其中hθ(X(i))为前面j颗树对于样本i的预测值之和,即:
hθ(X(i))=μ+∑i=1jpredict(treej,X(i))
因此用L(θ)X(i)求导,得:
?L(θ)?h(X(i))=hθ(X(i))?y(i)
即梯度的方向就是每次训练完成之后样本的残差,然后将此残差作为下一颗树的target值继续学习,整个算法的基本过程为:

对于m个训练样本((X(1),y(1)),(X(2),y(2)),?,(X(m),y(m))),计算均值:
μ=1m∑i=1my(i) 计算样本的残差dYi=y(i)?μ 设树的总颗数为K,对于k∈{1,2,?,K},对于所有的残差dY1,dY2,?,dYm通过CART树学习出一个学习器treek,即:
treek=Train_Learner(X,dY)
然后更新残差:
dYi=dYi?α×treek(X(i))
其中α为学习率,treek(X(i))为第i个样本在第k颗树上的预测值。 不断重复3中的步骤,直到训练到第K颗树为止。 最终的预测结果为,第j个样本的预测值为所有树的预测值的叠加和:
predict(X(j))=μ+α∑k=1Ktreek(X(j))

Loss=∑i=1m(hθ(X(i))?y(i))2

4. gbdt树的打印

4.1. 安装依赖的软件和库

安装GraphViz并配置环境变量 安装pydotplus

4.2. 获取gbdt模型

gbdt_model = grid.best_estimator_

4.3. 打印决策树

from sklearn import tree
import pydotplus

estimators = gbdt_model.estimators.shape[0]
for i in range(estimators):
    dot_data = tree.export_graphviz(gbdt_model)
    graph = pydotplus.graph_from_dot_data(dot_data)
    graph.write_pdf("../data/tree_"+str(i)+".pdf")