欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

通过源码分析GBDT是怎么实现early stopping的

程序员文章站 2024-03-15 19:55:30
...

GBDT文档:Early stopping of Gradient Boosting

源码分析

有无early stopping的比较

    gbes = ensemble.GradientBoostingClassifier(n_estimators=n_estimators,
                                               validation_fraction=0.2,
                                               n_iter_no_change=5, tol=0.01,
                                               random_state=0)
    gb = ensemble.GradientBoostingClassifier(n_estimators=n_estimators,
                                             random_state=0)

打开scikit-learn源码,看到sklearn.ensemble._gb.GradientBoostingClassifier

尝试用Example作为调试入口:examples/ensemble/plot_gradient_boosting_early_stopping.py

尝试失败。Cython和cpp项目就是难搞。

看到他的基类sklearn.ensemble._gb.BaseGradientBoosting#fit

sklearn/ensemble/_gb.py:424

        if self.n_iter_no_change is not None:
            stratify = y if is_classifier(self) else None
            X, X_val, y, y_val, sample_weight, sample_weight_val = (
                train_test_split(X, y, sample_weight,
                                 random_state=self.random_state,
                                 test_size=self.validation_fraction,
                                 stratify=stratify))
            if is_classifier(self):
                if self.n_classes_ != np.unique(y).shape[0]:
                    # We choose to error here. The problem is that the init
                    # estimator would be trained on y, which has some missing
                    # classes now, so its predictions would not have the
                    # correct shape.
                    raise ValueError(
                        'The training data after the early stopping split '
                        'is missing some classes. Try using another random '
                        'seed.'
                    )
        else:
            X_val = y_val = sample_weight_val = None

这波操作是在整理训练集与测试集

        n_stages = self._fit_stages(
            X, y, raw_predictions, sample_weight, self._rng, X_val, y_val,
            sample_weight_val, begin_at_stage, monitor, X_idx_sorted)

这应该是热启动的操作

begin_at_stage = self.estimators_.shape[0]

_fit_stages是boosting的全过程

# fit the boosting stages
n_stages = self._fit_stages(
    X, y, raw_predictions, sample_weight, self._rng, X_val, y_val,
    sample_weight_val, begin_at_stage, monitor, X_idx_sorted)

进入sklearn.ensemble._gb.BaseGradientBoosting#_fit_stages
sklearn/ensemble/_gb.py:516

for i in range(begin_at_stage, self.n_estimators):

一次boosting迭代

    # fit next stage of trees
    raw_predictions = self._fit_stage(
        i, X, y, raw_predictions, sample_weight, sample_mask,
        random_state, X_idx_sorted, X_csc, X_csr)

sklearn.ensemble._gb.BaseGradientBoosting#_fit_stage

看了下_fit_stage函数(与_fit_stages不同,是拟合单个boosting过程),放几段关键的代码

            residual = loss.negative_gradient(y, raw_predictions_copy, k=k,
                                              sample_weight=sample_weight)

            # induce regression tree on residuals
            tree = DecisionTreeRegressor(...
            
            tree.fit(X, residual, sample_weight=sample_weight,
                     check_input=False, X_idx_sorted=X_idx_sorted)

            # update tree leaves
            loss.update_terminal_regions(
                tree.tree_, X, y, residual, raw_predictions, sample_weight,
                sample_mask, learning_rate=self.learning_rate, k=k)

            # add tree to ensemble
            self.estimators_[i, k] = tree

感觉都是些基操,和书上看的拟合负梯度操作一致,oob的操作需要学习下

对于self.estimators_这个成员变量我很好奇,于是调查了一下。

    def _clear_state(self):
        """Clear the state of the gradient boosting model. """
        if hasattr(self, 'estimators_'):
            self.estimators_ = np.empty((0, 0), dtype=np.object)
    def _resize_state(self):
            self.estimators_ = np.resize(self.estimators_,
                                     (total_n_estimators, self.loss_.K))

回到for循环的部分。
看到一个有意思的地方:

    if monitor is not None:
        early_stopping = monitor(i, self, locals())
        if early_stopping:
            break

monitorfit中用户传入的一个函数。

    # By calling next(y_val_pred_iter), we get the predictions
    # for X_val after the addition of the current stage
    validation_loss = loss_(y_val, next(y_val_pred_iter),
                            sample_weight_val)

    # Require validation_score to be better (less) than at least
    # one of the last n_iter_no_change evaluations
    if np.any(validation_loss + self.tol < loss_history):
        loss_history[i % len(loss_history)] = validation_loss
    else:
        break

这里就是early_stop最关键的地方了。

loss_history是一个长度为n_iter_no_change的向量,刚开始用np.full填充为inf,然后通过区域操作往里面放最新的loss。

这个操作还是很巧妙的。

不同参数效果比较

GBDT文档:Early stopping of Gradient Boosting

validation_fraction=0.2, n_iter_no_change=5

通过源码分析GBDT是怎么实现early stopping的

score_gb
[1.0, 0.9583333333333334, 0.9441666666666667]

score_gbes
[1.0, 0.9638888888888889, 0.915]

通过源码分析GBDT是怎么实现early stopping的

validation_fraction=0.2, n_iter_no_change=10

通过源码分析GBDT是怎么实现early stopping的

score_gbes
[1.0, 0.9638888888888889, 0.9266666666666666]

通过源码分析GBDT是怎么实现early stopping的

validation_fraction=0.2, n_iter_no_change=20

通过源码分析GBDT是怎么实现early stopping的

score_gbes
[1.0, 0.9638888888888889, 0.9529166666666666]

通过源码分析GBDT是怎么实现early stopping的

validation_fraction=0.1, n_iter_no_change=20

通过源码分析GBDT是怎么实现early stopping的

score_gbes
[1.0, 0.9666666666666667, 0.94625]

通过源码分析GBDT是怎么实现early stopping的

手动实现early_stop

相关标签: automl

上一篇: 广告点击率平滑

下一篇: Junit5详解