决策树 ID3学习笔记
最近开始看机器学习方面的知识,决策树(DT)主要有两种算法 ID3(Iterative Dichotomiser 3)和 C4.5哈。
决策树算法的优点:
1:易于理解,使用白盒模型,相比之下,在一个黑盒子模型(例如人工神经网络),结果可能更难以解释
2:需要准备的数据量不大
3:能够处理数字和数据的类别(需要做相应的转变),而其他算法分析的数据集往往是只有一种类型的变量
4:能够处理多输出的问题
决策树算法的缺点:
1:决策树算法学习者可以创建复杂的树,容易过拟合,为了避免这种问题,出现了剪枝的概念,即设置一个叶子结点所需要的最小数目或者设置树的最大深度
2:决策树的结果可能是不稳定的,因为在数据中一个很小的变化可能导致生成一个完全不同的树,这个问题可以通过使用集成决策树来解决
3:实际决策树学习算法是基于启发式算法,如贪婪算法,寻求在每个节点上的局部最优决策。这样的算法不能保证返回全局最优决策树。
4:决策树学习者很可能在某些类占主导地位时创建有有偏异的树,因此建议用平衡的数据训练决策树
ID3算法最早是由罗斯昆(J. Ross Quinlan)于1975年在悉尼大学提出的一种分类预测算法,算法以信息论为基础,其核心是“信息熵”。ID3算法通过计算每个属性的信息增益,认为信息增益高的是好属性,每次划分选取信息增益最高的属性为划分标准,重复这个过程,直至生成一个能完美分类训练样例的决策树。首先介绍两个概念:
理解了上面两个概念就好办了,下面是一个具体的例子,加入根据 outlook,temp,hum,windy 属性来决定去不去paly。
直接上代码,写的比较搓,就是个示例,哈:
import copy, numpy as np
import sys
import pdb
#pdb.set_trace()
'''训练样本,前两列对应是属性值,最后一列表示分类,下面的值分别对应
Outlook, Temperature, Humidity, Windy, Play'''
train_data = [['Outlook', 'Temperature', 'Humidity', 'Windy', 'Play'],
['sunny', 'hot', 'high', 'false', 'no'],
['sunny', 'hot', 'high', 'true', 'no'],
['overcast', 'hot', 'high', 'false', 'yes'],
['rain', 'mild', 'high', 'false', 'yes'],
['rain', 'cool', 'normal', 'false', 'yes'],
['rain', 'cool', 'normal', 'true', 'no'],
['overcast', 'cool', 'normal', 'true', 'yes'],
['sunny', 'mild', 'high', 'false', 'no'],
['sunny', 'cool', 'normal', 'false', 'yes'],
['rain', 'mild', 'normal', 'false', 'yes'],
['sunny', 'mild', 'normal', 'true', 'yes'],
['overcast', 'mild', 'high', 'true', 'yes'],
['overcast', 'hot', 'normal', 'false', 'yes'],
['rain', 'mild', 'high', 'true', 'no']]
#测试样本
test_data = [['sunny', 'mild', 'normal', 'false'], ['overcast', 'hot', 'normal', 'true']]
#计算信息熵,传入的是某个属性下的分类及对应的个数,{1:2, 2:5, 3:9},返回信息熵
def entropy_cal(pro):
entro = 0
total = sum(list(pro.values()))
for key, value in pro.items():
P = value/total
entro = entro -P * np.log2(P)
return entro
#res = entropy_cal({1:1, 2:2, 3:1})
#print(res)
#传入的是训练样本,默认最后一列是标签,返回信息增益最大的属性值
def IG(train_data):
#print(train_data)
#不存在属性
if len(train_data[0]) <= 1:
return -1
#训练样本集数量
train_num = len(train_data)
#print(train_num)
#属性个数
pro_num = len(train_data[0])
#声明一个空的字典
pro_dic = {}
for i in range(pro_num):
for j in range(1, train_num):
pro_ind = str(i)
if pro_ind not in pro_dic.keys():
pro_dic[pro_ind] = {}
if train_data[j][i] in pro_dic[pro_ind].keys():
pro_dic[pro_ind][train_data[j][i]]['total'] += 1
else:
pro_dic[pro_ind][train_data[j][i]] = {'total':1}
#记录对应标签
if train_data[j][-1] in pro_dic[pro_ind][train_data[j][i]].keys():
pro_dic[pro_ind][train_data[j][i]][train_data[j][-1]] += 1
else:
pro_dic[pro_ind][train_data[j][i]][train_data[j][-1]] = 1
#print('--------')
#print(pro_dic)
#print('========')
MAX_IG = 10000
MAX_INDEX = -1
for i in range(pro_num - 1):
#对每个属性计算熵
linshi = 0
for key, value in pro_dic[str(i)].items():
P = value['total']/train_num
del value['total']
linshi += P * entropy_cal(value)
#print(linshi)
if linshi < MAX_IG:
MAX_INDEX = i
MAX_IG = linshi
#print('选择的属性是:')
#print(MAX_IG)
#print(MAX_INDEX)
return {'index':MAX_INDEX, 'pro':pro_dic[str(MAX_INDEX)].keys()}
def getTree(train, node):
#print(tree)
res=IG(train)
#print(res)
index = res['index']
#遍历列表train,如何 index 对应的标签都是同一个,那么就结束
num = len(train)
if index != -1:
#pro = {'sunny':{'yes':3,'no':2, 'data':[]}, 'overcast':{'yes':4, 'data':[]}, 'rain':{'yes':3,'no':2, 'data':[]}}
pro = {}
for i in range(num):
if train[i][index] in pro.keys():
if train[i][-1] in pro[train[i][index]].keys():
pro[train[i][index]][train[i][-1]] += 1
else:
pro[train[i][index]][train[i][-1]] = 1
else:
pro[train[i][index]] = {}
pro[train[i][index]][train[i][-1]] = 1
temp = copy.deepcopy(train[i])
del temp[index]
if 'data' not in pro[train[i][index]].keys():
pro[train[i][index]]['data'] = list()
head = copy.deepcopy(train[0])
del head[index]
pro[train[i][index]]['data'].append(head)
pro[train[i][index]]['data'].append(temp)
#print(pro)
name = train[0][index]
if len(node) == 0:
#print(tree)
for i in res['pro']:
if len(pro[i].keys()) <= 2:
key = list(pro[i].keys())
if key[0] == 'data':
del key[0]
print([name, i, key[0]])
else:
newdata = copy.deepcopy(pro[i]['data'])
getTree(newdata, [name, i])
else:
#print('-----------')
#print(res)
#print('===========')
for i in res['pro']:
newnode = copy.deepcopy(node)
if len(pro[i].keys()) <= 2:
key = list(pro[i].keys())
if key[0] == 'data':
del key[0]
newnode.append(name)
newnode.append(i)
newnode.append(key[0])
print(newnode)
else:
newdata = copy.deepcopy(pro[i]['data'])
newnode.append(name)
newnode.append(i)
getTree(newdata, newnode)
def main():
getTree(train_data, [])
if __name__ == "__main__":
sys.exit(main())
运行结果:
上一篇: vue小白教程-5 网络应用——axios发送请求
下一篇: Java学习笔记总结