欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页  >  IT编程

ID3决策树的Python实现以及可视化

程序员文章站 2022-03-28 16:15:12
算法介绍ID3决策树是比较经典的决策树,在周志华的机器学习中,生成决策树的算法为:算法的关键是如何选择最优划分属性,在ID3决策树中,用信息增益来指导决策树选择最优划分属性首先定义信息熵为:再定义信息增益为:一般而言,信息增益越大,意味着使用属性a进行划分所获得的纯度提升越大,因此我们选择最大信息增益的属性作为最优划分属性。Python实现思路树的数据表示既然要实现一棵树,首先要做的就是定义节点的数据结构,在C中,节点一般以结构体的形式存储,所以我们在Python中可以参考这一思路定义...

算法介绍

ID3决策树是比较经典的决策树,在周志华的机器学习中,生成决策树的算法为:
ID3决策树的Python实现以及可视化
算法的关键是如何选择最优划分属性,在ID3决策树中,用信息增益来指导决策树选择最优划分属性
首先定义信息熵为:
ID3决策树的Python实现以及可视化
再定义信息增益为:
ID3决策树的Python实现以及可视化
一般而言,信息增益越大,意味着使用属性a进行划分所获得的纯度提升越大,因此我们选择最大信息增益的属性作为最优划分属性。

Python实现思路

树的数据表示

既然要实现一棵树,首先要做的就是定义节点的数据结构,在C中,节点一般以结构体的形式存储,所以我们在Python中可以参考这一思路定义一个节点类:

class Node():
    """
    ID3决策树的节点
    parent -- 父节点
    sons -- 子节点集合,即在该节点最优划分属性下每个属性值的分支
    attrs -- 该节点下的最优划分属性
    parent_attrs_value -- 表示该节点是父节点哪一个属性的分支
    label -- 如果这个节点是叶子节点,则存放标签
    """
    def __init__(self, parent=None):
        self.parent = parent            
        self.sons = []                  
        self.attr = None                
        self.parent_attrs_value = None  
        self.label = None               

但在实际操作中,使用这一方法给代码的调试增加了难度,同时不利于后面用Graphviz包实现决策树的可视化,因此本文考虑使用另一种数据结构表示树,就是Python中的字典,我们先来看看对于西瓜书中给出的一颗决策树,用字典是如何表示的:
西瓜书中的一颗决策树:
ID3决策树的Python实现以及可视化
对应的Python字典表示:

tree = {'纹理':
            {'清晰':
                {'根蒂':
                    {'蜷缩':
                        {'label':'是'}, 
                    '稍蜷':
                        {'色泽':
                            {'青绿':
                                {'label':'是'}, 
                            '乌黑':
                                {'触感':
                                    {'硬滑':
                                        {'label':'是'}, 
                                    '软粘':
                                        {'label':'否'}}},
                            '浅白':
                                {'label':'是'}}},
                    '硬挺':
                        {'label':'否'}}}, 
            '稍糊':
                {'触感':
                    {'硬滑':
                        {'label':'否'}, 
                    '软粘':
                        {'label':'是'}}}, 
            '模糊':
                {'label':'否'}}}

如何可视化决策树

在本文中,使用Graphviz包进行决策树的可视化,这里是官网文档
只需使用几条简单的代码便可将决策树的节点绘制出来:

g = graphviz.Digraph(name=,filename=, format='png')
g.node(name=, label=, fontname="Microsoft YaHei", shape=)
g.edge(tail_name, head_name, label=, fontname="Microsoft YaHei")
g.view()

要注意,如果决策树的信息是中文的,要在fontname参数中指定中文字体,不然会出现乱码

Python代码

DecesionTree.py

import numpy as np
import scipy.io as sio
from collections import Counter
from graphviz import Digraph

