欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

sklearn初探(三):决策树及其可视化

程序员文章站 2024-03-19 20:27:52
...

sklearn初探(三):决策树及其可视化

前言

这是一个完整的工程,涵盖了从数据获得与处理,最后到构建决策树模型的全部过程。完整的源代码我会在文章最后给出。

概述

DecisionTreeClassifier 是能够在数据集上执行多分类的类,与其他分类器一样,DecisionTreeClassifier 采用输入两个数组:数组X,用 [n_samples, n_features] 的方式来存放训练样本。整数值数组Y,用 [n_samples] 来保存训练样本的类标签:

>>> from sklearn import tree
>>> X = [[0, 0], [1, 1]]
>>> Y = [0, 1]
>>> clf = tree.DecisionTreeClassifier()
>>> clf = clf.fit(X, Y)

执行通过之后,可以使用该模型来预测样本类别:

>>> clf.predict([[2., 2.]])
array([1])

另外,也可以预测每个类的概率,这个概率是叶中相同类的训练样本的分数:

>>> clf.predict_proba([[2., 2.]])
array([[ 0.,  1.]])

经过训练,我们可以使用 export_graphviz 导出器以 Graphviz 格式导出决策树.

任务

读入一个银行数据集,训练决策树模型,预测客户是否会购买某种产品,并将决策树可视化。由于给出的数据全部为有标数据,在训练的过程中可以进行k折交叉验证(本次实验为10折交叉验证)。实验基于sklearn库展开。

数据处理

数据存储的文件为train_set.csv,表头如下所示。
sklearn初探(三):决策树及其可视化
我将从中构建两个数据集,并对它们的效果进行比较。
使用pandas对数据进行读入与切割,如下:

bank_data = pd.read_csv("../datas/train_set.csv")
# first set of balance and duration
first_set = bank_data[['age', 'balance']]

这样,我就把age与balance两列切割下来,作为输入。
我还需要每组输入数据对应的结果,也就是顾客是否会购买该产品,所以:

labels = bank_data['y']

训练决策树模型

sklearn的模型构建很简单,fit一下就行。

clf_treeAB_d4 = tree.DecisionTreeClassifier(criterion="entropy", max_depth=4)
clf_treeAB_d4 = clf_treeAB_d4.fit(first_set, labels)

criterion是你希望每个节点根据什么标准进行分支,有信息熵和基尼指数两个选项。max_depth是你希望构建的决策树的最大深度。这个设定有讲究,过浅则欠拟合,过深则过拟合。
fit过后,你就得到这个模型了,就可以用它做预测了。
但是,我这里的数据都是有标数据,如果想要做10折交叉验证,可以直接这么写:

scores_treeAB_d4 = cross_val_score(clf_treeAB_d4, first_set, labels, cv=10)  # 10-means cross validate
print(scores_treeAB_d4)

cv这个参数就是你要几折交叉验证。

可视化

可视化需要graphviz这个包。
首先,生成dot文件

feature_name = ['age', 'balance']
class_name = ['not_buy', 'buy']
treeAB_d4_dot = tree.export_graphviz(
    clf_treeAB_d4
    , out_file=None
    , feature_names=feature_name
    , class_names=class_name
)
graph = graphviz.Source(treeAB_d4_dot)

生成的文件大概是这样的:

digraph Tree {
node [shape=box] ;
0 [label="age <= 60.5\nentropy = 0.521\nsamples = 25317\nvalue = [22356, 2961]\nclass = not_buy"] ;
1 [label="balance <= 291.5\nentropy = 0.495\nsamples = 24642\nvalue = [21972, 2670]\nclass = not_buy"] ;
0 -> 1 [labeldistance=2.5, labelangle=45, headlabel="True"] ;
...
29 [label="entropy = 1.0\nsamples = 2\nvalue = [1, 1]\nclass = not_buy"] ;
28 -> 29 ;
30 [label="entropy = 0.0\nsamples = 2\nvalue = [0, 2]\nclass = buy"] ;
28 -> 30 ;
}

然后render一下:

graph.render("../output/TreeForAgeAndBalanceD4")

效果如图:
sklearn初探(三):决策树及其可视化

另一个数据集

这第二个数据集是这样的:

second_set = bank_data[['duration', 'campaign', 'pdays', 'previous']]

我使用了三个模型,深度分别为3,4,8,最后评分是深度为4效果最好。深度为3则欠拟合,深度为8则过拟合。
总的来说,第二个数据集比第一个表现的好。

完整源代码

import pandas as pd
from sklearn import tree
from sklearn import preprocessing
import graphviz
from sklearn.model_selection import cross_val_score

