Python实现决策树C4.5算法的示例
程序员文章站
2024-02-06 09:12:22
为什么要改进成c4.5算法
原理
c4.5算法是在id3算法上的一种改进,它与id3算法最大的区别就是特征选择上有所不同,一个是基于信息增益比,一个是基于信息增益。...
为什么要改进成c4.5算法
原理
c4.5算法是在id3算法上的一种改进,它与id3算法最大的区别就是特征选择上有所不同,一个是基于信息增益比,一个是基于信息增益。
之所以这样做是因为信息增益倾向于选择取值比较多的特征(特征越多,条件熵(特征划分后的类别变量的熵)越小,信息增益就越大);因此在信息增益下面加一个分母,该分母是当前所选特征的熵,注意:这里而不是类别变量的熵了。
这样就构成了新的特征选择准则,叫做信息增益比。为什么加了这样一个分母就会消除id3算法倾向于选择取值较多的特征呢?
因为特征取值越多,该特征的熵就越大,分母也就越大,所以信息增益比就会减小,而不是像信息增益那样增大了,一定程度消除了算法对特征取值范围的影响。
实现
在算法实现上,c4.5算法只是修改了信息增益计算的函数calcshannonentoffeature和最优特征选择函数choosebestfeaturetosplit。
calcshannonentoffeature在id3的calcshannonent函数上加了个参数feat,id3中该函数只用计算类别变量的熵,而calcshannonentoffeature可以计算指定特征或者类别变量的熵。
choosebestfeaturetosplit函数在计算好信息增益后,同时计算了当前特征的熵iv,然后相除得到信息增益比,以最大信息增益比作为最优特征。
在划分数据的时候,有可能出现特征取同一个值,那么该特征的熵为0,同时信息增益也为0(类别变量划分前后一样,因为特征只有一个取值),0/0没有意义,可以跳过该特征。
#coding=utf-8 import operator from math import log import time import os, sys import string def createdataset(traindatafile): print traindatafile dataset = [] try: fin = open(traindatafile) for line in fin: line = line.strip() cols = line.split('\t') row = [cols[1], cols[2], cols[3], cols[4], cols[5], cols[6], cols[7], cols[8], cols[9], cols[10], cols[0]] dataset.append(row) #print row except: print 'usage xxx.py traindatafilepath' sys.exit() labels = ['cip1', 'cip2', 'cip3', 'cip4', 'sip1', 'sip2', 'sip3', 'sip4', 'sport', 'domain'] print 'datasetlen', len(dataset) return dataset, labels #calc shannon entropy of label or feature def calcshannonentoffeature(dataset, feat): numentries = len(dataset) labelcounts = {} for feavec in dataset: currentlabel = feavec[feat] if currentlabel not in labelcounts: labelcounts[currentlabel] = 0 labelcounts[currentlabel] += 1 shannonent = 0.0 for key in labelcounts: prob = float(labelcounts[key])/numentries shannonent -= prob * log(prob, 2) return shannonent def splitdataset(dataset, axis, value): retdataset = [] for featvec in dataset: if featvec[axis] == value: reducedfeatvec = featvec[:axis] reducedfeatvec.extend(featvec[axis+1:]) retdataset.append(reducedfeatvec) return retdataset def choosebestfeaturetosplit(dataset): numfeatures = len(dataset[0]) - 1 #last col is label baseentropy = calcshannonentoffeature(dataset, -1) bestinfogainrate = 0.0 bestfeature = -1 for i in range(numfeatures): featlist = [example[i] for example in dataset] uniquevals = set(featlist) newentropy = 0.0 for value in uniquevals: subdataset = splitdataset(dataset, i, value) prob = len(subdataset) / float(len(dataset)) newentropy += prob *calcshannonentoffeature(subdataset, -1) #calc conditional entropy infogain = baseentropy - newentropy iv = calcshannonentoffeature(dataset, i) if(iv == 0): #value of the feature is all same,infogain and iv all equal 0, skip the feature continue infogainrate = infogain / iv if infogainrate > bestinfogainrate: bestinfogainrate = infogainrate bestfeature = i return bestfeature #feature is exhaustive, reture what you want label def majoritycnt(classlist): classcount = {} for vote in classlist: if vote not in classcount.keys(): classcount[vote] = 0 classcount[vote] += 1 return max(classcount) def createtree(dataset, labels): classlist = [example[-1] for example in dataset] if classlist.count(classlist[0]) ==len(classlist): #all data is the same label return classlist[0] if len(dataset[0]) == 1: #all feature is exhaustive return majoritycnt(classlist) bestfeat = choosebestfeaturetosplit(dataset) bestfeatlabel = labels[bestfeat] if(bestfeat == -1): #特征一样,但类别不一样,即类别与特征不相关,随机选第一个类别做分类结果 return classlist[0] mytree = {bestfeatlabel:{}} del(labels[bestfeat]) featvalues = [example[bestfeat] for example in dataset] uniquevals = set(featvalues) for value in uniquevals: sublabels = labels[:] mytree[bestfeatlabel][value] = createtree(splitdataset(dataset, bestfeat, value),sublabels) return mytree def main(): if(len(sys.argv) < 3): print 'usage xxx.py trainset outputtreefile' sys.exit() data,label = createdataset(sys.argv[1]) t1 = time.clock() mytree = createtree(data,label) t2 = time.clock() fout = open(sys.argv[2], 'w') fout.write(str(mytree)) fout.close() print 'execute for ',t2-t1 if __name__=='__main__': main()
以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持。