欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

决策树算法实现

程序员文章站 2022-05-21 23:30:25
...

decision-tree.py

本文为 落魄陶陶 原创,转载请注明出处
数据来源及源码参见github

  • 学习并参考《机器学习实战》第三章
  • 主要使用Pandas库
  • decision-tree.py为基本算法实现,基于数据fish.xlsx

理解核心:

  1. 数据的有序程度以熵来表示,信息增益越大,表明对数据的划分越有效
  2. 遍历每个字段尝试对数据进行划分后计算信息增益,每次取信息增益最大的划分
  3. 如果划分后某个分组中都属于一类,停止划分,否则递归调用步骤2进一步划分

关键在于第2步,要明确,熵的衡量,每次都是以label列的有序程度来计算,
换句话说,根据Xi的不同取值对数据进行划分,其对应的分类组Y更加有序,说明这是更好的划分

import math

import pandas as pd


# 加载数据
def load_excel(file: str) -> pd.DataFrame:
    return pd.read_excel(file)


def load_csv(file: str, sep: str = ',') -> pd.DataFrame:
    return pd.read_csv(file, sep=sep, header=None)


# 计算熵 H=-∑p(xi)log(p(xi),2)
def calc_entropy(df: pd.DataFrame) -> float:
    total = df.shape[0]
    value_counts = df[df.columns[-1]].value_counts()
    entropy_items = value_counts. \
        apply(lambda x: x / total). \
        apply(lambda prob: prob * math.log2(prob))
    return -entropy_items.sum()


# 划分子集
def split_data_frame(df, col_name, val):
    return df[df[col_name] == val].drop(col_name, axis=1)


# 选择最好子集划分
def choose_best_feature(df: pd.DataFrame) -> str:
    columns = df.columns[:-1]
    best_entropy = calc_entropy(df)
    best_info_gain = 0.
    best_column = None
    for col in columns:
        values = df[col].unique()
        new_entropy = 0.
        for val in values:
            subset = split_data_frame(df, col, val)
            new_entropy += calc_entropy(subset)
        info_gain = best_entropy - new_entropy
        if info_gain > best_info_gain:
            best_info_gain = info_gain
            best_column = col
    return best_column


# 创建决策树
def create_tree(df: pd.DataFrame):
    values = df[df.columns[-1]].unique()
    if values.size < 2:  # 所有label都相同,返回label
        return values[0]
    if df.shape[0] == 2:  # df中只有最后一列数据和label,不可进一步划分,统计label中数量最多的为最终label
        return df[df.columns[-1]].value_counts(ascending=False).values[0]
    best_column = choose_best_feature(df)
    tree = {best_column: {}}
    values_of_column = df[best_column].unique()
    for val in values_of_column:
        tree[best_column][val] = create_tree(split_data_frame(df, best_column, val))
    return tree


if __name__ == '__main__':
    # df = load_excel('fish.xlsx')
    df = load_csv('lenses.txt', '\t')
    # entropy = calc_entropy(df)
    # column = choose_best_feature(df)
    tree = create_tree(df)
    print(tree)

相关标签: 决策树

上一篇: 决策树

下一篇: ACM PKU 2407 Relatives