class DecisionTree():
    """
    一个构建ID3决策树的类
    attrs -- 存放属性的字典, 字典中,键为属性名,值为属性的取值,最后一个属性为标签属性
    X -- 训练数据
    y -- 标签
    attr_idx -- 属性列索引
    tree -- 生成的决策树,用字典形式存放
    node_name -- 用于对决策树的可视化,在graphviz中对节点的命名
    """
    def __init__(self):
        self.attrs = None
        self.X = None
        self.y = None
        self.attr_idx = None
        self.tree = {}
        self.node_name = "0"


    def get_attrs(self, data):
        """
        对数据集进行处理,得到属性与对应的属性取值
        args:
        data -- 输入的数据矩阵, shape=(samples+1, features), dtype='<U?', 其中,第一行为属性,最后一列为标签
        returns:
        attrs -- 存放属性的字典, 字典中,键为属性名,值为属性的取值 
        """
        attrs = {}
        for i in range(data.shape[1]):
            attrs_values = sorted(set(data[1:, i]))
            attrs[data[0][i]] = attrs_values

        self.attrs = attrs
        return attrs


    def generate_tree(self, data):
        """
        生成决策树
        args:
        data -- 输入的数据矩阵, shape=(samples+1, features+label), dtype='<U?', 其中,第一行为属性,最后一列为标签
        """
        self.X = data[1:, :-1]
        self.y = data[1:, -1]

        # 先创建一个不含label属性的纯变量属性字典
        pure_attrs = self.attrs.copy()
        del(pure_attrs['label'])
        # 构造一个只含属性名的列表
        attr_names = [attr_name for attr_name in pure_attrs.keys()]
        # 将属性名编号,方便查找其在数据中对应的列
        attr_idx = {}
        for num, attr in enumerate(attr_names):
            attr_idx[attr] = num
        self.attr_idx = attr_idx

        # 生成根节点
        self.tree['root_node'] = {}
        self._generate_tree(self.X, self.y, self.tree['root_node'], pure_attrs, attr_idx)
        self.tree = self.tree['root_node']
        

    def _generate_tree(self, X, y, node, attrs, attr_idx):
        """
        递归生成决策树
        args:
        X -- 输入的数据矩阵, shape=(samples, features), dtype='<U?'
        y -- 标签, shape=(samples, )
        parent_node -- 父节点,此次递归函数是父节点的某一个属性值的递归
        attrs -- 属性字典, 即从父节点分支到现在的节点时,还没有被划分的属性
        attr_idx -- 属性在数据中列索引
        """

        #--------- 如果训练集中样本全属于同一类别 ---------#
        if len(set(y.tolist())) == 1:
            node['label'] = y[0]
            return

        #-------- 如果属性集为空集或者训练集中样本在属性集上取值相同 ---------#
        # 判断训练集样本在属性集中取值是否相同
        same = True
        for i in range(X.shape[1]):
            if len(set(X[:, i].tolist())) > 1:
                same = False
        
        if not attrs or same:
            y_counter = Counter(y)
            most_y = y_counter.most_common()[0][0]
            node['label'] = most_y
            return

        #--------- 选择最优属性生成分支 ---------#
        # 选出最优划分属性
        optimal_attr = self.choose_optimal_attr(X, y, attrs, attr_idx)
        node[optimal_attr] = {}
        node = node[optimal_attr]
        # 对于最优划分属性下每个属性值
        for attr_value in attrs[optimal_attr]:
            # 生成分支
            node[attr_value] = {}
            # 令Dv表示X中在optimal_attr上取值为attr_value的样本子集
            Dv = X.copy()
            attr_value_idx = Dv[:, attr_idx[optimal_attr]] == attr_value
            Dv = Dv[attr_value_idx, :]
            y_Dv = y[attr_value_idx]
            Dv = np.delete(Dv, attr_idx[optimal_attr], 1)
            # 如果Dv为空
            if Dv.size == 0:
                # 将分支节点标记为叶节点,其类别标记为X中样本最多的类,即统计y
                y_counter = Counter(y)
                most_y = y_counter.most_common()[0][0]
                node[attr_value]['label'] = most_y
            else:
                # 更新属性字典
                new_attrs = attrs.copy()
                del(new_attrs[optimal_attr])
                # 更新属性列索引
                new_attr_names = [new_attr_name for new_attr_name in new_attrs.keys()]
                new_attr_idx = {}
                for num, attr in enumerate(new_attr_names):
                    new_attr_idx[attr] = num
                self._generate_tree(Dv, y_Dv, node[attr_value], new_attrs, new_attr_idx)


    def compute_Ent(self, y):
        """
        计算给出属性名列表所对应的所有样本的信息熵
        args:
        y -- 标签数组, shape=(samples, )
        return:
        Ent -- 样本的信息熵
        """
        Ent = 0
        m = np.size(y)
        for label in self.attrs['label']:
            pk = np.sum(y == label)
            pk = pk / m
            log2pk = np.log2(pk + 1e-8) # 防止算得0,导致返回nan
            Ent -= pk * log2pk
        return Ent


    def choose_optimal_attr(self, X, y, attrs, attr_idx):
        """
        选择最优划分属性 划分标准:属性的信息增益
        args: 
        X -- 输入的数据矩阵, shape=(samples, features), dtype='<U?'
        y -- 标签, shape=(samples, )
        attrs -- 属性字典
        attr_idx -- 属性在数据中列索引
        returns:
        max_gain_attr -- 最大的信息增益对应的属性
        """
        # 计算当前所含属性对应所有样本的信息熵
        Ent = self.compute_Ent(y)
        m = np.size(y)
        # 记录当前最大的信息增益以及对应的属性
        max_gain = 0
        max_gain_attr = None
        
        # 计算每一个属性的信息增益
        for attr, idx in attr_idx.items():
            x = X[:, idx]
            gain = Ent
            # 计算一个属性中每个属性值的信息熵
            for attr_value in attrs[attr]:
                _y = y[x==attr_value]
                if _y.size != 0:
                    ent = self.compute_Ent(_y)
                else:
                    ent = 0
                gain -= np.size(_y) / m * ent
            if gain > max_gain:
                max_gain = gain
                max_gain_attr = attr
                
        return max_gain_attr


    def predict(self, predict_x):
        """
        预测样本结果
        args:
        predict_x -- 预测样本数据矩阵 shape=(samples, features)
        returns:
        predict_y -- 样本的预测结果 shape=(samples, )
        """
        s = predict_x.shape[0]
        predict_y = []
        for i in range(s):
            node = self.tree
            while(1):
                if 'label' in node.keys():
                    predict_y.append(node['label'])
                    break
                elif list(node.keys())[0] in self.attrs.keys():
                    attr = list(node.keys())[0]
                    idx = self.attr_idx[attr]
                    node = node[attr]
                else:
                    node = node[predict_x[i, idx]]
        return predict_y


    def tree_traversal(self, g, parent_node, parent_node_name, parent_attr, parent_attr_value):
        """
        对树进行遍历,生成可视化的节点
        g -- 要绘制的有向图
        parent_node -- 父节点
        parent_node_name -- 父节点在有向图中的代号
        parent_attr -- 父节点的属性
        parent_attr_value -- 父节点到该节点的属性值
        """
        if (parent_attr and parent_attr_value) is None:
            if 'label' in parent_node.keys():
                g.node(name=self.node_name, label=parent_node['label'], fontname="Microsoft YaHei")
                return
            else:
                attr = list(parent_node.keys())[0]
                node = parent_node[attr]
                parent_node_name = "0"
                for attr_value in node.keys():
                    self.tree_traversal(g, node[attr_value], parent_node_name, attr, attr_value)
        else:
            if 'label' in parent_node.keys():
                g.node(name=parent_node_name, label=parent_attr, fontname="Microsoft YaHei", shape='box')
                self.node_name = str(int(self.node_name) + 1)
                g.node(name=self.node_name, label=parent_node['label'], fontname="Microsoft YaHei")
                g.edge(parent_node_name, self.node_name, label=parent_attr_value, fontname="Microsoft YaHei")
            else:
                attr = list(parent_node.keys())[0]
                g.node(name=parent_node_name, label=parent_attr, fontname="Microsoft YaHei", shape='box')
                self.node_name = str(int(self.node_name) + 1)
                g.node(name=self.node_name, label=attr, fontname="Microsoft YaHei", shape='box')
                g.edge(parent_node_name, self.node_name, label=parent_attr_value, fontname="Microsoft YaHei")
                node = parent_node[attr]
                parent_node_name = self.node_name
                for attr_value in node.keys():
                    self.tree_traversal(g, node[attr_value], parent_node_name, attr, attr_value)


    def tree_visualize(self, file_name=None):
        """
        将决策树可视化
        args:
        file_name -- 若给出该参数,则将决策树保存为file_name的图片
        """
        if file_name:
            g = Digraph("Decision Tree", filename=file_name, format='png')
        else:
            g = Digraph("Decision Tree")
        self.tree_traversal(g, self.tree, None, None, None)
        g.view()


