求训练集的熵和信息增益
程序员文章站
2022-07-14 10:20:05
...
求训练集的熵和信息增益
代码如下:
import math
L = ['S','S','I','M','I','M','M','I','M','S']
F = ['S','I','M','M','M','I','S','M','S','S']
H = ['N','Y','Y','Y','Y','N','N','N','Y','Y']
R = ['N','Y','Y','Y','Y','Y','N','Y','Y','N']
def H_R(data): #求R熵的函数
Y,N = 0,0
for i in data:
if i == 'Y':
Y += 1 #统计R中为Y的个数
else:
N += 1 #统计R中为N的个数
RY = Y/(Y+N)
RN = N/(Y+N)
HR = -RY*math.log(RY,2)-RN*math.log(RN,2) #求R的熵
return HR
def H_RX(data1,data2): #求H(R|L)和H(R|F)函数
S,M,I = 0,0,0
for i in data1:
if i == 'S':
S += 1 #统计S的个数
elif i == 'M':
M += 1 #统计M的个数
else:
I += 1 #统计I的个数
RS = S/(S+M+I)
RM = M/(S+M+I)
RI = I/(S+M+I) #分别求S,M,I的概率
Rate = [RS,RM,RI]
TY,TN,HR_X = 0,0,0
count = -1 #这是一个计数器
for i in ['S','M','I']:
wd = [k for k,x in enumerate(data1) if x==i] #分别定位所有S,M,I在列表中的位置
count += 1 #作为计数器标记Rate列表对应元素
for j in wd:
if data2[j] == 'Y':
TY += 1 #统计分别在S,M,I对应的Y的数量
elif data2[j] == 'N':
TN += 1 #统计分别在S,M,I对应的N的数量
RY_X = TY/(TY+TN)
RN_X = TN/(TY+TN) #分别求在S,M,I下Y,N的概率
TY,TN = 0,0 #让上次统计的数量清0,使得下一次统计重新开始,防止与上次叠加
if RY_X==0 or RN_X ==0: #防止出现0log0而无法计算
HR_X += 0
else:
HR_X += -Rate[count]*(RY_X*math.log(RY_X,2)+RN_X*math.log(RN_X,2))
return HR_X
def H_RH(data1,data2):
Y,N = 0,0
for i in data1:
if i == 'Y':
Y += 1
else:
N += 1 #分别统计H中Y和N的数量
RY = Y/(Y+N)
RN = N/(Y+N)#分别求出H中Y、N的概率
Rate = [RY,RN]
TY,TN,HR_H = 0,0,0
count = -1
for i in ['Y','N']:
wd = [k for k,x in enumerate(data1) if x==i] #分别定位H中所有Y和N在列表中的位置
count += 1
for j in wd:
if data2[j] == 'Y':
TY += 1 #统计分别在H中的Y、N对应的Y的数量
elif data2[j] == 'N':
TN += 1 #统计分别在H中的Y、N对应的N的数量
RY_H = TY/(TY+TN)
RN_H = TN/(TY+TN) #分别求在H的Y、N下Y,N的概率
TY,TN = 0,0 #让上次统计的数量清0,使得下一次统计重新开始,防止与上次叠加
if RY_H==0 or RN_H ==0: #防止出现0log0而无法计算
HR_H += 0
else:
HR_H += -Rate[count]*(RY_H*math.log(RY_H,2)+RN_H*math.log(RN_H,2))
return HR_H
def G_HX(HR,HR_X):
GH_X = HR-HR_X
return GH_X
r = H_R(R)
l = H_RX(L,R)
f = H_RX(F,R)
h = H_RH(H,R)
print("信息熵分别为\n HR={:.2f}\n H(R|L)={:.2f}\n H(R|F)={:.2f}\n H(R|H)={:.2f}".format(H_R(R),H_RX(L,R),H_RX(F,R),H_RH(H,R)))
print("信息增益分别为\n g(L)={:.2f}\n g(F)={:.2f}\n g(H)={:.2f}".format(G_HX(r,l),G_HX(r,f),G_HX(r,h)))
运行结果如下:
上一篇: windows驱动开发-内存管理
下一篇: h5+c3进阶(1)
推荐阅读