【Python】ID3算法的实现
程序员文章站
2024-02-11 12:45:22
...
ID3算法的实现
ID3实现
首先定义一个ID3DTree类来封装算法
# coding:utf-8
#Python3.7.2
from numpy import *
import math
import copy
import _pickle as pickle
class ID3DTree(object):
def __init__(self):
self.tree = {}
self.dataSet = []
self.labels = []
数据导入函数
def loadDataSet(self, path, labels):
recordlist = []
fp = open(path, 'rb')
content = fp.read()
fp.close()
rowlist = content.splitlines()
recordlist = [row.split('\t') for row in rowlist if row.strip()]
self.dataSet = recordlist
self.labels = labels
执行决策树函数
def train(self):
labels = copy.deepcopy(self.labels)
self.tree = self.buildTree(self.dataSet, labels)
决策树主方法
构建决策树:创建决策树主程序
def buildTree(self, dataSet, lebels):
cateList = [data[-1] for data in dataSet]
if cateList.count(cateList[0]) == len(cateList):
return cateList[0]
if len(dataSet[0]) == 1:
return self.maxCate(cateList)
bestFeat = self.getBestFeat(dataSet)
bestFeatLabel = labels[bestFeat]
tree = {bestFeatLabel: {}}
del(labels[bestFeat])
uniqueVals = set([data[bestFeat] for data in dataSet])
for value in uniqueVals:
subLabels = labels[:]
splitDataset = self.splitDataset(dataSet, bestFeat, value)
subTree = self.buildTree(splitDataset, subLabels)
tree[bestFeatLabel][value] = subTree
return tree
计算出现次数最多的类别标签
def maxCate(self, cateList):
items = dict([(cateList.count(i), i) for i in cateList])
return items[max(items.keys())]
计算最优特征
def getBestFeat(self, dataSet):
numFeatures = len(dataSet[0])-1
baseEntropy = self.computeEntropy(dataSet)
bestInfoGain = 0.0
bestFeature = -1
for i in range(numFeatures):
uniqueVals = set([data[i] for data in dataSet])
newEntropy = 0.0
for value in uniqueVals:
subDataSet = self.splitDataset(dataSet, i, value)
prob = len(subDataSet)/float(len(dataSet))
newEntropy += prob*self.computeEntropy(subDataSet)
infoGain = baseEntropy-newEntropy
if(infoGain > bestInfoGain):
bestInfoGain = infoGain
bestFeature = i
return bestFeature
计算信息熵
def computeEntropy(self, dataSet):
datalen = float(len(dataSet))
cateList = [data[-1] for data in dataSet]
items = dict([(i, cateList.count(i)) for i in cateList])
infoEntropy = 0.0
for key in items:
prob = float(items[key])/datalen
infoEntropy -= prob*math.log(prob, 2)
return infoEntropy
划分数据集
def splitDataset(self, dataSet, axis, value):
rtnList = []
for featVec in dataSet:
if featVec[axis] == value:
rFeatVec = featVec[:axis]
rFeatVec.extend(featVec[axis+1:])
rtnList.append(rFeatVec)
return rtnList
训练决策树
数据集下载:
链接: 数据集下载
提取码:1tgp
dtree = ID3DTree()
dtree.loadDataSet("数据集.dat", ["age", "revenue", "student", "credit"])
dtree.train()
print(dtree.tree)
执行结果:
{'age': {'1': 'yes', '0': {'student': {'1': 'yes', '0': 'no'}}, '2': {'credit': {'1': 'no', '0': 'yes'}}}}
输出的结果是一个数据字典,我们将此字典转换为树状的形式
可以看到图中创建的决策树结构与json相同,说明算法正确。
持久化决策树
ID3类也提供了专门的方法用于保存决策树到文件,并可从文件读取决策树到内存。
def storeTree(self,inputTree,filename): # 存储树到文件
fw = open(filename,'w')
pickle.dump(inputTree,fw)
fw.close()
def grabTree(self,filename): # 从文件抓取树
fr = open(filename)
return pickle.load(fr)
执行代码如下。
dtree.storeTree(dtree.tree,"data.tree")
mytree = dtree.grabTree("data.tree")
print mytre
执行结果如下。
{'age': {'1': 'yes', '0': {'student': {'1': 'yes', '0': 'no'}}, '2': {'credit': {'1': 'no', '0': 'yes'}}}
决策树分类
最后我们给出决策树的分类器代码。
def predict(self,inputTree,featLabels,testVec): # 分类器
root = inputTree.keys()[0] # 树根节点
secondDict = inputTree[root] # value-子树结构或分类标签
featIndex = featLabels.index(root) # 根节点在分类标签集中的位置
key = testVec[featIndex] # 测试集数组取值
valueOfFeat = secondDict[key]
if isinstance(valueOfFeat, dict):
classLabel = self.predict(valueOfFeat, featLabels, testVec) # 递归分类
else: classLabel = valueOfFeat
return classLabel
下面我们随机给出一个潜在客户,即一个行向量,使用学习出的决策树进行分类。执行预测代码如下。
dtree = ID3DTree()
labels = ["age","revenue","student","credit"]
vector = ['0','1','0','0'] # ['0','1','0','0','no']
mytree = dtree.grabTree("data.tree")
print "真实输出 ","no","->","决策树输出",dtree.predict(mytree,labels,vector)
算法评估
D3算法是比较早的机器学习算法,在1979年Quinlan就提出了该算法的思想。它以信息熵为度量标准,划分出决策树特征节点,每次优先选取信息量最多的属性,也就是使信息熵变为最小的属性,以构造一棵信息熵下降最快的决策树。
但是ID3在使用中也暴露出了一些问题。
- ID3算法的节点划分度量标准采用的是信息增益,信息增益偏向于选择特征值个数较多的特征。而取值个数较多的特征并不一定是最优的特征,所以需要改进选择属性的节点划分度量标准。
- ID3算法递归过程中需要依次计算每个特征值的,对于大型数据会生成比较复杂的决策树:层次和分支都很多,而其中某些分支的特征值概率很小,如果不加忽略就会造成过拟合的问题。即决策树对样本数据的分类精度较高,但在测试集上,分类的结果受决策树分支的影响很大。
上一篇: Java阻塞队列
下一篇: 决策树+Python3实现ID3