机器学习实战 —— 决策树(sklearn api)
程序员文章站
2022-03-30 18:55:57
...
代码
import pandas as pd
from sklearn.preprocessing import LabelEncoder
from sklearn import tree
from sklearn.externals.six import StringIO
# pip install pydotplus
# pip install graphviz
import pydotplus
# Graphviz瞎子地址:http://www.graphviz.org/download/
import os
os.environ["PATH"] += os.pathsep + 'D:/program files (x86)/Graphviz2.38/bin'
def loadData():
"""
加载文件,生成特征集和目标值集
:return:
"""
# 加载文件
with open('lenses.txt') as fr:
# 处理文件,去掉每行两头的空白符,以\t分隔每个数据
lenses = [inst.strip().split('\t') for inst in fr.readlines()]
# 提取每组数据的类别,保存在列表里
lenses_targt = []
for each in lenses:
# 存储Label到lenses_targt中
lenses_targt.append([each[-1]])
lensesLabels = ['age', 'prescript', 'astigmatic', 'tearRate']
# 保存lenses数据的字典,用于生成pandas
lenses_dict = {}
# 提取信息,生成字典
for each_label in lensesLabels:
# 保存lenses数据的临时列表
lenses_list = []
for each in lenses:
# index方法用于从列表中找出某个值第一个匹配项的索引位置
lenses_list.append(each[lensesLabels.index(each_label)])
lenses_dict[each_label] = lenses_list
# 生成pandas.DataFrame用于对象的创建
lenses_pd = pd.DataFrame(lenses_dict)
print(lenses_targt)
print(lenses_pd)
return lenses_pd, lenses_targt
def dataEncoder(data_pd):
le = LabelEncoder()
# 为每一列序列化
for col in data_pd.columns:
# fit_transform()干了两件事:fit找到数据转换规则,并将数据标准化
# transform()直接把转换规则拿来用,需要先进行fit
# transform函数是一定可以替换为fit_transform函数的,fit_transform函数不能替换为transform函数
data_pd[col] = le.fit_transform(data_pd[col])
# 打印归一化的结果
print(data_pd)
def createTree(data_pd, labels):
# 创建DecisionTreeClassifier()类
clf = tree.DecisionTreeClassifier(criterion='entropy', max_depth=4)
# 使用数据构造决策树
# fit(X,y):Build a decision tree classifier from the training set(X,y)
# 所有的sklearn的API必须先fit
clf = clf.fit(data_pd.values.tolist(), labels)
return clf
def exportTree(clf, feature_names):
# 保存树
with open("lenses.dot", 'w') as f:
tree.export_graphviz(clf, out_file=f)
# 打印树
dot_data = StringIO()
tree.export_graphviz(clf, out_file=dot_data,
feature_names=feature_names,
class_names=clf.classes_,
filled=True, rounded=True,
special_characters=True)
graph = pydotplus.graph_from_dot_data(dot_data.getvalue())
graph.write_pdf("tree.pdf")
def main():
# 生成数据集和目标值集
data_pd, targts = loadData()
# 数据编码,序列化
dataEncoder(data_pd)
# 生成树
tree = createTree(data_pd, targts)
# 保存树、打印树
exportTree(tree, data_pd.keys())
# 预测
print(tree.predict([[1, 1, 1, 0]]))
if __name__ == '__main__':
main()
运行结果
[['no lenses'], ['soft'], ['no lenses'], ['hard'], ['no lenses'], ['soft'], ['no lenses'], ['hard'], ['no lenses'], ['soft'], ['no lenses'], ['hard'], ['no lenses'], ['soft'], ['no lenses'], ['no lenses'], ['no lenses'], ['no lenses'], ['no lenses'], ['hard'], ['no lenses'], ['soft'], ['no lenses'], ['no lenses']]
age astigmatic prescript tearRate
0 young no myope reduced
1 young no myope normal
2 young yes myope reduced
3 young yes myope normal
4 young no hyper reduced
5 young no hyper normal
6 young yes hyper reduced
7 young yes hyper normal
8 pre no myope reduced
9 pre no myope normal
10 pre yes myope reduced
11 pre yes myope normal
12 pre no hyper reduced
13 pre no hyper normal
14 pre yes hyper reduced
15 pre yes hyper normal
16 presbyopic no myope reduced
17 presbyopic no myope normal
18 presbyopic yes myope reduced
19 presbyopic yes myope normal
20 presbyopic no hyper reduced
21 presbyopic no hyper normal
22 presbyopic yes hyper reduced
23 presbyopic yes hyper normal
age astigmatic prescript tearRate
0 2 0 1 1
1 2 0 1 0
2 2 1 1 1
3 2 1 1 0
4 2 0 0 1
5 2 0 0 0
6 2 1 0 1
7 2 1 0 0
8 0 0 1 1
9 0 0 1 0
10 0 1 1 1
11 0 1 1 0
12 0 0 0 1
13 0 0 0 0
14 0 1 0 1
15 0 1 0 0
16 1 0 1 1
17 1 0 1 0
18 1 1 1 1
19 1 1 1 0
20 1 0 0 1
21 1 0 0 0
22 1 1 0 1
23 1 1 0 0
['hard']
Process finished with exit code 0
lenses.dot
digraph Tree {
node [shape=box] ;
0 [label="X[3] <= 0.5\nentropy = 1.326\nsamples = 24\nvalue = [4, 15, 5]"] ;
1 [label="X[1] <= 0.5\nentropy = 1.555\nsamples = 12\nvalue = [4, 3, 5]"] ;
0 -> 1 [labeldistance=2.5, labelangle=45, headlabel="True"] ;
2 [label="X[2] <= 0.5\nentropy = 0.65\nsamples = 6\nvalue = [0, 1, 5]"] ;
1 -> 2 ;
3 [label="entropy = 0.0\nsamples = 3\nvalue = [0, 0, 3]"] ;
2 -> 3 ;
4 [label="X[0] <= 0.5\nentropy = 0.918\nsamples = 3\nvalue = [0, 1, 2]"] ;
2 -> 4 ;
5 [label="entropy = 0.0\nsamples = 1\nvalue = [0, 0, 1]"] ;
4 -> 5 ;
6 [label="entropy = 1.0\nsamples = 2\nvalue = [0, 1, 1]"] ;
4 -> 6 ;
7 [label="X[2] <= 0.5\nentropy = 0.918\nsamples = 6\nvalue = [4, 2, 0]"] ;
1 -> 7 ;
8 [label="X[0] <= 1.5\nentropy = 0.918\nsamples = 3\nvalue = [1, 2, 0]"] ;
7 -> 8 ;
9 [label="entropy = 0.0\nsamples = 2\nvalue = [0, 2, 0]"] ;
8 -> 9 ;
10 [label="entropy = 0.0\nsamples = 1\nvalue = [1, 0, 0]"] ;
8 -> 10 ;
11 [label="entropy = 0.0\nsamples = 3\nvalue = [3, 0, 0]"] ;
7 -> 11 ;
12 [label="entropy = 0.0\nsamples = 12\nvalue = [0, 12, 0]"] ;
0 -> 12 [labeldistance=2.5, labelangle=-45, headlabel="False"] ;
}
树图如下