决策树——ID3算法实现
决策树:构建一个基于属性的树形分类器。
1.每个非叶节点表示一个特征属性上的测试(分割),
2.每个分支代表这个特征属性在某个值域上的输出,
3.每个叶节点存放一个类别。
使用决策树进行决策的过程就是从根节点开始,测试待分类项中相应的特征属性,并按照其值选择输出分支,直到到达叶子节点,将叶子节点存放的类别作为决策结果。
采用递归的方法进行建树
递归的结束条件
1.当前结点样本均属于同一类别,无需划分。
2.当前属性集为空。
3.所有样本在当前属性集上取值相同,无法划分。
4.当前结点包含的样本集合为空,不能划分。
决策树的核心
经过属性划分后,不同类样本被更好的分离
理想情况:划分后样本被完美分类。即每个分支的样本都属性同一类。
实际情况:不可能完美划分!尽量使得每个分支某一类样本比例尽量高!即尽量提高划分后子集的纯度。
划分的目标:提升划分后子集的纯度,降低划分后子集的不纯度
决策树算法分类
决策树算法的区别主要在于所采用的纯度判别标准
ID3算法:
使用信息增益作为判别标准
信息熵计算公式:
假设属性��有��可能取值{��^1,��^2,⋯⋯,��^��}, ��^��对应划分后的数据子集为��^��.
信息增益:
信息增益越大,说明当前的划分效果越好
C4.5算法
使用信息增益率作为判别准则
����(��)称为属性��的“固有值”(Intrinsic Value)
信息增益率越大,说明当前划分效果越好
CART算法
使用基尼系数作为判别准则
实验环境
python3.6
macOS 10.12
代码思路
BuildTree函数:在该函数中完成递归建树,递归返回条件的判断,建立存储树所用的字典,打印各类信息
ChooseAttr函数:在该函数中完成选出最佳特征的功能,根据Ent函数计算出的所有样本的信息熵和加权的信息熵计算信息增益,信息增越大的意味着该属性的纯度越高,选取信息增益最大的属性为最佳属性。
Ent函数:计算输入样本的信息熵,通过输入Sample的最后一列统计出该正例与反例出现的概率,根据信息熵公式计算信息熵
SpiltData函数:该函数用于对数据进行拆分,去掉已经判断过的属性对应的样本
CreatePlot函数:用于决策树的可视化
数据集
使用西瓜书上的西瓜数据集2.0
为了方便计算,将西瓜数据集的内容转换为数字
色泽: 0:青绿 1:乌黑 2:浅白
根底: 0:蜷缩 1:少蜷 2:硬挺
敲声: 0:浊响 1:沉闷 2:清脆
纹理: 0:清晰 1:稍糊 2:模糊
脐部: 0:凹陷 1:稍凹 2:平坦
触感: 0:硬滑 1:软黏
好瓜: 0:不是 1:是
上代码
import math
import numpy
import DrawTree
数据集,属性列表
#初始化一个属性列表
AttrArr=["色泽","根蒂","敲声","纹理","脐部","触感","好瓜"]
#此处使用西瓜数据集2.0
data = numpy.array(
[[0,0,0,0,0,0,1],
[1,0,1,0,0,0,1],
[1,0,0,0,0,0,1],
[0,0,1,0,0,0,1],
[2,0,0,0,0,0,1],
[0,1,0,0,1,1,1],
[1,1,0,1,1,1,1],
[1,1,0,0,1,0,1],
[1,1,1,1,1,0,0],
[0,2,2,0,2,1,0],
[2,2,2,2,2,0,0],
[2,0,0,2,2,1,0],
[0,1,0,1,0,0,0],
[2,1,1,1,0,0,0],
[1,1,0,0,1,1,0],
[2,0,0,2,2,0,0],
[0,0,1,1,1,0,0]]
)
BuildTree函数:在该函数中完成递归建树,递归返回条件的判断,建立存储树所用的字典,打印各类信息
#建树的函数
def BuildTree(Sample,Label):
#Sample 为输入的数据
#Label 为对应的标签
#获取输入数据的的大小
[Count, Attr] = Sample.shape;
n = Attr - 1;
m = Count - 1;
print("Sample:")
print(Sample)
#使用classlist存储表示正例与反例所在的列
classList = Sample[:, n];
# 记录第一个类中的个数
classOne = 1;
for i in range(1, Count):
if (classList[i] == classList[0]):
classOne = classOne + 1;
#如果当前结点包含的样本全属于同一个样本,则停止划分
if (classOne == Count):
print("Final")
print(Sample)
if(classList[0]==0):return "no" #通过 classlist 的 0 1 判断 最终的结果
if (classList[0] == 1): return "yes" #通过 classlist 的 0 1 判断 最终的结果
#如果当前属性集为空,无法划分
if (Attr == 0):
print("Final")
print(Sample)
return classList[0]
#使用ChooseAttr函数获取最佳的特征对应编号
bestAttr = ChooseAttr(Sample)
#通过最佳特征的编号获得标签名
name=Label[bestAttr]
#新建一个字典用于存储树
Tree = {name: {}}
#打印出最佳特征
print("最佳特征:", name);
#取出对最佳属性对应的一列 并去掉重复值 用于得出一个属性下所包含的取值
featValue = numpy.unique(Sample[:, bestAttr])
#计算出一个属性下包含的取值
numOfFeatValue = len(featValue);
#最佳属性的每一个评级都打印出来
for i in range(0, numOfFeatValue):
print(name, "评级:", featValue[i])
subLabels = Label[:]
#对现有的树执行 SpiltData 去掉计算过的属性所对应的样本 递归调用buildTree
Tree[name][i] = BuildTree(SpiltData(Sample, bestAttr, featValue[i]),subLabels)
print('-------------------------');
return Tree
ChooseAttr函数:在该函数中完成选出最佳特征的功能,根据Ent函数计算出的所有样本的信息熵和加权的信息熵计算信息增益,信息增越大的意味着该属性的纯度越高,选取信息增益最大的属性为最佳属性。
#Choose函数用于 选出最佳的属性
def ChooseAttr(Sample):
#Sample 为输入的数据
#获取输入数据的大小
[Count, Attr] = Sample.shape
numOfFeature = Attr - 1
#计算整个数据的信息熵
baseEnt = Ent(Sample)
#初始信息增益
bestInfoGain = 0.0
#初始的最佳属性为 -1
bestFeature = -1
#遍历当前所有属性
for j in range(0, numOfFeature):
#记录出每一个属性中的取值 并去掉重复值
featureTemp = numpy.unique(Sample[:, j])
#记录下属性取值的个数
numF = len(featureTemp)
newEnt = 0.0;
#遍历所有的取值
for i in range(0, numF):
#去除掉当前已经判断的样本
subSet = SpiltData(Sample, j, featureTemp[i])
#得到每一个取值的个数
[newCount, newAttr] = subSet.shape
#计算每一个取值出现的概率
prob = newCount / Count
#计算新的信息熵
newEnt = newEnt + prob * Ent(subSet)
#计算信息增益
infoGain = baseEnt - newEnt
#找到信息增益最大的属性 作为当前最佳属性
if (infoGain > bestInfoGain):
bestInfoGain = infoGain
bestFeature = j
return bestFeature
Ent函数:计算输入样本的信息熵,通过输入Sample的最后一列统计出该正例与反例出现的概率,根据信息熵公式计算信息熵
#Ent函数用于计算信息熵
def Ent(Sample):
#Sample为输入的数据
#得到输入数据的大小
[Count, Attr] = Sample.shape
n = Attr - 1
m = Count - 1
#获取正例与反例所在的列
label = Sample[:, n]
#去掉重复的数据
deal = numpy.unique(label)
#得到最后判别情况的个数
numOfLabel = len(deal)
#新建一个概率list 用于存储概率
prob = numpy.zeros([numOfLabel, 2])
for i in range(0, numOfLabel):
#获取正例与反例
prob[i, 0] = deal[i]
for j in range(0, Count):
#对正例 与 反例 进行计数
if (label[j] == deal[i]):
prob[i, 1] = prob[i, 1] + 1
#计算出概率
prob[:, 1] = prob[:, 1] / Count
ent = 0
#根据信息熵公示计算出信息熵
for i in range(0, numOfLabel):
ent = ent - prob[i, 1] * math.log2(prob[i, 1])
return ent
SpiltData函数:该函数用于对数据进行拆分,去掉已经判断过的属性对应的样本
#对数据进行拆分 去掉已经判断过的属性所对应的样本
def SpiltData(Sample, axis, value):
#Sample 代表输入的数据
#axis 表示要删除值所在的行
#value表示要删除的值
[Count, Attr] = Sample.shape
subSet = Sample
k = 0
#对每一个样本都做判断 把已经做过判断的样本删掉
for i in range(0, Count):
if (Sample[i, axis]) != value:
subSet=numpy.delete(subSet,i-k,0)
k = k + 1
return subSet
主函数部分
TreeDict=BuildTree(data,AttrArr)
实验结果
存储树的字典:
{‘纹理’: {0: {‘根蒂’: {0: ‘yes’, 1: {‘色泽’: {0: ‘yes’, 1: {‘触感’: {0: ‘yes’, 1: ‘no’}}}}, 2: ‘no’}}, 1: {‘触感’: {0: ‘no’, 1: ‘yes’}}, 2: ‘no’}}
递归的过程:
Sample:
[[0 0 0 0 0 0 1]
[1 0 1 0 0 0 1]
[1 0 0 0 0 0 1]
[0 0 1 0 0 0 1]
[2 0 0 0 0 0 1]
[0 1 0 0 1 1 1]
[1 1 0 1 1 1 1]
[1 1 0 0 1 0 1]
[1 1 1 1 1 0 0]
[0 2 2 0 2 1 0]
[2 2 2 2 2 0 0]
[2 0 0 2 2 1 0]
[0 1 0 1 0 0 0]
[2 1 1 1 0 0 0]
[1 1 0 0 1 1 0]
[2 0 0 2 2 0 0]
[0 0 1 1 1 0 0]]
最佳特征: 纹理
纹理 评级: 0
Sample:
[[0 0 0 0 0 0 1]
[1 0 1 0 0 0 1]
[1 0 0 0 0 0 1]
[0 0 1 0 0 0 1]
[2 0 0 0 0 0 1]
[0 1 0 0 1 1 1]
[1 1 0 0 1 0 1]
[0 2 2 0 2 1 0]
[1 1 0 0 1 1 0]]
最佳特征: 根蒂
根蒂 评级: 0
Sample:
[[0 0 0 0 0 0 1]
[1 0 1 0 0 0 1]
[1 0 0 0 0 0 1]
[0 0 1 0 0 0 1]
[2 0 0 0 0 0 1]]
Final
[[0 0 0 0 0 0 1]
[1 0 1 0 0 0 1]
[1 0 0 0 0 0 1]
[0 0 1 0 0 0 1]
[2 0 0 0 0 0 1]]
根蒂 评级: 1
Sample:
[[0 1 0 0 1 1 1]
[1 1 0 0 1 0 1]
[1 1 0 0 1 1 0]]
最佳特征: 色泽
色泽 评级: 0
Sample:
[[0 1 0 0 1 1 1]]
Final
[[0 1 0 0 1 1 1]]
色泽 评级: 1
Sample:
[[1 1 0 0 1 0 1]
[1 1 0 0 1 1 0]]
最佳特征: 触感
触感 评级: 0
Sample:
[[1 1 0 0 1 0 1]]
Final
[[1 1 0 0 1 0 1]]
触感 评级: 1
Sample:
[[1 1 0 0 1 1 0]]
Final
[[1 1 0 0 1 1 0]]
根蒂 评级: 2
Sample:
[[0 2 2 0 2 1 0]]
Final
[[0 2 2 0 2 1 0]]
纹理 评级: 1
Sample:
[[1 1 0 1 1 1 1]
[1 1 1 1 1 0 0]
[0 1 0 1 0 0 0]
[2 1 1 1 0 0 0]
[0 0 1 1 1 0 0]]
最佳特征: 触感
触感 评级: 0
Sample:
[[1 1 1 1 1 0 0]
[0 1 0 1 0 0 0]
[2 1 1 1 0 0 0]
[0 0 1 1 1 0 0]]
Final
[[1 1 1 1 1 0 0]
[0 1 0 1 0 0 0]
[2 1 1 1 0 0 0]
[0 0 1 1 1 0 0]]
触感 评级: 1
Sample:
[[1 1 0 1 1 1 1]]
Final
[[1 1 0 1 1 1 1]]
纹理 评级: 2
Sample:
[[2 2 2 2 2 0 0]
[2 0 0 2 2 1 0]
[2 0 0 2 2 0 0]]
Final
[[2 2 2 2 2 0 0]
[2 0 0 2 2 1 0]
[2 0 0 2 2 0 0]]
树的可视化
此处使用的是网上一个常见的可视化代码
可视化函数
#绘制树形图
import matplotlib
# matplotlib.use('qt4agg')
from matplotlib.font_manager import *
import matplotlib.pyplot as plt
myfont = FontProperties(fname='/Users/zhangxuancheng/Library/Fonts/simhei.ttf')
decision_node = dict(boxstyle="sawtooth",fc="0.8")
leaf_node = dict(boxstyle="round4",fc="0.8")
arrow_args = dict(arrowstyle="<-")
plt.rcParams['font.sans-serif'] = ['SimHei']
#获取树的叶子结点个数(确定图的宽度)
def get_leaf_num(tree):
leaf_num = 0
first_key = list(tree.keys())[0]
next_dict = tree[first_key]
for key in next_dict.keys():
if type(next_dict[key]).__name__=="dict":
leaf_num +=get_leaf_num(next_dict[key])
else:
leaf_num +=1
return leaf_num
#获取数的深度(确定图的高度)
def get_tree_depth(tree):
depth = 0
first_key = list(tree.keys())[0]
next_dict = tree[first_key]
for key in next_dict.keys():
if type(next_dict[key]).__name__ == "dict":
thisdepth = 1+ get_tree_depth(next_dict[key])
else:
thisdepth = 1
if thisdepth>depth: depth = thisdepth
return depth
def plotNode(nodeTxt, centerPt, parentPt, nodeType):
createPlot.ax1.annotate(nodeTxt, xy=parentPt, xycoords='axes fraction',
xytext=centerPt, textcoords='axes fraction',
va="center", ha="center", bbox=nodeType, arrowprops=arrow_args)
#在父子节点间填充文本信息
def plotMidText(cntrPt, parentPt, txtString):
xMid = (parentPt[0] - cntrPt[0]) / 2.0 + cntrPt[0]
yMid = (parentPt[1] - cntrPt[1]) / 2.0 + cntrPt[1]
createPlot.ax1.text(xMid, yMid, txtString, va="center", ha="center", rotation=30)
def plotTree(myTree, parentPt, nodeTxt):
numLeafs = get_leaf_num(myTree)
depth = get_tree_depth(myTree)
firstStr = list(myTree.keys())[0]
cntrPt = (plotTree.xOff + (1.0 + float(numLeafs)) / 2.0 / plotTree.totalW, plotTree.yOff)
plotMidText(cntrPt, parentPt, nodeTxt)
plotNode(firstStr, cntrPt, parentPt, decision_node)
secondDict = myTree[firstStr]
plotTree.yOff = plotTree.yOff - 1.0 / plotTree.totalD
for key in secondDict.keys():
if type(secondDict[
key]).__name__ == 'dict':
plotTree(secondDict[key], cntrPt, str(key))
else:
plotTree.xOff = plotTree.xOff + 1.0 / plotTree.totalW
plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leaf_node)
plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))
plotTree.yOff = plotTree.yOff + 1.0 / plotTree.totalD
def createPlot(inTree):
fig = plt.figure(1, facecolor='white')
fig.clf()
axprops = dict(xticks=[], yticks=[])
createPlot.ax1 = plt.subplot(111, frameon=False, **axprops)
plotTree.totalW = float(get_leaf_num(inTree))
plotTree.totalD = float(get_tree_depth(inTree))
plotTree.xOff = -0.5 / plotTree.totalW
plotTree.yOff = 1.0
plotTree(inTree, (0.5, 1.0), '')
plt.show()
调用该函数
DrawTree.createPlot(TreeDict)
可视化结果
上一篇: maven配置生成java doc文档中文乱码问题解决方案
下一篇: ID3