python实现C4.5决策树算法
程序员文章站
2024-01-21 18:19:04
c4.5算法使用信息增益率来代替id3的信息增益进行特征的选择,克服了信息增益选择特征时偏向于特征值个数较多的不足。信息增益率的定义如下:
# -*-...
c4.5算法使用信息增益率来代替id3的信息增益进行特征的选择,克服了信息增益选择特征时偏向于特征值个数较多的不足。信息增益率的定义如下:
# -*- coding: utf-8 -*- from numpy import * import math import copy import cpickle as pickle class c45dtree(object): def __init__(self): # 构造方法 self.tree = {} # 生成树 self.dataset = [] # 数据集 self.labels = [] # 标签集 # 数据导入函数 def loaddataset(self, path, labels): recordlist = [] fp = open(path, "rb") # 读取文件内容 content = fp.read() fp.close() rowlist = content.splitlines() # 按行转换为一维表 recordlist = [row.split("\t") for row in rowlist if row.strip()] # strip()函数删除空格、tab等 self.dataset = recordlist self.labels = labels # 执行决策树函数 def train(self): labels = copy.deepcopy(self.labels) self.tree = self.buildtree(self.dataset, labels) # 构件决策树:穿件决策树主程序 def buildtree(self, dataset, lables): catelist = [data[-1] for data in dataset] # 抽取源数据集中的决策标签列 # 程序终止条件1:如果classlist只有一种决策标签,停止划分,返回这个决策标签 if catelist.count(catelist[0]) == len(catelist): return catelist[0] # 程序终止条件2:如果数据集的第一个决策标签只有一个,返回这个标签 if len(dataset[0]) == 1: return self.maxcate(catelist) # 核心部分 bestfeat, featvaluelist= self.getbestfeat(dataset) # 返回数据集的最优特征轴 bestfeatlabel = lables[bestfeat] tree = {bestfeatlabel: {}} del (lables[bestfeat]) for value in featvaluelist: # 决策树递归生长 sublables = lables[:] # 将删除后的特征类别集建立子类别集 # 按最优特征列和值分隔数据集 splitdataset = self.splitdataset(dataset, bestfeat, value) subtree = self.buildtree(splitdataset, sublables) # 构建子树 tree[bestfeatlabel][value] = subtree return tree # 计算出现次数最多的类别标签 def maxcate(self, catelist): items = dict([(catelist.count(i), i) for i in catelist]) return items[max(items.keys())] # 计算最优特征 def getbestfeat(self, dataset): num_feats = len(dataset[0][:-1]) totality = len(dataset) baseentropy = self.computeentropy(dataset) conditionentropy = [] # 初始化条件熵 slpitinfo = [] # for c4.5,caculate gain ratio allfeatvlist = [] for f in xrange(num_feats): featlist = [example[f] for example in dataset] [spliti, featurevaluelist] = self.computesplitinfo(featlist) allfeatvlist.append(featurevaluelist) slpitinfo.append(spliti) resultgain = 0.0 for value in featurevaluelist: subset = self.splitdataset(dataset, f, value) appearnum = float(len(subset)) subentropy = self.computeentropy(subset) resultgain += (appearnum/totality)*subentropy conditionentropy.append(resultgain) # 总条件熵 infogainarray = baseentropy*ones(num_feats)-array(conditionentropy) infogainratio = infogainarray/array(slpitinfo) # c4.5信息增益的计算 bestfeatureindex = argsort(-infogainratio)[0] return bestfeatureindex, allfeatvlist[bestfeatureindex] # 计算划分信息 def computesplitinfo(self, featurevlist): numentries = len(featurevlist) featurevaulesetlist = list(set(featurevlist)) valuecounts = [featurevlist.count(featvec) for featvec in featurevaulesetlist] plist = [float(item)/numentries for item in valuecounts] llist = [item*math.log(item, 2) for item in plist] splitinfo = -sum(llist) return splitinfo, featurevaulesetlist # 计算信息熵 # @staticmethod def computeentropy(self, dataset): datalen = float(len(dataset)) catelist = [data[-1] for data in dataset] # 从数据集中得到类别标签 # 得到类别为key、 出现次数value的字典 items = dict([(i, catelist.count(i)) for i in catelist]) infoentropy = 0.0 for key in items: # 香农熵: = -p*log2(p) --infoentropy = -prob * log(prob, 2) prob = float(items[key]) / datalen infoentropy -= prob * math.log(prob, 2) return infoentropy # 划分数据集: 分割数据集; 删除特征轴所在的数据列,返回剩余的数据集 # dataset : 数据集; axis: 特征轴; value: 特征轴的取值 def splitdataset(self, dataset, axis, value): rtnlist = [] for featvec in dataset: if featvec[axis] == value: rfeatvec = featvec[:axis] # list操作:提取0~(axis-1)的元素 rfeatvec.extend(featvec[axis + 1:]) # 将特征轴之后的元素加回 rtnlist.append(rfeatvec) return rtnlist # 存取树到文件 def storetree(self, inputtree, filename): fw = open(filename,'w') pickle.dump(inputtree, fw) fw.close() # 从文件抓取树 def grabtree(self, filename): fr = open(filename) return pickle.load(fr)
调用代码
# -*- coding: utf-8 -*- from numpy import * from c45dtree import * dtree = c45dtree() dtree.loaddataset("dataset.dat",["age", "revenue", "student", "credit"]) dtree.train() dtree.storetree(dtree.tree, "data.tree") mytree = dtree.grabtree("data.tree") print mytree
以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持。