欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

机器学习系列之决策树

程序员文章站 2022-05-22 08:24:24
...
最近想把每个机器学习的算法,重新学习一遍。最好能自己编写一遍,但是一方面编程能力欠缺,另一方面时间有限。所以大本分代码都是跟着别人的技术博客,照葫芦画瓢。
无论是编程能力,还是机器学习算法,都有待进一步提升。请注意下面的代码不完整,完整代码请参照下面分享的大牛的技术博客。


#!/usr/bin/env python
# -*- coding:utf-8 -*-
__author__ = 'Great'

"""
输入:数据集
输出:决策树(分类结果)

#伪代码
def 创建决策树:
    
    if 数据集样本分类一致:
       创建带类标签的叶子节点
    else:
       寻找划分数据集,信息熵增益最大的特征
       据此划分数据集
       for 每个划分后的数据集:
           创建之树(递归)

def 加载数据集
def 计算熵
def 数据集划分
def 根据熵增益选择最佳划分
def 递归构建决策树
def 样本分类
def matplotlib 显示
def 决策树存储
"""

"""计算信息熵
H(x) = -∑[p(x)log2(p(x))]
熵是基于每种类别的概率计算
"""
from math import log
def entropy_cal(data):
    label = {}
    for one in data:
        label_data = one[-1]
        if label_data not in label.keys():
            label[label_data] =0
        label[label_data] += 1

    length = len(data)
    h_entropy = 0
    for item in label:
        px = label[item]/length
        h_entropy -= float(px)*log(px, 2)

    return h_entropy

"""数据集"""
def get_data():
    dataset = [[1, 1, 'yes'],
               [1, 1, 'yes'],
               [1, 0, 'no'],
               [0, 1, 'no'],
               [0, 1, 'no']]
    return dataset

"""测试"""

data = get_data()
test = entropy_cal(data)
print(test)

#数据集的划分
def splitdata(data, axis, value):
    next_data = []
    for one in data:
        if one[axis] == value:
            data_next = data[:axis]
            data_next.extend(data[axis+1:])
            next_data.append(data_next)
    return next_data
#计算信息熵增益
G = H(D) - H(Di|xi)
H(Di|xi) = -∑(Di/D)*∑[p(x)log2(p(x))]

#最佳数据集划分特征
def bestfeature(data):

    best_gain = 0
    base_entropy = entropy_cal(data)
    best_feat = -1

    h_f_entropy = 0
    length = len(data[0]) - 1
    for i in range(length):
        feat_list = [item[i] for item in data]
        unique_feat = set(feat_list)

        new_entropy = 0
        for value in unique_feat:
            sub_data = splitdata(data, i, value)
            prob = len(sub_data)/float(len(data))
            new_entropy += prob*entropy_cal(sub_data)

        get_gain = base_entropy - new_entropy
        if (get_gain > best_gain):
            best_gain = get_gain
            best_feat = i
    return best_feat

'''测试2'''

#data = get_data()
feat = bestfeature(data)
#print(feat)

不完整代码,未完成,待续。。。
参考:统计学习方法
https://www.cnblogs.com/muchen/p/6141978.html
https://www.cnblogs.com/luozeng/p/8604997.html
https://www.cnblogs.com/lianjiehere/p/6862890.html
 

相关标签: 决策树