bank_data = pd.read_csv("../datas/train_set.csv")
# first set of balance and duration
first_set = bank_data[['age', 'balance']]
labels = bank_data['y']
# print(first_set)
# decision tree begins
# decision tree max_depth=4
clf_treeAB_d4 = tree.DecisionTreeClassifier(criterion="entropy", max_depth=4)
clf_treeAB_d4 = clf_treeAB_d4.fit(first_set, labels)
scores_treeAB_d4 = cross_val_score(clf_treeAB_d4, first_set, labels, cv=10)  # 10-means cross validate
print(scores_treeAB_d4)
# visualize
feature_name = ['age', 'balance']
class_name = ['not_buy', 'buy']
treeAB_d4_dot = tree.export_graphviz(
    clf_treeAB_d4
    , out_file=None
    , feature_names=feature_name
    , class_names=class_name
)
graph = graphviz.Source(treeAB_d4_dot)
graph.render("../output/TreeForAgeAndBalanceD4")
# decision tree max_depth=8
clf_treeAB_d8 = tree.DecisionTreeClassifier(criterion="entropy", max_depth=8)
clf_treeAB_d8 = clf_treeAB_d8.fit(first_set, labels)
scores_treeAB_d8 = cross_val_score(clf_treeAB_d8, first_set, labels, cv=10)  # 10-means cross validate
print(scores_treeAB_d8)
# visualize
treeAB_d8_dot = tree.export_graphviz(
    clf_treeAB_d8
    , out_file=None
    , feature_names=feature_name
    , class_names=class_name
)
graph = graphviz.Source(treeAB_d8_dot)
graph.render("../output/TreeForAgeAndBalanceD8")

# second set of duration, campaign, pdays, as well as previous
# performs much more better than the 1st one
second_set = bank_data[['duration', 'campaign', 'pdays', 'previous']]
# decision tree max_depth=3
clf_treeDrtCpnPdsPrev_d3 = tree.DecisionTreeClassifier(criterion="entropy", max_depth=3)
clf_treeDrtCpnPdsPrev_d3 = clf_treeDrtCpnPdsPrev_d3.fit(second_set, labels)
scores_treeDrtCpnPdsPrev_d3 = cross_val_score(
    clf_treeDrtCpnPdsPrev_d3, second_set, labels, cv=10
)  # 10-means cross validate
print(scores_treeDrtCpnPdsPrev_d3)
# visualize
feature_name_2 = ['duration', 'campaign', 'pdays', 'previous']
treeDrtCpnPdsPrev_d3_dot = tree.export_graphviz(
    clf_treeDrtCpnPdsPrev_d3
    , out_file=None
    , feature_names=feature_name_2
    , class_names=class_name
)
graph = graphviz.Source(treeDrtCpnPdsPrev_d3_dot)
graph.render("../output/TreeForDrtCpnPdsPrevD3")
# decision tree max_depth=4
clf_treeDrtCpnPdsPrev_d4 = tree.DecisionTreeClassifier(criterion="entropy", max_depth=4)
clf_treeDrtCpnPdsPrev_d4 = clf_treeDrtCpnPdsPrev_d4.fit(second_set, labels)
scores_treeDrtCpnPdsPrev_d4 = cross_val_score(
    clf_treeDrtCpnPdsPrev_d4, second_set, labels, cv=10
)  # 10-means cross validate
print(scores_treeDrtCpnPdsPrev_d4)
# visualize
treeDrtCpnPdsPrev_d4_dot = tree.export_graphviz(
    clf_treeDrtCpnPdsPrev_d4
    , out_file=None
    , feature_names=feature_name_2
    , class_names=class_name
)
graph = graphviz.Source(treeDrtCpnPdsPrev_d4_dot)
graph.render("../output/TreeForDrtCpnPdsPrevD4")
# decision tree max_depth=8
clf_treeDrtCpnPdsPrev_d8 = tree.DecisionTreeClassifier(criterion="entropy", max_depth=8)
clf_treeDrtCpnPdsPrev_d8 = clf_treeDrtCpnPdsPrev_d8.fit(second_set, labels)
scores_treeDrtCpnPdsPrev_d8 = cross_val_score(
    clf_treeDrtCpnPdsPrev_d8, second_set, labels, cv=10
)  # 10-means cross validate
print(scores_treeDrtCpnPdsPrev_d8)
# visualize
treeDrtCpnPdsPrev_d8_dot = tree.export_graphviz(
    clf_treeDrtCpnPdsPrev_d8
    , out_file=None
    , feature_names=feature_name_2
    , class_names=class_name
)
graph = graphviz.Source(treeDrtCpnPdsPrev_d8_dot)
graph.render("../output/TreeForDrtCpnPdsPrevD8")
# decision tree ends

数据集

目前还在审核。我在下一篇文章中把链接发出来吧。

相关标签: python