欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

决策树 ID3学习笔记

程序员文章站 2024-02-16 10:42:04
...

最近开始看机器学习方面的知识,决策树(DT)主要有两种算法 ID3(Iterative Dichotomiser 3)和 C4.5哈。

决策树算法的优点:

1:易于理解,使用白盒模型,相比之下,在一个黑盒子模型(例如人工神经网络),结果可能更难以解释
2:需要准备的数据量不大
3:能够处理数字和数据的类别(需要做相应的转变),而其他算法分析的数据集往往是只有一种类型的变量
4:能够处理多输出的问题


决策树算法的缺点:
1:决策树算法学习者可以创建复杂的树,容易过拟合,为了避免这种问题,出现了剪枝的概念,即设置一个叶子结点所需要的最小数目或者设置树的最大深度
2:决策树的结果可能是不稳定的,因为在数据中一个很小的变化可能导致生成一个完全不同的树,这个问题可以通过使用集成决策树来解决
3:实际决策树学习算法是基于启发式算法,如贪婪算法,寻求在每个节点上的局部最优决策。这样的算法不能保证返回全局最优决策树。

4:决策树学习者很可能在某些类占主导地位时创建有有偏异的树,因此建议用平衡的数据训练决策树

 

ID3算法最早是由罗斯昆(J. Ross Quinlan)于1975年在悉尼大学提出的一种分类预测算法,算法以信息论为基础,其核心是“信息熵”。ID3算法通过计算每个属性的信息增益,认为信息增益高的是好属性,每次划分选取信息增益最高的属性为划分标准,重复这个过程,直至生成一个能完美分类训练样例的决策树。首先介绍两个概念:

决策树 ID3学习笔记

理解了上面两个概念就好办了,下面是一个具体的例子,加入根据 outlook,temp,hum,windy 属性来决定去不去paly。

决策树 ID3学习笔记

直接上代码,写的比较搓,就是个示例,哈:

import copy, numpy as np
import sys
import pdb

#pdb.set_trace()
'''训练样本,前两列对应是属性值,最后一列表示分类,下面的值分别对应
              Outlook, Temperature, Humidity, Windy, Play'''

train_data = [['Outlook', 'Temperature', 'Humidity', 'Windy', 'Play'],
              ['sunny', 'hot', 'high', 'false', 'no'],
              ['sunny', 'hot', 'high', 'true', 'no'],
              ['overcast', 'hot', 'high', 'false', 'yes'], 
              ['rain', 'mild', 'high', 'false', 'yes'], 
              ['rain', 'cool', 'normal', 'false', 'yes'],
              ['rain', 'cool', 'normal', 'true', 'no'],
              ['overcast', 'cool', 'normal', 'true', 'yes'],
              ['sunny', 'mild', 'high', 'false', 'no'],
              ['sunny', 'cool', 'normal', 'false', 'yes'],
              ['rain', 'mild', 'normal', 'false', 'yes'],
              ['sunny', 'mild', 'normal', 'true', 'yes'],
              ['overcast', 'mild', 'high', 'true', 'yes'],
              ['overcast', 'hot', 'normal', 'false', 'yes'],
              ['rain', 'mild', 'high', 'true', 'no']]
#测试样本
test_data = [['sunny', 'mild', 'normal', 'false'], ['overcast', 'hot', 'normal', 'true']]

#计算信息熵,传入的是某个属性下的分类及对应的个数,{1:2, 2:5, 3:9},返回信息熵
def entropy_cal(pro):
	entro = 0
	total = sum(list(pro.values()))
	for key, value in pro.items():
		P = value/total
		entro = entro -P * np.log2(P)
	return entro

#res = entropy_cal({1:1, 2:2, 3:1})
#print(res)


#传入的是训练样本,默认最后一列是标签,返回信息增益最大的属性值
def IG(train_data):
	#print(train_data)
	#不存在属性
	if len(train_data[0]) <= 1:
		return -1
	#训练样本集数量
	train_num = len(train_data)
	#print(train_num)
	#属性个数
	pro_num = len(train_data[0])
    #声明一个空的字典
	pro_dic = {}
	for i in range(pro_num):
		for j in range(1, train_num):
			pro_ind = str(i)

			if pro_ind not in pro_dic.keys():
				pro_dic[pro_ind] = {}
			
			if train_data[j][i] in pro_dic[pro_ind].keys():
				pro_dic[pro_ind][train_data[j][i]]['total'] += 1
			else:
				pro_dic[pro_ind][train_data[j][i]] = {'total':1}
			#记录对应标签
			if train_data[j][-1] in pro_dic[pro_ind][train_data[j][i]].keys():
				pro_dic[pro_ind][train_data[j][i]][train_data[j][-1]] += 1
			else:
				pro_dic[pro_ind][train_data[j][i]][train_data[j][-1]] = 1

	#print('--------')
	#print(pro_dic)
	#print('========')
	MAX_IG = 10000
	MAX_INDEX = -1
	for i in range(pro_num - 1):
		#对每个属性计算熵
		linshi = 0
		for key, value in pro_dic[str(i)].items():
			P = value['total']/train_num
			del value['total']
			linshi += P * entropy_cal(value)
		#print(linshi)
		if linshi < MAX_IG:
			MAX_INDEX = i
			MAX_IG = linshi
	#print('选择的属性是:')
	#print(MAX_IG)
	#print(MAX_INDEX)
	return {'index':MAX_INDEX, 'pro':pro_dic[str(MAX_INDEX)].keys()}


		
def getTree(train, node):
	#print(tree)
	res=IG(train)
	#print(res)
	index = res['index']
	#遍历列表train,如何 index 对应的标签都是同一个,那么就结束
	num = len(train)

	if index != -1:
		#pro = {'sunny':{'yes':3,'no':2, 'data':[]}, 'overcast':{'yes':4, 'data':[]}, 'rain':{'yes':3,'no':2, 'data':[]}}
		pro = {}
		for i in range(num):
			if train[i][index] in pro.keys():
				if train[i][-1] in pro[train[i][index]].keys():
					pro[train[i][index]][train[i][-1]] += 1
				else:
					pro[train[i][index]][train[i][-1]] = 1
			else:
				pro[train[i][index]] = {}
				pro[train[i][index]][train[i][-1]] = 1

			temp = copy.deepcopy(train[i])
			del temp[index]
			if 'data' not in pro[train[i][index]].keys():
				pro[train[i][index]]['data'] = list()
				head = copy.deepcopy(train[0])
				del head[index]
				pro[train[i][index]]['data'].append(head)
			pro[train[i][index]]['data'].append(temp)
	#print(pro)
	name = train[0][index]
	if len(node) == 0:
		#print(tree)
		for i in res['pro']:
			if len(pro[i].keys()) <= 2:
				key = list(pro[i].keys())
				if key[0] == 'data':
					del key[0]
				print([name, i, key[0]])
			else:
				newdata = copy.deepcopy(pro[i]['data'])
				getTree(newdata, [name, i])
	else:
		#print('-----------')
		#print(res)
		#print('===========')
		for i in res['pro']:
			newnode = copy.deepcopy(node)
			if len(pro[i].keys()) <= 2:
				key = list(pro[i].keys())
				if key[0] == 'data':
					del key[0]
				newnode.append(name)
				newnode.append(i)
				newnode.append(key[0])
				print(newnode)
			else:
				newdata = copy.deepcopy(pro[i]['data'])
				newnode.append(name)
				newnode.append(i)
				getTree(newdata, newnode)

def main():
	getTree(train_data, [])


if __name__ == "__main__":
    sys.exit(main())

运行结果:

决策树 ID3学习笔记