if __name__ == "__main__":
    pass

主函数,以西瓜树的西瓜数据集为例生成决策树,原数据集是Matlab的cell数组,并以mat文件存放,因此需要预处理一下:

import numpy as np
import scipy.io as sio
from DecisionTree import DecisionTree

def preprocess():
    raw_data = sio.loadmat('watermelon.mat')
    raw_data = raw_data['watermelon']
    data = np.zeros(raw_data.shape, dtype='<U20')

    for i in range(data.shape[0]):
        for j in range(data.shape[1]):
            data[i, j] = raw_data[i, j][0]

    data[0, -1] = 'label'
    return data

def main_1():
    """
    完整决策树
    """
    data = preprocess()
    DTree = DecisionTree()
    attrs = DTree.get_attrs(data)
    DTree.generate_tree(data)
    DTree.tree_visualize('watermelob_tree')

def main_2():
    """
    留出两个样本作为测试集
    """
    data = preprocess()
    train_idx = np.delete(np.arange(0, 18), [8, 17])
    test_idx = [8, 17]
    train_data = data[train_idx, :]
    test_data = data[test_idx, :]
    test_X = test_data[:, :-1]
    test_y = test_data[:, -1]

    DTree = DecisionTree()
    DTree.get_attrs(train_data)
    DTree.generate_tree(train_data)
    predict_y = DTree.predict(test_X)
    print(predict_y)
    DTree.tree_visualize('watermelon_tree_2')

main_1()

最终生成的决策树图片为:
ID3决策树的Python实现以及可视化
到这里我们就成功地用Python实现了ID3决策树!

本文地址:https://blog.csdn.net/SZU_Kwong/article/details/109634151