欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

XGBoost介绍及Python实现

程序员文章站 2022-03-18 21:37:04
...

XGBoost介绍及Python实现

XGBoost简单介绍

  • XGBoost 算法是boost 集成算法中的一种,Boosting 算法的思想是将许多弱分类器集成在一起形成一个强分类器。XGBoost 是一种提升树模型,是将许多树模型集成在一起,形成强分类器。XGBoost 中使用的弱分类器为CART (classification and regression tree)回归树。

  • CART 回归树是一种二叉树,而可用于分类也可用于回归,其每一个节点只对是否进行判断,通过自上而下的分裂不断对样本集合的特征进行分裂。XGBoost 是通过不断添加CART ,不断使特征进行分裂完成CART 若分类器的集成。每一颗新的树的学习任务是去拟合原有树的拟合残差,使得生成的树的预测结果与真实值更加接近,可以认为XGBoost 的计算结果为多颗二叉树的计算结果加和。

  • XGBoost的损失函数为:

    • o b j = ∑ i = 1 n L ( y i , y ^ i ) + ∑ i = 1 n Ω ( f i ) obj=\sum_{i=1}^{n}L(y_i,\hat{y}_i)+\sum_{i=1}^{n}\Omega(f_i) obj=i=1nL(yi,y^i)+i=1nΩ(fi)
    • 由预测与真实值的误差和子树的模型复杂度组成
  • XGBoost采用Additive Training的方式进行训练,即在已经训练好的树模型和结构上边每次添加一棵树趋近于目标函数如下:

    • y ^ i ( 0 ) = 0 \hat{y}^{(0)}_i=0 y^i(0)=0
    • y ^ i ( 1 ) = f 1 ( x i ) = y ^ i ( 0 ) = 0 + f 1 ( x i ) \hat{y}^{(1)}_i=f_1(x_i)=\hat{y}^{(0)}_i=0+f_1(x_i) y^i(1)=f1(xi)=y^i(0)=0+f1(xi)
    • y ^ i ( 2 ) = f 2 ( x i ) = y ^ i ( 0 ) = 0 + f 1 ( x i ) + f 2 ( x i ) \hat{y}^{(2)}_i=f_2(x_i)=\hat{y}^{(0)}_i=0+f_1(x_i)+f_2(x_i) y^i(2)=f2(xi)=y^i(0)=0+f1(xi)+f2(xi)
    • ………………
    • y ^ i ( t ) = ∑ k = 1 T f t ( x i ) = y ^ i ( t − 1 ) + f t ( x i ) \hat{y}^{(t)}_i= \sum_{k=1}^{T} f_t(x_i)=\hat{y}^{(t-1)}_i+f_t(x_i) y^i(t)=k=1Tft(xi)=y^i(t1)+ft(xi)
  • 根据XGBoost的训练方法和原则XGBoost的损失函数可以简化为:

    • o b s = ∑ i = 1 n L ( y i , y ^ ( t − 1 ) + f t ( x i ) ) + Ω ( f t ) + c o n s t a n t obs=\sum_{i=1}^{n}L(y_i,\hat{y}^{(t-1)}+f_t(x_i))+\Omega(f_t)+constant obs=i=1nL(yi,y^(t1)+ft(xi))+Ω(ft)+constant
    • 对上式进行泰勒展开,可以得到
    • o b s = ∑ i = 1 n [ L ( y i , y ^ ( t − 1 ) ) + g i f t ( x i ) ) + h i f t 2 ( x i ) ] + Ω ( f t ) + c o n s t a n t obs=\sum_{i=1}^{n}[L(y_i,\hat{y}^{(t-1)})+g_if_t(x_i))+h_if_t^2(x_i)]+\Omega(f_t)+constant obs=i=1n[L(yi,y^(t1))+gift(xi))+hift2(xi)]+Ω(ft)+constant
    • gi 和 hi 是仅仅与最新生成树有关,之前的t-1棵树的损失函数保存在constant中,每一颗新生成的树的损失函数仅仅与其本身有关
    • 将所有的常数去掉可以得到第t步的损失函数
      • ∑ i = 1 n [ g i f t ( x i ) + 1 2 h i f t 2 ( x i ) ] + Ω ( f t ) \sum^n_{i=1}[g_if_t(x_i)+\frac{1}{2}h_if_t^2(x_i)]+\Omega(f_t) i=1n[gift(xi)+21hift2(xi)]+Ω(ft)
    • 模型的复杂度是XGBoost损失函数中的正则化项
      • Ω ( f t ) = γ T + 1 2 λ ∑ j = 1 T w j 2 \Omega(f_t)=\gamma T+\frac{1}{2} \lambda \sum^T_{j=1}w^2_{j} Ω(ft)=γT+21λj=1Twj2

