欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

sklearn DecisionTree 源码分析

程序员文章站 2024-03-15 19:58:54
...

sklearn.tree._classes.BaseDecisionTree#fit
y至少为1维(意思是可以处理multilabels数据)

y = np.atleast_1d(y)
if is_classifier(self):
    self.tree_ = Tree(self.n_features_,
                      self.n_classes_, self.n_outputs_)
else:
    self.tree_ = Tree(self.n_features_,
                      # TODO: tree should't need this in this case
                      np.array([1] * self.n_outputs_, dtype=np.intp),
                      self.n_outputs_)
self.n_outputs_ = y.shape[1]
self.n_classes_ = self.n_classes_[0]
self.n_classes_ = []
for k in range(self.n_outputs_):
    classes_k, y_encoded[:, k] = np.unique(y[:, k],
                                           return_inverse=True)
    self.classes_.append(classes_k)
    self.n_classes_.append(classes_k.shape[0])
np.unique([3,2,2,3,3,4], return_inverse=True)
Out[4]: (array([2, 3, 4]), array([1, 0, 0, 1, 1, 2]))

return_inverse类似于LabelEncode

sklearn.tree._tree.Tree

    def __cinit__(self, int n_features, np.ndarray[SIZE_t, ndim=1] n_classes,
                  int n_outputs):
  1. 特征数
  2. 类别数
  3. label维度
# Use BestFirst if max_leaf_nodes given; use DepthFirst otherwise
if max_leaf_nodes < 0:
    builder = DepthFirstTreeBuilder(splitter, min_samples_split,
                                    min_samples_leaf,
                                    min_weight_leaf,
                                    max_depth,
                                    self.min_impurity_decrease,
                                    min_impurity_split)
else:
    builder = BestFirstTreeBuilder(splitter, min_samples_split,
                                   min_samples_leaf,
                                   min_weight_leaf,
                                   max_depth,
                                   max_leaf_nodes,
                                   self.min_impurity_decrease,
                                   min_impurity_split)

scikit-learn决策树算法类库介绍

最大叶子节点数max_leaf_nodes

通过限制最大叶子节点数,可以防止过拟合,默认是"None”,即不限制最大的叶子节点数。如果加了限制,算法会建立在最大叶子节点数内最优的决策树。如果特征不多,可以不考虑这个值,但是如果特征分成多的话,可以加以限制,具体的值可以通过交叉验证得到。

sklearn.tree._tree.DepthFirstTreeBuilder#build

builder.build(self.tree_, X, y, sample_weight, X_idx_sorted)
cpdef build(self, Tree tree, object X, np.ndarray y,
            np.ndarray sample_weight=None,
            np.ndarray X_idx_sorted=None):

注意到一个现象,这里该有的参数都有,但是class_weight去哪了呢?怀疑是转化了sample_weight

if self.class_weight is not None:
    expanded_class_weight = compute_sample_weight(
        self.class_weight, y_original)
if expanded_class_weight is not None:
    if sample_weight is not None:
        sample_weight = sample_weight * expanded_class_weight
    else:
        sample_weight = expanded_class_weight

sklearn/tree/_tree.pyx:203

splitter.init(X, y, sample_weight_ptr, X_idx_sorted)
cdef SIZE_t n_node_samples = splitter.n_samples
rc = stack.push(0, n_node_samples, 0, _TREE_UNDEFINED, 0, INFINITY, 0)

rc是根节点,在分裂前含有所有的样本

StackStackRecord都是sklearn自己写的数据结构

is_leaf = (depth >= max_depth or
           n_node_samples < min_samples_split or
           n_node_samples < 2 * min_samples_leaf or
           weighted_n_node_samples < 2 * min_weight_leaf)
is_leaf = (is_leaf or (impurity <= min_impurity_split))

满足以上条件直接停止分裂

sklearn.tree._splitter.BestSplitter

sklearn.tree._splitter.BestSplitter#node_split


scikit-learn uses an optimised version of the CART algorithm; however, scikit-learn implementation does not support categorical variables for now.

sklearn DecisionTree 源码分析

相关标签: automl