sklearn初探(三):决策树及其可视化
sklearn初探(三):决策树及其可视化
前言
这是一个完整的工程,涵盖了从数据获得与处理,最后到构建决策树模型的全部过程。完整的源代码我会在文章最后给出。
概述
DecisionTreeClassifier
是能够在数据集上执行多分类的类,与其他分类器一样,DecisionTreeClassifier
采用输入两个数组:数组X,用 [n_samples, n_features]
的方式来存放训练样本。整数值数组Y,用 [n_samples]
来保存训练样本的类标签:
>>> from sklearn import tree
>>> X = [[0, 0], [1, 1]]
>>> Y = [0, 1]
>>> clf = tree.DecisionTreeClassifier()
>>> clf = clf.fit(X, Y)
执行通过之后,可以使用该模型来预测样本类别:
>>> clf.predict([[2., 2.]])
array([1])
另外,也可以预测每个类的概率,这个概率是叶中相同类的训练样本的分数:
>>> clf.predict_proba([[2., 2.]])
array([[ 0., 1.]])
经过训练,我们可以使用 export_graphviz
导出器以 Graphviz 格式导出决策树.
任务
读入一个银行数据集,训练决策树模型,预测客户是否会购买某种产品,并将决策树可视化。由于给出的数据全部为有标数据,在训练的过程中可以进行k折交叉验证(本次实验为10折交叉验证)。实验基于sklearn库展开。
数据处理
数据存储的文件为train_set.csv,表头如下所示。
我将从中构建两个数据集,并对它们的效果进行比较。
使用pandas对数据进行读入与切割,如下:
bank_data = pd.read_csv("../datas/train_set.csv")
# first set of balance and duration
first_set = bank_data[['age', 'balance']]
这样,我就把age与balance两列切割下来,作为输入。
我还需要每组输入数据对应的结果,也就是顾客是否会购买该产品,所以:
labels = bank_data['y']
训练决策树模型
sklearn的模型构建很简单,fit一下就行。
clf_treeAB_d4 = tree.DecisionTreeClassifier(criterion="entropy", max_depth=4)
clf_treeAB_d4 = clf_treeAB_d4.fit(first_set, labels)
criterion是你希望每个节点根据什么标准进行分支,有信息熵和基尼指数两个选项。max_depth是你希望构建的决策树的最大深度。这个设定有讲究,过浅则欠拟合,过深则过拟合。
fit过后,你就得到这个模型了,就可以用它做预测了。
但是,我这里的数据都是有标数据,如果想要做10折交叉验证,可以直接这么写:
scores_treeAB_d4 = cross_val_score(clf_treeAB_d4, first_set, labels, cv=10) # 10-means cross validate
print(scores_treeAB_d4)
cv这个参数就是你要几折交叉验证。
可视化
可视化需要graphviz这个包。
首先,生成dot文件
feature_name = ['age', 'balance']
class_name = ['not_buy', 'buy']
treeAB_d4_dot = tree.export_graphviz(
clf_treeAB_d4
, out_file=None
, feature_names=feature_name
, class_names=class_name
)
graph = graphviz.Source(treeAB_d4_dot)
生成的文件大概是这样的:
digraph Tree {
node [shape=box] ;
0 [label="age <= 60.5\nentropy = 0.521\nsamples = 25317\nvalue = [22356, 2961]\nclass = not_buy"] ;
1 [label="balance <= 291.5\nentropy = 0.495\nsamples = 24642\nvalue = [21972, 2670]\nclass = not_buy"] ;
0 -> 1 [labeldistance=2.5, labelangle=45, headlabel="True"] ;
...
29 [label="entropy = 1.0\nsamples = 2\nvalue = [1, 1]\nclass = not_buy"] ;
28 -> 29 ;
30 [label="entropy = 0.0\nsamples = 2\nvalue = [0, 2]\nclass = buy"] ;
28 -> 30 ;
}
然后render一下:
graph.render("../output/TreeForAgeAndBalanceD4")
效果如图:
另一个数据集
这第二个数据集是这样的:
second_set = bank_data[['duration', 'campaign', 'pdays', 'previous']]
我使用了三个模型,深度分别为3,4,8,最后评分是深度为4效果最好。深度为3则欠拟合,深度为8则过拟合。
总的来说,第二个数据集比第一个表现的好。
完整源代码
import pandas as pd
from sklearn import tree
from sklearn import preprocessing
import graphviz
from sklearn.model_selection import cross_val_score
bank_data = pd.read_csv("../datas/train_set.csv")
# first set of balance and duration
first_set = bank_data[['age', 'balance']]
labels = bank_data['y']
# print(first_set)
# decision tree begins
# decision tree max_depth=4
clf_treeAB_d4 = tree.DecisionTreeClassifier(criterion="entropy", max_depth=4)
clf_treeAB_d4 = clf_treeAB_d4.fit(first_set, labels)
scores_treeAB_d4 = cross_val_score(clf_treeAB_d4, first_set, labels, cv=10) # 10-means cross validate
print(scores_treeAB_d4)
# visualize
feature_name = ['age', 'balance']
class_name = ['not_buy', 'buy']
treeAB_d4_dot = tree.export_graphviz(
clf_treeAB_d4
, out_file=None
, feature_names=feature_name
, class_names=class_name
)
graph = graphviz.Source(treeAB_d4_dot)
graph.render("../output/TreeForAgeAndBalanceD4")
# decision tree max_depth=8
clf_treeAB_d8 = tree.DecisionTreeClassifier(criterion="entropy", max_depth=8)
clf_treeAB_d8 = clf_treeAB_d8.fit(first_set, labels)
scores_treeAB_d8 = cross_val_score(clf_treeAB_d8, first_set, labels, cv=10) # 10-means cross validate
print(scores_treeAB_d8)
# visualize
treeAB_d8_dot = tree.export_graphviz(
clf_treeAB_d8
, out_file=None
, feature_names=feature_name
, class_names=class_name
)
graph = graphviz.Source(treeAB_d8_dot)
graph.render("../output/TreeForAgeAndBalanceD8")
# second set of duration, campaign, pdays, as well as previous
# performs much more better than the 1st one
second_set = bank_data[['duration', 'campaign', 'pdays', 'previous']]
# decision tree max_depth=3
clf_treeDrtCpnPdsPrev_d3 = tree.DecisionTreeClassifier(criterion="entropy", max_depth=3)
clf_treeDrtCpnPdsPrev_d3 = clf_treeDrtCpnPdsPrev_d3.fit(second_set, labels)
scores_treeDrtCpnPdsPrev_d3 = cross_val_score(
clf_treeDrtCpnPdsPrev_d3, second_set, labels, cv=10
) # 10-means cross validate
print(scores_treeDrtCpnPdsPrev_d3)
# visualize
feature_name_2 = ['duration', 'campaign', 'pdays', 'previous']
treeDrtCpnPdsPrev_d3_dot = tree.export_graphviz(
clf_treeDrtCpnPdsPrev_d3
, out_file=None
, feature_names=feature_name_2
, class_names=class_name
)
graph = graphviz.Source(treeDrtCpnPdsPrev_d3_dot)
graph.render("../output/TreeForDrtCpnPdsPrevD3")
# decision tree max_depth=4
clf_treeDrtCpnPdsPrev_d4 = tree.DecisionTreeClassifier(criterion="entropy", max_depth=4)
clf_treeDrtCpnPdsPrev_d4 = clf_treeDrtCpnPdsPrev_d4.fit(second_set, labels)
scores_treeDrtCpnPdsPrev_d4 = cross_val_score(
clf_treeDrtCpnPdsPrev_d4, second_set, labels, cv=10
) # 10-means cross validate
print(scores_treeDrtCpnPdsPrev_d4)
# visualize
treeDrtCpnPdsPrev_d4_dot = tree.export_graphviz(
clf_treeDrtCpnPdsPrev_d4
, out_file=None
, feature_names=feature_name_2
, class_names=class_name
)
graph = graphviz.Source(treeDrtCpnPdsPrev_d4_dot)
graph.render("../output/TreeForDrtCpnPdsPrevD4")
# decision tree max_depth=8
clf_treeDrtCpnPdsPrev_d8 = tree.DecisionTreeClassifier(criterion="entropy", max_depth=8)
clf_treeDrtCpnPdsPrev_d8 = clf_treeDrtCpnPdsPrev_d8.fit(second_set, labels)
scores_treeDrtCpnPdsPrev_d8 = cross_val_score(
clf_treeDrtCpnPdsPrev_d8, second_set, labels, cv=10
) # 10-means cross validate
print(scores_treeDrtCpnPdsPrev_d8)
# visualize
treeDrtCpnPdsPrev_d8_dot = tree.export_graphviz(
clf_treeDrtCpnPdsPrev_d8
, out_file=None
, feature_names=feature_name_2
, class_names=class_name
)
graph = graphviz.Source(treeDrtCpnPdsPrev_d8_dot)
graph.render("../output/TreeForDrtCpnPdsPrevD8")
# decision tree ends
数据集
目前还在审核。我在下一篇文章中把链接发出来吧。
上一篇: SQL关联查询————LEFT JOIN关键字的使用
下一篇: 斐波那契数列(二)--矩阵优化算法
推荐阅读
-
sklearn构造决策树模型 树的可视化 pydotplus和GraphViz的安装
-
决策树的sklearn实现及其GraphViz可视化
-
Sklearn决策树可视化
-
sklearn中决策树回归器DecisionTreeRegressor的实际应用及可视化
-
sklearn初探(三):决策树及其可视化
-
利用sklearn中 ID3算法实现简单的课程销量预测+决策树可视化
-
机器学习:决策树(四) —— sklearn决策树的使用及其可视化
-
通俗地说决策树算法(三)sklearn决策树实战
-
决策树可视化(sklearn、graphviz)——python数据分析与挖掘实战 5-2 决策树预测销售量高低
-
通俗地说决策树算法(三)sklearn决策树实战