python3 || 决策树 ID3算法
参考:
[1]机器学习实战 http://blog.csdn.net/suipingsp/article/details/41927247
不过这本书中数据集是一个列表,利用字典的形式来实现树的结构
[2] http://blog.csdn.net/wzmsltw/article/details/51039928
这篇博客结合【机器学习实战】与【机器学习】内容,增加了对连续变量处理
我的代码:决策树ID3算法的实现
- 主要根据【机器学习】的步骤,参考【机器学习实战】
- 数据集采用数组的形式;
- 利用类的形式来实现树的结构;
- 增加连续变量的处理
决策树的原理就不介绍了,这里采用的是ID3算法,也就是采用信息增益来选择最优属性。C4.5也就是利用信息增益比来选择最优划分属性。
导入数据
采用【机器学习】中西瓜数据集的数据进行分析
西瓜数据集如下:下载链接
import pandas as pd
import numpy as np
data = pd.read_csv('watermelon2.csv')
data = np.array(data)
data = np.delete(data,0,1) #删除第一列编码
由于用numpy
中的loadtxt
中文编码会有问题,所以采用pandas
导入再转为数组
注:类属性必须在数据集中的最后一列
信息增益的计算
(1)计算数据集的经验熵
from math import log
from collections import Counter
def calShannonEnt(dataSet):
"""
计算H(D)
:param dataSet: array类型,类处于最后一列
:return: num类型
"""
numEntries = len(dataSet)
labelsCount = Counter(dataSet[:,-1]) #计算每个特征取值的个数
shannonEnt = 0.0
for key in labelsCount:
prob = labelsCount[key] / numEntries #利用频率来代替概率
shannonEnt -= prob * log(prob,2)
return shannonEnt
#调用
calShannonEnt(data)
#0.9975025463691153
(2)计算特征A对数据集D的经验条件熵
首先需要筛选出特定特征以及特定特征取值的数据集 D|A=
1、离散数据
def splitDataSet(dataSet,axis,value):
"""
需要计算H(D|A),则需要把计算该特征的数据分离出来,这里适用于离散变量
:param dataset:数据集 array
:param axis: 特征,第几列
:param value: 特征所取的值
:return:
"""
numberset = np.where(dataSet[:,axis] == value)
return dataSet[numberset,][0]
def calconEnt(dataSet,axis):
"""
计算离散数据下H(D|A)
"""
feature = Counter(dataSet[:,axis])
LEN = len(dataSet[:,axis])
conEnt = 0.0
for key in feature:
spdata = splitDataSet(dataSet, axis, key)
conEnt += float(feature[key]/LEN) * calShannonEnt(spdata)
return conEnt
#调用 第三列数据中取值为 清晰 的数据
splitDataSet(dataSet,3,'清晰')
2、连续数据
选择合适的划分点,将数据进行二分(最简单)
最优划分点的选取:
(1)给定数据集D和连续特征A,假设A上有n个不同取值,从小到大排序得{
(2)令
(3)生成n-1个候选集
(4)根据信息增益最大选出最优划分点
def splitContinuousDataSet(dataSet,axis,value,direction=0):
"""
采用二分法对连续变量进行划分,
:param dataSet: 数据集 array
:param axis:
:param value:根据value划分为两部分
:param direction:0-1取值,0表示<value 的数据; 1表示>=value的数据
:return:
"""
if direction==0:
numberset = np.where(dataSet[:, axis] < value)
else:
numberset = np.where(dataSet[:, axis] >= value)
return dataSet[numberset,][0]
def bestSplitforCondata(dataSet,axis):
sorteddata = sorted(dataSet[:, axis]) #对该特征取值进行排序
#生成候选集
splitPoint = [(sorteddata[i] + sorteddata[i + 1]) / 2 for i in range(len(sorteddata) - 1)]
LEN = len(dataSet[:,axis])
bestsplitInfoGain = float('-inf')
bestsplitevalue = float('-inf')
Ent = calShannonEnt(dataSet)
for value in splitPoint:
conEnt = 0.0
for i in range(2):
spdata = splitContinuousDataSet(dataSet, axis, value, direction=i)
Len = len(spdata[:,axis])
conEnt += float(Len/LEN) * calShannonEnt(spdata)
InfoEnt = Ent - conEnt
if InfoEnt > bestsplitInfoGain:
bestsplitInfoGain = InfoEnt
bestsplitevalue = value
return bestsplitInfoGain,bestsplitevalue
splitContinuousDataSet(dataSet,axis,0.7,direction=1)
bestSplitforCondata(dataSet,6)
#(0.2624392604045631, 0.38149999999999995)
bestSplitforCondata(dataSet,7)
#(0.34929372233065203, 0.126)
(3)计算信息增益
#计算信息增益
calShannonEnt(dataSet) - calconEnt(dataSet,4) #离散数据
Ent,split = bestSplitforCondata(dataSet,6)
calShannonEnt(dataSet) - Ent #连续数据
## 0.28915878284167895
选择最优划分特征
def chooseBestFeaturetoSplit(dataSet,Method = 'ID3'):
"""
需要考虑连续型变量和离散性变量
:param dataSet:
:param Method:
:return:
"""
numFeatures = np.shape(dataSet)[1] - 1
bestEnt = calShannonEnt(dataSet)
bestInfoGain = 0.0
bestFeature = None
bestsplitevalue = None
for i in range(numFeatures):
if isinstance(dataSet[0,i],float): #判断是否是连续型变量
infoGain,splitevalue = bestSplitforCondata(dataSet,i)
else:
infoGain = bestEnt - calconEnt(dataSet, i)
splitevalue = None
if infoGain > bestInfoGain:
bestInfoGain = infoGain
bestFeature = i
bestsplitevalue = splitevalue
return bestFeature,bestInfoGain,bestsplitevalue
chooseBestFeaturetoSplit(dataSet)
##(3, 0.3805918973682686, None)
每个节点根据投票法确定类
def majorityCnt( dataSet ):
labelsCount = Counter(dataSet[:, -1])
LEN = len(dataSet[:,-1])
label = labelsCount.most_common(1)[0][0]
prob = round(labelsCount[label]/LEN,3)
return label,prob
majorityCnt( dataSet )
#('否', 0.529)
构造决策树类
class decisiontree:
def __init__(self,label,prob,axis = None,value =None,*node):
self.label = label #该节点的类的标记
self.prob = prob #该节点为该类的概率
self.axis = axis #该节点采用哪个特征继续划分
self.value = value #离散:特征的取值 连续:0-1
#后续还会增加其他特征
# node :该节点下的子节点
# splitevalue :连续型数据,最佳二分划分点
生成树
def createTree(dataSet,e=0.001,num=2):
"""
:param dataSet: 数据集 array
:param e: 阈值,信息增益小于该阈值则停止
:param num: 阈值,数据集中数目小于该数目则停止
:return:返回树
"""
label, prob = majorityCnt(dataSet)
number = len(dataSet)
tree = decisiontree(label, prob)
axis, Gain,splitevalue = chooseBestFeaturetoSplit(dataSet)
tree.axis = axis
node = []
if Gain < e:
return tree
elif number <= num:
return tree
else:
if splitevalue == None: #判断是否是连续型变量
for value in set(dataSet[:, axis]):
newset = splitDataSet(dataSet, axis, value)
newtree = createTree(newset)
newtree.value = value
newtree.splitevalue = splitevalue
node.append(newtree)
tree.node = node
else:
for i in range(2):
newset = splitContinuousDataSet(dataSet, axis, splitevalue, direction=i)
newtree = createTree(newset)
newtree.value = i
newtree.splitevalue = splitevalue
node.append(newtree)
tree.node = node
return tree
mytree = createTree(dataSet)
vars(mytree)
"""
{'axis': 3,
'label': '否',
'node': [<__main__.decisiontree at 0x20e0ac04748>,
<__main__.decisiontree at 0x20e0ac04780>,
<__main__.decisiontree at 0x20e0ac04f28>],
'prob': 0.529,
'value': None}
"""
该数据集类别为“否”的比例占0.529,首先采用第四个特征(axis = 3)作为最优划分特征,根据该特征的取值分为三个子树(node存储的列表),可根据node,继续向下搜索子树。
搜索决策树
def FinLabel(tree,sample):
if not hasattr(tree,'node'):
label,prob = tree.label , tree.prob
return label,prob
LEN = len(tree.node)
axis = tree.axis
num = -1
for i in range(LEN):
if tree.node[i].splitevalue == None:
if sample[axis] == tree.node[i].value:
num = i
break
else:
if sample[axis] < tree.node[i].splitevalue and tree.node[i].value == 0:
num = i
if sample[axis] >= tree.node[i].splitevalue and tree.node[i].value == 1:
num = i
break
label, prob = FinLabel(tree.node[num],sample)
return label,prob
sample = np.delete(dataSet[0],-1,0)
#array(['青绿', '蜷缩', '浊响', '清晰', '凹陷', '硬滑', 0.6970000000000001, 0.46], dtype=object)
#利用该数据第一行数据来预测类别
FinLabel(mytree,sample) #进行预测
##('是', 1.0)
上一篇: ID3分类决策树算法
下一篇: 决策树算法原理及实现