通过源码分析GBDT是怎么实现early stopping的
GBDT文档:Early stopping of Gradient Boosting
文章目录
源码分析
有无early stopping的比较
gbes = ensemble.GradientBoostingClassifier(n_estimators=n_estimators,
validation_fraction=0.2,
n_iter_no_change=5, tol=0.01,
random_state=0)
gb = ensemble.GradientBoostingClassifier(n_estimators=n_estimators,
random_state=0)
打开scikit-learn
源码,看到sklearn.ensemble._gb.GradientBoostingClassifier
尝试用Example作为调试入口:examples/ensemble/plot_gradient_boosting_early_stopping.py
尝试失败。Cython和cpp项目就是难搞。
看到他的基类sklearn.ensemble._gb.BaseGradientBoosting#fit
sklearn/ensemble/_gb.py:424
if self.n_iter_no_change is not None:
stratify = y if is_classifier(self) else None
X, X_val, y, y_val, sample_weight, sample_weight_val = (
train_test_split(X, y, sample_weight,
random_state=self.random_state,
test_size=self.validation_fraction,
stratify=stratify))
if is_classifier(self):
if self.n_classes_ != np.unique(y).shape[0]:
# We choose to error here. The problem is that the init
# estimator would be trained on y, which has some missing
# classes now, so its predictions would not have the
# correct shape.
raise ValueError(
'The training data after the early stopping split '
'is missing some classes. Try using another random '
'seed.'
)
else:
X_val = y_val = sample_weight_val = None
这波操作是在整理训练集与测试集
n_stages = self._fit_stages(
X, y, raw_predictions, sample_weight, self._rng, X_val, y_val,
sample_weight_val, begin_at_stage, monitor, X_idx_sorted)
这应该是热启动的操作
begin_at_stage = self.estimators_.shape[0]
_fit_stages
是boosting的全过程
# fit the boosting stages
n_stages = self._fit_stages(
X, y, raw_predictions, sample_weight, self._rng, X_val, y_val,
sample_weight_val, begin_at_stage, monitor, X_idx_sorted)
进入sklearn.ensemble._gb.BaseGradientBoosting#_fit_stages
sklearn/ensemble/_gb.py:516
for i in range(begin_at_stage, self.n_estimators):
一次boosting迭代
# fit next stage of trees
raw_predictions = self._fit_stage(
i, X, y, raw_predictions, sample_weight, sample_mask,
random_state, X_idx_sorted, X_csc, X_csr)
sklearn.ensemble._gb.BaseGradientBoosting#_fit_stage
看了下_fit_stage
函数(与_fit_stages
不同,是拟合单个boosting过程),放几段关键的代码
residual = loss.negative_gradient(y, raw_predictions_copy, k=k,
sample_weight=sample_weight)
# induce regression tree on residuals
tree = DecisionTreeRegressor(...
tree.fit(X, residual, sample_weight=sample_weight,
check_input=False, X_idx_sorted=X_idx_sorted)
# update tree leaves
loss.update_terminal_regions(
tree.tree_, X, y, residual, raw_predictions, sample_weight,
sample_mask, learning_rate=self.learning_rate, k=k)
# add tree to ensemble
self.estimators_[i, k] = tree
感觉都是些基操,和书上看的拟合负梯度操作一致,oob
的操作需要学习下
对于self.estimators_
这个成员变量我很好奇,于是调查了一下。
def _clear_state(self):
"""Clear the state of the gradient boosting model. """
if hasattr(self, 'estimators_'):
self.estimators_ = np.empty((0, 0), dtype=np.object)
def _resize_state(self):
self.estimators_ = np.resize(self.estimators_,
(total_n_estimators, self.loss_.K))
回到for循环的部分。
看到一个有意思的地方:
if monitor is not None:
early_stopping = monitor(i, self, locals())
if early_stopping:
break
monitor
是fit
中用户传入的一个函数。
# By calling next(y_val_pred_iter), we get the predictions
# for X_val after the addition of the current stage
validation_loss = loss_(y_val, next(y_val_pred_iter),
sample_weight_val)
# Require validation_score to be better (less) than at least
# one of the last n_iter_no_change evaluations
if np.any(validation_loss + self.tol < loss_history):
loss_history[i % len(loss_history)] = validation_loss
else:
break
这里就是early_stop最关键的地方了。
loss_history
是一个长度为n_iter_no_change
的向量,刚开始用np.full
填充为inf
,然后通过区域操作往里面放最新的loss。
这个操作还是很巧妙的。
不同参数效果比较
GBDT文档:Early stopping of Gradient Boosting
validation_fraction=0.2
, n_iter_no_change=5
score_gb
[1.0, 0.9583333333333334, 0.9441666666666667]
score_gbes
[1.0, 0.9638888888888889, 0.915]
validation_fraction=0.2
, n_iter_no_change=10
score_gbes
[1.0, 0.9638888888888889, 0.9266666666666666]
validation_fraction=0.2
, n_iter_no_change=20
score_gbes
[1.0, 0.9638888888888889, 0.9529166666666666]
validation_fraction=0.1
, n_iter_no_change=20
score_gbes
[1.0, 0.9666666666666667, 0.94625]