欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页  >  科技

决策树代码(数据集以西瓜集为例我自己手录)

程序员文章站 2022-03-01 16:38:38
决策树实验@[TOC](决策树实验)前言一、使用步骤1.源码2.数据集二、结果前言决策树理论数据这里不讲,只把我的代码贴出来。代码一部分来源机器学习实战,详细的注释是我自己加的。另一部分源码我自己写的(处理西瓜集的部分),如有错误欢迎指正。一、使用步骤1.源码代码如下(示例):from math import logimport operatordef convert(filename): fr = open(filename,encoding="utf-8") a...


前言

决策树理论数据这里不讲,只把我的代码贴出来。代码一部分来源机器学习实战,详细的注释是我自己加的。另一部分源码我自己写的(处理西瓜集的部分),如有错误欢迎指正。


一、使用步骤

1.源码

代码如下(示例):

from math import log
import operator

def convert(filename):
    fr = open(filename,encoding="utf-8")
    arrayOfLines = fr.readlines()
    #print(arrayOfLines)
    labels = arrayOfLines[0]
    attrubute = labels.strip().split(",")
    del(attrubute[0])
    del(attrubute[-1])
    del(arrayOfLines[0])
    fileLineNumber = len(arrayOfLines)
    for i in range(fileLineNumber):
        arrayOfLines[i] = arrayOfLines[i].strip().split(',')
        del(arrayOfLines[i][0])
    return arrayOfLines,attrubute



#定义函数CalShannonEnt()用于计算样本空间的信息熵
def calcShannonEnt(dataSet):
    numEntries = len(dataSet) #numEntries 计算数据集中样本数
    labelCounts = {}          #创建字典用来存储记录不同样本的个数
    for featureVector in dataSet:     #遍历结构体dataSet中的每一行
        currentLabel = featureVector[-1]   #获取每一行的最后一个元素,也就是标签
        if currentLabel not in labelCounts.keys():      #判断当前标签是否在字典中的key值中
            labelCounts[currentLabel] = 0  #初始化字典中标签和数量的键值对
        labelCounts[currentLabel] += 1   #如果dataSet中的标签在字典的key值中value值加一
    shannonEnt = 0.0                               #初始化信息熵
    for key in labelCounts:                         #判断key在
        prob = float(labelCounts[key])/numEntries   #计算不同样本在数据中的占比
        shannonEnt -= prob * log(prob,2)
    return shannonEnt

#创建一个数据集用于测试calcShannonEnt
#返回值有两个,数据集和标签
'''def createDataSet():
    dataSet = [[1,1,'yes'],
               [1,1,'yes'],
               [1,0,'no'],
               [0,1,'no'],
               [0,1,'no']]
    labels = ['no surfacing','flippers']
    return dataSet,labels'''

def splitDataSet(dataSet,axis,value):
    returnDataSet = []
    for featureVector in dataSet:
        if featureVector[axis] == value:
            reduceFeatureVector = featureVector[:axis]
            reduceFeatureVector.extend(featureVector[axis+1:])
            returnDataSet.append(reduceFeatureVector)
    return returnDataSet

#选择最好的划分数据集方式
def chooseBestFeatureToSplit(dataSet):
    numFeatures = len(dataSet[0]) -1  #计算特征向量的长度,等于矩阵行长减去一
    #numFeature = dataSet.shape[1] - 1
    baseEntropy = calcShannonEnt(dataSet)
    bestInfoGain = 0.0
    bestFeature = -1
    for i in range(numFeatures):
        featureList = [example[i]for example in dataSet]
        uniqueVals  = set(featureList)
        newEntropy = 0.0
        for value in uniqueVals:
            subDataSet = splitDataSet(dataSet,i,value)
            prob = len(subDataSet)/float(len(dataSet))
            newEntropy += prob * calcShannonEnt(subDataSet)
        infoGain = baseEntropy - newEntropy
        if(infoGain > bestInfoGain):
            bestInfoGain = infoGain
            bestFeature = i
    return bestFeature

def majorityCnt(classList):
    classCount = {}
    for vote in classList:
        if vote not in classCount.keys():
            classCount[vote] = 0
        classCount[vote] += 1
    sortedClassCount = sorted(classCount.items(),key = operator.itemgetter(1),reverse = True)
    return sortedClassCount[0][0]

def createTree(dataSet,labels):
    classList = [example[-1] for example in dataSet]
    if classList.count(classList[0]) == len(classList):
        return classList[0]
    if len(dataSet[0]) == 1:
        return majorityCnt(classList)
    bestFeature = chooseBestFeatureToSplit(dataSet)
    bestFeatureLabel = labels[bestFeature]
    myTree = {bestFeatureLabel:{}}
    del(labels[bestFeature])
    featureValues = [example[bestFeature] for example in dataSet]
    uniqueValues = set(featureValues)
    for value in uniqueValues:
        subLabels = labels[:]
        myTree[bestFeatureLabel][value] = createTree(splitDataSet(dataSet,bestFeature,value),subLabels)
    return myTree

dataSet,labels = convert("西瓜集.csv")
#print(dataSet)
#print(labels)
print(createTree(dataSet,labels))

2.数据集

代码如下(示例):

链接:https://pan.baidu.com/s/1hHdeRBXLgO89SekTT7UZPA
提取码:fw54


二、结果

决策树代码(数据集以西瓜集为例我自己手录)

本文地址:https://blog.csdn.net/qq_41893089/article/details/110499143