决策树代码(数据集以西瓜集为例我自己手录)
程序员文章站
2022-03-01 16:38:38
决策树实验@[TOC](决策树实验)前言一、使用步骤1.源码2.数据集二、结果前言决策树理论数据这里不讲,只把我的代码贴出来。代码一部分来源机器学习实战,详细的注释是我自己加的。另一部分源码我自己写的(处理西瓜集的部分),如有错误欢迎指正。一、使用步骤1.源码代码如下(示例):from math import logimport operatordef convert(filename): fr = open(filename,encoding="utf-8") a...
前言
决策树理论数据这里不讲,只把我的代码贴出来。代码一部分来源机器学习实战,详细的注释是我自己加的。另一部分源码我自己写的(处理西瓜集的部分),如有错误欢迎指正。
一、使用步骤
1.源码
代码如下(示例):
from math import log
import operator
def convert(filename):
fr = open(filename,encoding="utf-8")
arrayOfLines = fr.readlines()
#print(arrayOfLines)
labels = arrayOfLines[0]
attrubute = labels.strip().split(",")
del(attrubute[0])
del(attrubute[-1])
del(arrayOfLines[0])
fileLineNumber = len(arrayOfLines)
for i in range(fileLineNumber):
arrayOfLines[i] = arrayOfLines[i].strip().split(',')
del(arrayOfLines[i][0])
return arrayOfLines,attrubute
#定义函数CalShannonEnt()用于计算样本空间的信息熵
def calcShannonEnt(dataSet):
numEntries = len(dataSet) #numEntries 计算数据集中样本数
labelCounts = {} #创建字典用来存储记录不同样本的个数
for featureVector in dataSet: #遍历结构体dataSet中的每一行
currentLabel = featureVector[-1] #获取每一行的最后一个元素,也就是标签
if currentLabel not in labelCounts.keys(): #判断当前标签是否在字典中的key值中
labelCounts[currentLabel] = 0 #初始化字典中标签和数量的键值对
labelCounts[currentLabel] += 1 #如果dataSet中的标签在字典的key值中value值加一
shannonEnt = 0.0 #初始化信息熵
for key in labelCounts: #判断key在
prob = float(labelCounts[key])/numEntries #计算不同样本在数据中的占比
shannonEnt -= prob * log(prob,2)
return shannonEnt
#创建一个数据集用于测试calcShannonEnt
#返回值有两个,数据集和标签
'''def createDataSet():
dataSet = [[1,1,'yes'],
[1,1,'yes'],
[1,0,'no'],
[0,1,'no'],
[0,1,'no']]
labels = ['no surfacing','flippers']
return dataSet,labels'''
def splitDataSet(dataSet,axis,value):
returnDataSet = []
for featureVector in dataSet:
if featureVector[axis] == value:
reduceFeatureVector = featureVector[:axis]
reduceFeatureVector.extend(featureVector[axis+1:])
returnDataSet.append(reduceFeatureVector)
return returnDataSet
#选择最好的划分数据集方式
def chooseBestFeatureToSplit(dataSet):
numFeatures = len(dataSet[0]) -1 #计算特征向量的长度,等于矩阵行长减去一
#numFeature = dataSet.shape[1] - 1
baseEntropy = calcShannonEnt(dataSet)
bestInfoGain = 0.0
bestFeature = -1
for i in range(numFeatures):
featureList = [example[i]for example in dataSet]
uniqueVals = set(featureList)
newEntropy = 0.0
for value in uniqueVals:
subDataSet = splitDataSet(dataSet,i,value)
prob = len(subDataSet)/float(len(dataSet))
newEntropy += prob * calcShannonEnt(subDataSet)
infoGain = baseEntropy - newEntropy
if(infoGain > bestInfoGain):
bestInfoGain = infoGain
bestFeature = i
return bestFeature
def majorityCnt(classList):
classCount = {}
for vote in classList:
if vote not in classCount.keys():
classCount[vote] = 0
classCount[vote] += 1
sortedClassCount = sorted(classCount.items(),key = operator.itemgetter(1),reverse = True)
return sortedClassCount[0][0]
def createTree(dataSet,labels):
classList = [example[-1] for example in dataSet]
if classList.count(classList[0]) == len(classList):
return classList[0]
if len(dataSet[0]) == 1:
return majorityCnt(classList)
bestFeature = chooseBestFeatureToSplit(dataSet)
bestFeatureLabel = labels[bestFeature]
myTree = {bestFeatureLabel:{}}
del(labels[bestFeature])
featureValues = [example[bestFeature] for example in dataSet]
uniqueValues = set(featureValues)
for value in uniqueValues:
subLabels = labels[:]
myTree[bestFeatureLabel][value] = createTree(splitDataSet(dataSet,bestFeature,value),subLabels)
return myTree
dataSet,labels = convert("西瓜集.csv")
#print(dataSet)
#print(labels)
print(createTree(dataSet,labels))
2.数据集
代码如下(示例):
链接:https://pan.baidu.com/s/1hHdeRBXLgO89SekTT7UZPA
提取码:fw54
二、结果
本文地址:https://blog.csdn.net/qq_41893089/article/details/110499143