欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

决策树实战2-使用决策树预测隐形眼镜类型

程序员文章站 2022-04-02 10:29:08
...

这里是3.x版本的Python,对代码做了一些修改。
其中画图的函数直接使用的是原代码中的函数,也做了一些修改。

书本配套的数据和2.7版本的源码可以在这里获取 :https://www.manning.com/books/machine-learning-in-action

from math import log
from ch3.treePlotter import createPlot

def calShannonEntropy(dataset):
    """
    计算香浓熵
    :param dataset: 输入数据集
    :return: 熵
    """
    num = len(dataset)
    label_liat = {}
    for x in dataset:
        label = x[-1]  # the last column is label
        if label not in label_liat.keys():
            label_liat[label]=0
        label_liat[label] += 1
    shannonEnt = 0.0
    for key in label_liat:
        prob = float(label_liat[key]/num)
        shannonEnt -= prob * log(prob,2)

    # print("数据集的香浓熵为%f" % shannonEnt)
    return shannonEnt


def splitDate(dataset, axis, value):
    """
    根据某个特征划分数据集,
    :param dataset: 输入数据集
    :param axis: 数据集的每一列表示一个特征,axis取不同的值表示取不同的特征
    :param value: 根据这个特征划分的类别标记,在二叉树中常为2个,是或者否
    :return: 返回去掉了某个特征并且值是value的数据
    """
    newdataset = []
    for x in dataset:
        if x[axis] == value:
            reduceFeat = x[:axis]
            reduceFeat.extend(x[axis+1:])
            newdataset.append(reduceFeat)
    return newdataset

def keyFeatureSelect(dataset):
    """
    通过信息增益判断哪个特征是关键特征并返回这个特征
    :param dataset: 输入数据集
    :return: 特征
    """
    num_feature = len(dataset[0])-1
    base_entropy = calShannonEntropy(dataset)
    bestInfogain = 0
    bestfeature = -1
    for i in range(num_feature):
        featlist = [example[i] for example in dataset]
        feat_value = set(featlist)
        feat_entropy = 0
        for value in feat_value:
            subset = splitDate(dataset,i,value)
            prob = len(subset)/float(len(dataset))
            feat_entropy += prob * calShannonEntropy(subset)
        infoGain = base_entropy - feat_entropy
        # print("第%d个特征的信息增益%0.3f" %(i,infoGain))
        if (infoGain > bestInfogain):
            bestInfogain = infoGain
            bestfeature = i

    # print("第%d个特征最关键" % i)
    return  bestfeature


def voteClass(classlist):
    """
    通过投票的方式决定类别
    :param classlist: 输入类别的集合
    :return: 大多数类别的标签
    """
    import operator

    classcount = {}
    for x in classlist:
        if x not in classcount.keys():classcount[x]=0
        classcount += 1
    sortclass = sorted(classcount.iteritems(),key = operator.itemgetter(1),reverse=True)

    return sortclass[0][0]


def createTree(dataset,labels):
    """
    递归构建树
    :param dataset: dataset
    :param labels: labels of feature
    :return:树
    """
    labelsCopy = labels[:]          # 原代码没有这个,结果第一次运行之后第一个特征被删除了,所以做了修改
    classList = [example[-1] for example in dataset]
    if classList.count(classList[0]) == len(classList): #判断所有类标签是否相同
        return classList[0]
    if len(dataset[0]) == 1: # 是否历遍了所有特征(是否剩下一个特征)
        return voteClass(classList)
    bestFeat = keyFeatureSelect(dataset)
    bestFeatLabel = labelsCopy[bestFeat]
    tree = {bestFeatLabel:{}} # 使用字典实现树
    del labelsCopy[bestFeat]
    featValues = [example[bestFeat] for example in dataset]
    uniqueValue = set(featValues)
    for value in uniqueValue:
        subLabels = labelsCopy[:] #复制类标签到新的列表中,保证每次递归调用不改变原始列表
        tree[bestFeatLabel][value] = createTree(splitDate(dataset,bestFeat,value),subLabels)
    return tree


def decTreeClassify(inputTree, featLables, testVec):
    """
    使用决策树模型进行分类
    :param inputTree:
    :param featLables:
    :param testVec:
    :return:
    """


    firstStr = list(inputTree.keys())[0]    # 根节点
    secondDict = inputTree[firstStr]        # 节点下的值
    featIndex = featLables.index(firstStr)  # 获得第一个特征的label对应数据的位置
    for key in secondDict.keys():           # secondDict.keys()表示一个特征的取值
        if testVec[featIndex] == key:       # 比较测试向量中的值和树的节点值
            if type(secondDict[key]).__name__ == 'dict':
                classLabel = decTreeClassify(secondDict[key], featLables, testVec)
            else:
                classLabel = secondDict[key]
    return classLabel


def storeTree(inputTree, filename):
    """
    store the trained Tree.
    :param inputTree: the the trained Tree
    :param filename: save tree as file name
    :return: None
    """
    import pickle
    fw = open(filename,'wb')
    pickle.dump(inputTree,fw)
    fw.close()
    print("tree save as", filename)


def grabTree(filename):
    """
    read stored tree from disk
    :param filename: the goal file
    :return: Tree
    """
    print("load tree from disk...")
    import pickle
    fr = open(filename,"rb")
    return pickle.load(fr)



if __name__== '__main__':

    fr = open('lenses.txt')
    lense = [inst.strip().split('\t') for inst in fr.readlines()]
    train_set = lense[1:]
    test_set = lense[0]
    lenseLabels = ['age', 'prescript', 'astigmatic', 'tearRate']
    lenseTree = createTree(train_set, lenseLabels)
    createPlot(lenseTree)
    storeTree(lenseTree, 'lenseTree.txt')
    restoreTree = grabTree('lenseTree.txt')
    print(restoreTree)
    predict = decTreeClassify(restoreTree,lenseLabels,test_set)
    print(predict)

画出来的图:
决策树实战2-使用决策树预测隐形眼镜类型
运行结果:

{'tearRate': {'reduced': 'no lenses', 'normal': {'astigmatic': {'yes': {'prescript': {'myope': 'hard', 'hyper': {'age': {'pre': 'no lenses', 'presbyopic': 'no lenses', 'young': 'hard'}}}}, 'no': {'age': {'pre': 'soft', 'presbyopic': {'prescript': {'myope': 'no lenses', 'hyper': 'soft'}}, 'young': 'soft'}}}}}}

预测结果:

no lenses

参考《机器学习实战》

相关标签: 决策树