欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页  >  IT编程

Python实现数据挖掘基本决策树算法(信息增益)

程序员文章站 2022-07-08 15:55:46
算法实现思路首先,事务集可以采用手动输入,也可以预定义,我根据习题8.7使用列表字典结构预定义事务集。其次,分别编写获取候选属性、获取数据集中的元组分类所需的期望信息、获取根据指定候选属性分类所需的期望信息、计算信息增益、根据指定属性对数据集进行分类、判断数据集是否属于同一类,共6种方法;初始数据集采用列表保存,列表中的每一项为字典,保存属性及对应的属性值。获取候选属性方法:获取初始数据集中某一项所有键的列表,去除其中的数量“count”和类别“status”获取数据集中的元组分类所需的期望信息:...

算法实现思路

首先,事务集可以采用手动输入,也可以预定义,我在这里使用列表包含字典结构预定义事务集。
其次,分别编写获取候选属性、获取数据集中的元组分类所需的期望信息、获取根据指定候选属性分类所需的期望信息、计算信息增益、根据指定属性对数据集进行分类、判断数据集是否属于同一类,共6种方法;
初始数据集采用列表保存,列表中的每一项为字典,保存属性及对应的属性值。
获取候选属性方法:获取初始数据集中某一项所有键的列表,去除其中的数量“count”和类别“status”
获取数据集中的元组分类所需的期望信息:遍历数据集,获取每一类别的数量和总数,使用math.log2()方法及计算公式得出对数据集中的元组分类所需的期望信息
获取根据指定候选属性分类所需的期望信息:根据指定属性,获取此属性的分类,并根据此分类计算所需期望信息
计算信息增益:调用以上两种方法,取其差
根据指定属性对数据集进行分类:遍历数据集,将指定属性不同的项添加至不同的分类中
判断数据集是否属于同一类:遍历数据集,获取数据集中所有项的分类,如果存储分类列表长度为一,则属于同一类,否则不属于同一类
最后,编写主函数。如果数据集不属于同一类,获取候选数据集,若候选属性不为空,执行以下操作,计算所有候选属性的信息增益,取能获得最大信息增益的属性,获取此数据根据能获得最大信息增益的属性进行分裂的所有子集,对每一子集进行递归操作。

源代码

# -*- coding: utf-8 -*-

"""
@Time        : 2020/12/7
@Author      : lixinci
@File        : 8_7_决策树算法
@Description :
"""
import copy
import math


dataset = [
    {"department": "sales", "age": "31...35", "salary": "46K...50K", "count": 30, "status": "senior"},
    {"department": "sales", "age": "26...30", "salary": "26K...30K", "count": 40, "status": "junior"},
    {"department": "sales", "age": "31...35", "salary": "31K...35K", "count": 40, "status": "junior"},
    {"department": "systems", "age": "21...25", "salary": "46K...50K", "count": 20, "status": "junior"},
    {"department": "systems", "age": "31...35", "salary": "66K...70K", "count": 5, "status": "senior"},
    {"department": "systems", "age": "26...30", "salary": "46K...50K", "count": 3, "status": "junior"},
    {"department": "systems", "age": "41...45", "salary": "66K...70K", "count": 3, "status": "senior"},
    {"department": "marketing", "age": "36...40", "salary": "46K...50K", "count": 10, "status": "senior"},
    {"department": "marketing", "age": "31...35", "salary": "41K...45K", "count": 4, "status": "junior"},
    {"department": "secretary", "age": "46...50", "salary": "36K...40K", "count": 4, "status": "senior"},
    {"department": "secretary", "age": "26...30", "salary": "26K...30K", "count": 6, "status": "junior"}
]


def get_candidate_attributes(dataset):
    """
    获取候选属性
    """
    data = dataset[0]
    candidate_attributes = list(data.keys())
    candidate_attributes.remove('status')
    candidate_attributes.remove('count')
    return candidate_attributes


def original_information_requirement(dataset):
    """
    统计分类,计算对数据集中的元组分类所需的期望信息
    """
    labels = {}
    info = 0.0
    for data in dataset:
        labels[data["status"]] = labels.get(data["status"], 0) + data["count"]
    # 获取数据集中总数
    count = sum(list(labels.values()))
    # 求期望信息
    for num in labels.values():
        info -= (num / count) * math.log2(num / count)
    return info


def new_information_requirements(dataset, label):
    """
    计算新的信息需求
    """
    attribute_category = {}
    info = 0.0
    for data in dataset:
        attribute_category[data[label]] = attribute_category.get(data[label], 0) + data["count"]
    # 获取数据集中的总数
    count = sum(list(attribute_category.values()))
    for category, num in attribute_category.items():
        tmp_info = 0.0
        tmp_labels = {}
        for data in dataset:
            if data[label].__eq__(category):
                tmp_labels[data["status"]] = tmp_labels.get(data["status"], 0) + data["count"]
        tmp_count = sum(list(tmp_labels.values()))
        for tmp_num in tmp_labels.values():
            tmp_info -= (tmp_num / tmp_count) * math.log2(tmp_num / tmp_count)
        info += (num / count) * tmp_info
    return info


def information_gain(dataset, label):
    """
    计算信息增益
    """
    info = original_information_requirement(
        dataset) - new_information_requirements(dataset, label)
    return info


def classification_by_attribute(dataset, label):
    """
    根据选中属性对数据集进行分类
    """
    # 获取某属性的几个类别
    category = []
    for data in dataset:
        if data[label] not in category:
            category.append(data[label])
    # 根据属性存储分类后的数据集
    datas = []
    for i in range(len(category)):
        tmp_data = []
        for data in dataset:
            if data[label].__eq__(category[i]):
                tmp = copy.deepcopy(data)
                del tmp[label]
                tmp_data.append(tmp)
        datas.append({category[i]: tmp_data})
    return datas


def is_same_class(dataset):
    """
    判断是否属于同一类
    """
    category = []
    for data in dataset:
        if data['status'] not in category:
            category.append(data['status'])
    if len(category) == 1:
        return True
    return False


def main(dataset, count, branch=''):
    if not is_same_class(dataset):
        # 获取候选属性集
        candidate_attributes = get_candidate_attributes(dataset)
        if not candidate_attributes:
            print("在第{}层{}分支的数据集因为无候选集无法再次进行分解".format(count, branch))
            return
        # 保存最佳分裂属性
        optimal_split_attribute = ""
        # 保存最高分裂属性的信息增益
        max_information_gain = 0
        for label in candidate_attributes:
            current_information_gain = information_gain(dataset, label)
            if current_information_gain > max_information_gain:
                optimal_split_attribute = label

        # 生成的子集
        datas = classification_by_attribute(dataset, optimal_split_attribute)
        for data in datas:
            for key, sub_data in data.items():
                main(sub_data, count + 1, key)

        if branch.__eq__(''):
            print("第{}层根据{}来分类".format(count, optimal_split_attribute))
        else:
            print(
                "在第{}层{}分支的数据集根据{}来分类".format(
                    count,
                    branch,
                    optimal_split_attribute))
    else:
        if branch.__eq__(''):
            return
        print("在第{}层{}分支的数据集为同一类".format(count, branch))


if __name__ == '__main__':
    main(dataset, 1)

本文地址:https://blog.csdn.net/qq_44924544/article/details/110876313