XGBoost的python简单实现

  • XGBoost的简单实现
from xgboost import XGBRegressor as XGBR
# 加载波士顿数据集
from sklearn.datasets import load_boston
# 从sklearn中加载测试方法
from sklearn.model_selection import KFold, cross_val_score as CVS, train_test_split as TTS
from sklearn.metrics import mean_squared_error as RMSE

data = load_boston()

X = data.data 
#数据格式为 numpy.array
y = data.target

# 将数据随机打乱,分割为训练数据集和检测数据集
Xtrain,Xtest,Ytrain,Ytest = TTS(X,y,test_size=0.3,random_state=420)

reg = XGBR(n_estimators=100).fit(Xtrain,Ytrain) #训练
reg.predict(Xtest) #预测

reg.score(Xtest,Ytest) #预测结果的评价,采用的指标为相关系数

RMSE(Ytest,reg.predict(Xtest)) #计算均方根误差
# 进行交叉检验
CVS(reg,Xtrain,Ytrain,cv=5).mean()
  • 上述流程可以概述为:读入数据–>设置参数–>训练模型–>效果检测,其中读入数据流程中XGBoost模块中提供了一个xgb.DMatrix()格式数据,可以读入CSV数据等二进制数据,但是也可以直接使用numpy.array数据格式放入模型中进行训练,其他流程实现较为简单,可以结合sklearn实现,因此xgboost建立过程中需要了解的为参数设置
  • XGBoost参数介绍:
  • XGBoost的参数分为三个部分:
    • 通用参数:关于我们使用什么模型进行boost
    • Booster参数:集成参数
    • 学习任务参数:针对不同学习任务制定不同参数
  • 通用参数:
    • 在这里主要介绍通用参数中的重点参数
    • booster [default= gbtree ]
      • booster的基础树模型,包括三种 :gbtree, gblinear or dart,其中gbtree与dart的应用效果较为一致,dglinear应用效果稍差【参考:菜菜的sklearn】
    • verbosity [default=1]
      • 训练中是否打印信息,0 (silent), 1 (warning), 2 (info), 3 (debug)
  • Booster参数:
    • eta [default=0.3, alias: learning_rate]; range[0,1]
      • 相当于神经网络中的学习率,与二叉树的分裂有关,可以增加XGBoost的稳定性(鲁棒性)
    • gamma [default=0, alias: min_split_loss]
      • 子叶节点分裂所需要的最小损失,gamma越大,算法越保守(鲁棒性)
    • max_depth [default=6]
      • 树的深度,深度越大,XGBoost的越复杂越容易过拟合
    • subsample [default=1]
      • 训练每一颗新的树时候的采样率,有点类似与神经网络中的批量训练,每次训练新的二叉树时使用数据在全数据集中的占比
    • sampling_method [default= uniform]
      • 采样采用的概率模型:uniform 正态分布 gradient_based
    • lambda,alpha:
      • L2和L1正则化参数
        参考:
        XGBoost官方文档
        菜菜的sklearn:xgboost篇
        blog: https://blog.csdn.net/luanfenlian0992/article/details/106448500/