欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

手写决策树并可视化

程序员文章站 2022-04-02 10:42:39
...

决策树

可视化

手写决策树并可视化

描述

采用数据为UCI数据库中的Lenses Data Set(https://archive.ics.uci.edu/ml/datasets/Lenses)

包含
24个实例
3个分类:
1 : the patient should be fitted with hard contact lenses,
2 : the patient should be fitted with soft contact lenses,
3 : the patient should not be fitted with contact lenses.

4种属性:
1:age of the patient: (1) young, (2) pre-presbyopic, (3) presbyopic
2:spectacle prescription: (1) myope, (2) hypermetrope
3:astigmatic: (1) no, (2) yes
4:tear production rate: (1) reduced, (2) normal

源数据格式:

1 1 1 1 1 3
2 1 1 1 2 2
3 1 1 2 1 3

20 3 1 2 2 1
21 3 2 1 1 3
22 3 2 1 2 2
23 3 2 2 1 3
24 3 2 2 2 3

理论方法

采用递归的方法向下生成决策树
结束条件:

1.若节点内的所有实例都属于同一类
2.若节点内的所有实例的属性都相同

主要操作:
对节点内的属性进行划分,并判断若按此属性划分的信息增益
找到信息增益最大的属性,并进行划分
对划分后子节点递归该操作

实现方法

采用python作为语言进行编写
创建DT类
节点的数据结构为一个数组:[id, class,samples, ent,attr, children]
children为子节点的id数组,若为空则为叶子节点

可视化
可视化采用dot language,对决策树遍历后按格式写入.dot文件中,最后导出图像文件

附录

import math
import numpy as np
class DT():
    def __init__(self):
        self.nodes = []# [self, class,samples, ent,attrs, children]
        self.node_index = []
        self.node_n = 0
        pass
    def treeGeneration(self,D):
        class_n = list(set([i[-1] for i in D]))
        if len(class_n) == 1:
            self.nodes.append([self.node_n,class_n[0],len(D), 0, -1, []])
            self.node_index.append(self.node_n)
            self.node_n += 1
            return self.node_n - 1
        temp_i = D[0]
        for i in D:
            if i[:-1] != temp_i[:-1]:
                a_n = len(i) - 1
                gains=[]
                for i in range(0,a_n):
                    class_a_n = set([j[i] for j in D])
                    if len(class_a_n) == 1:
                        gains.append(0)
                    else:
                        d = [[] for m in range(0,len(class_a_n))]
                        for m in D:
                            d[list(class_a_n).index(m[i])].append(m)
                        gains.append(self.gain(D,d))
                best_a = np.argmax(gains)
                a_s = set([j[best_a] for j in D])
                d = [[] for m in range(0, len(a_s))]
                for m in D:
                    d[list(a_s).index(m[best_a])].append(m)
                children_index = []
                self.nodes.append([self.node_n, -1, len(D), self.ent(D), best_a, []])
                father_node = self.node_n
                self.node_index.append(self.node_n)
                self.node_n += 1
                for i in d:
                    children_index.append(self.treeGeneration(i))
                self.nodes[father_node][5] = children_index
                return father_node
        self.nodes.append([self.node_n,max([[i[-1] for i in D].count(j) for j in class_n]), len(D),self.ent(D),-1, []])
        self.node_index.append(self.node_n)
        self.node_n += 1
        return self.node_n - 1
    def ent(self,D):
        class_n = set([i[-1] for i in D])
        len_D = len(D)
        len_d = {}
        for i in class_n:
            len_d[i] = 0
        for i in D:
            len_d[i[-1]] += 1
        p = [i/len_D for i in len_d.values()]
        return self.ent_(p)
    def ent_(self,p):
        return -sum([i*math.log(i,2) for i in p])
    def gain(self,D,d):
        D_len = len(D)
        return self.ent(D) - sum([(len(i)/D_len)*self.ent(i) for i in d])
    def fit(self, data):
        self.treeGeneration(data)
        return
    def to_graph(self, filepath,attrs=None,classes=None):
        with open(filepath, "w") as dot_f:
            dot_f.write("digraph dtnodes{\n")
            # nodes definition
            for i in range(0,self.node_n):
                    dot_f.write("{}[label=\"class:{}\nsamples:{}\nentropy:{}\"];\n".format(i,self.nodes[i][1] if classes is None or self.nodes[i][1] == -1 else classes[self.nodes[i][1]] ,self.nodes[i][2],self.nodes[i][3]))
            # arcs
            for i in self.nodes:
                node_index = i[0]
                num = 1
                for j in i[5]:
                    dot_f.write("{}->{}[label=\"{}\"];\n".format(node_index, j,j if attrs is None else attrs[i[4]]+"="+str(num)))
                    num+=1
            dot_f.write("}")
         
data = []
with open("./lenses.data", "r") as f:
    for i in f:
        data.append([int(j) for j in i.replace("  "," ").split(" ")[1:]])
# print(data)
dt = DT()
dt.fit(data)
dt.to_graph("./demo.dot",attrs={0:"age",1:"spectacle prescription",2:"astigmatic",3:"tear production rate"},classes={1:"hard",2:"soft",3:"not"})
相关标签: 决策树 python