欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

机器学习实战 —— 决策树(sklearn api)

程序员文章站 2022-03-30 18:55:57
...

代码

import pandas as pd
from sklearn.preprocessing import LabelEncoder
from sklearn import tree
from sklearn.externals.six import StringIO

# pip install pydotplus
# pip install graphviz
import pydotplus

# Graphviz瞎子地址:http://www.graphviz.org/download/
import os
os.environ["PATH"] += os.pathsep + 'D:/program files (x86)/Graphviz2.38/bin'


def loadData():
    """
    加载文件,生成特征集和目标值集
    :return:
    """
    # 加载文件
    with open('lenses.txt') as fr:
        # 处理文件,去掉每行两头的空白符,以\t分隔每个数据
        lenses = [inst.strip().split('\t') for inst in fr.readlines()]

    # 提取每组数据的类别,保存在列表里
    lenses_targt = []
    for each in lenses:
        # 存储Label到lenses_targt中
        lenses_targt.append([each[-1]])

    lensesLabels = ['age', 'prescript', 'astigmatic', 'tearRate']

    # 保存lenses数据的字典,用于生成pandas
    lenses_dict = {}
    # 提取信息,生成字典
    for each_label in lensesLabels:
        # 保存lenses数据的临时列表
        lenses_list = []
        for each in lenses:
            # index方法用于从列表中找出某个值第一个匹配项的索引位置
            lenses_list.append(each[lensesLabels.index(each_label)])
        lenses_dict[each_label] = lenses_list
    # 生成pandas.DataFrame用于对象的创建
    lenses_pd = pd.DataFrame(lenses_dict)
    print(lenses_targt)
    print(lenses_pd)

    return lenses_pd, lenses_targt


def dataEncoder(data_pd):
    le = LabelEncoder()
    # 为每一列序列化
    for col in data_pd.columns:
        # fit_transform()干了两件事:fit找到数据转换规则,并将数据标准化
        # transform()直接把转换规则拿来用,需要先进行fit
        # transform函数是一定可以替换为fit_transform函数的,fit_transform函数不能替换为transform函数
        data_pd[col] = le.fit_transform(data_pd[col])
    # 打印归一化的结果
    print(data_pd)


def createTree(data_pd, labels):
    # 创建DecisionTreeClassifier()类
    clf = tree.DecisionTreeClassifier(criterion='entropy', max_depth=4)
    # 使用数据构造决策树
    # fit(X,y):Build a decision tree classifier from the training set(X,y)
    # 所有的sklearn的API必须先fit
    clf = clf.fit(data_pd.values.tolist(), labels)
    return clf


def exportTree(clf, feature_names):
    # 保存树
    with open("lenses.dot", 'w') as f:
        tree.export_graphviz(clf, out_file=f)

    # 打印树
    dot_data = StringIO()
    tree.export_graphviz(clf, out_file=dot_data,
                              feature_names=feature_names,
                              class_names=clf.classes_,
                              filled=True, rounded=True,
                              special_characters=True)
    graph = pydotplus.graph_from_dot_data(dot_data.getvalue())
    graph.write_pdf("tree.pdf")


def main():
    # 生成数据集和目标值集
    data_pd, targts = loadData()
    # 数据编码,序列化
    dataEncoder(data_pd)
    # 生成树
    tree = createTree(data_pd, targts)
    # 保存树、打印树
    exportTree(tree, data_pd.keys())

    # 预测
    print(tree.predict([[1, 1, 1, 0]]))


if __name__ == '__main__':
    main()

运行结果

[['no lenses'], ['soft'], ['no lenses'], ['hard'], ['no lenses'], ['soft'], ['no lenses'], ['hard'], ['no lenses'], ['soft'], ['no lenses'], ['hard'], ['no lenses'], ['soft'], ['no lenses'], ['no lenses'], ['no lenses'], ['no lenses'], ['no lenses'], ['hard'], ['no lenses'], ['soft'], ['no lenses'], ['no lenses']]
           age astigmatic prescript tearRate
0        young         no     myope  reduced
1        young         no     myope   normal
2        young        yes     myope  reduced
3        young        yes     myope   normal
4        young         no     hyper  reduced
5        young         no     hyper   normal
6        young        yes     hyper  reduced
7        young        yes     hyper   normal
8          pre         no     myope  reduced
9          pre         no     myope   normal
10         pre        yes     myope  reduced
11         pre        yes     myope   normal
12         pre         no     hyper  reduced
13         pre         no     hyper   normal
14         pre        yes     hyper  reduced
15         pre        yes     hyper   normal
16  presbyopic         no     myope  reduced
17  presbyopic         no     myope   normal
18  presbyopic        yes     myope  reduced
19  presbyopic        yes     myope   normal
20  presbyopic         no     hyper  reduced
21  presbyopic         no     hyper   normal
22  presbyopic        yes     hyper  reduced
23  presbyopic        yes     hyper   normal
    age  astigmatic  prescript  tearRate
0     2           0          1         1
1     2           0          1         0
2     2           1          1         1
3     2           1          1         0
4     2           0          0         1
5     2           0          0         0
6     2           1          0         1
7     2           1          0         0
8     0           0          1         1
9     0           0          1         0
10    0           1          1         1
11    0           1          1         0
12    0           0          0         1
13    0           0          0         0
14    0           1          0         1
15    0           1          0         0
16    1           0          1         1
17    1           0          1         0
18    1           1          1         1
19    1           1          1         0
20    1           0          0         1
21    1           0          0         0
22    1           1          0         1
23    1           1          0         0
['hard']

Process finished with exit code 0

lenses.dot

digraph Tree {
node [shape=box] ;
0 [label="X[3] <= 0.5\nentropy = 1.326\nsamples = 24\nvalue = [4, 15, 5]"] ;
1 [label="X[1] <= 0.5\nentropy = 1.555\nsamples = 12\nvalue = [4, 3, 5]"] ;
0 -> 1 [labeldistance=2.5, labelangle=45, headlabel="True"] ;
2 [label="X[2] <= 0.5\nentropy = 0.65\nsamples = 6\nvalue = [0, 1, 5]"] ;
1 -> 2 ;
3 [label="entropy = 0.0\nsamples = 3\nvalue = [0, 0, 3]"] ;
2 -> 3 ;
4 [label="X[0] <= 0.5\nentropy = 0.918\nsamples = 3\nvalue = [0, 1, 2]"] ;
2 -> 4 ;
5 [label="entropy = 0.0\nsamples = 1\nvalue = [0, 0, 1]"] ;
4 -> 5 ;
6 [label="entropy = 1.0\nsamples = 2\nvalue = [0, 1, 1]"] ;
4 -> 6 ;
7 [label="X[2] <= 0.5\nentropy = 0.918\nsamples = 6\nvalue = [4, 2, 0]"] ;
1 -> 7 ;
8 [label="X[0] <= 1.5\nentropy = 0.918\nsamples = 3\nvalue = [1, 2, 0]"] ;
7 -> 8 ;
9 [label="entropy = 0.0\nsamples = 2\nvalue = [0, 2, 0]"] ;
8 -> 9 ;
10 [label="entropy = 0.0\nsamples = 1\nvalue = [1, 0, 0]"] ;
8 -> 10 ;
11 [label="entropy = 0.0\nsamples = 3\nvalue = [3, 0, 0]"] ;
7 -> 11 ;
12 [label="entropy = 0.0\nsamples = 12\nvalue = [0, 12, 0]"] ;
0 -> 12 [labeldistance=2.5, labelangle=-45, headlabel="False"] ;
}

树图如下
机器学习实战 —— 决策树(sklearn api)