sklearn DecisionTree 源码分析
程序员文章站
2024-03-15 19:58:54
...
sklearn.tree._classes.BaseDecisionTree#fit
y
至少为1维(意思是可以处理multilabels
数据)
y = np.atleast_1d(y)
if is_classifier(self):
self.tree_ = Tree(self.n_features_,
self.n_classes_, self.n_outputs_)
else:
self.tree_ = Tree(self.n_features_,
# TODO: tree should't need this in this case
np.array([1] * self.n_outputs_, dtype=np.intp),
self.n_outputs_)
self.n_outputs_ = y.shape[1]
self.n_classes_ = self.n_classes_[0]
self.n_classes_ = []
for k in range(self.n_outputs_):
classes_k, y_encoded[:, k] = np.unique(y[:, k],
return_inverse=True)
self.classes_.append(classes_k)
self.n_classes_.append(classes_k.shape[0])
np.unique([3,2,2,3,3,4], return_inverse=True)
Out[4]: (array([2, 3, 4]), array([1, 0, 0, 1, 1, 2]))
return_inverse
类似于LabelEncode
sklearn.tree._tree.Tree
def __cinit__(self, int n_features, np.ndarray[SIZE_t, ndim=1] n_classes,
int n_outputs):
- 特征数
- 类别数
- label维度
# Use BestFirst if max_leaf_nodes given; use DepthFirst otherwise
if max_leaf_nodes < 0:
builder = DepthFirstTreeBuilder(splitter, min_samples_split,
min_samples_leaf,
min_weight_leaf,
max_depth,
self.min_impurity_decrease,
min_impurity_split)
else:
builder = BestFirstTreeBuilder(splitter, min_samples_split,
min_samples_leaf,
min_weight_leaf,
max_depth,
max_leaf_nodes,
self.min_impurity_decrease,
min_impurity_split)
最大叶子节点数max_leaf_nodes
通过限制最大叶子节点数,可以防止过拟合,默认是"None”,即不限制最大的叶子节点数。如果加了限制,算法会建立在最大叶子节点数内最优的决策树。如果特征不多,可以不考虑这个值,但是如果特征分成多的话,可以加以限制,具体的值可以通过交叉验证得到。
sklearn.tree._tree.DepthFirstTreeBuilder#build
builder.build(self.tree_, X, y, sample_weight, X_idx_sorted)
cpdef build(self, Tree tree, object X, np.ndarray y,
np.ndarray sample_weight=None,
np.ndarray X_idx_sorted=None):
注意到一个现象,这里该有的参数都有,但是class_weight
去哪了呢?怀疑是转化了sample_weight
了
if self.class_weight is not None:
expanded_class_weight = compute_sample_weight(
self.class_weight, y_original)
if expanded_class_weight is not None:
if sample_weight is not None:
sample_weight = sample_weight * expanded_class_weight
else:
sample_weight = expanded_class_weight
sklearn/tree/_tree.pyx:203
splitter.init(X, y, sample_weight_ptr, X_idx_sorted)
cdef SIZE_t n_node_samples = splitter.n_samples
rc = stack.push(0, n_node_samples, 0, _TREE_UNDEFINED, 0, INFINITY, 0)
rc
是根节点,在分裂前含有所有的样本
Stack
和StackRecord
都是sklearn自己写的数据结构
is_leaf = (depth >= max_depth or
n_node_samples < min_samples_split or
n_node_samples < 2 * min_samples_leaf or
weighted_n_node_samples < 2 * min_weight_leaf)
is_leaf = (is_leaf or (impurity <= min_impurity_split))
满足以上条件直接停止分裂
sklearn.tree._splitter.BestSplitter
sklearn.tree._splitter.BestSplitter#node_split
scikit-learn uses an optimised version of the CART algorithm; however, scikit-learn implementation does not support categorical variables for now.
上一篇: django中使用redis做缓存中间件