R语言分类预测--决策树
1、基本思想:
忽略对数据分布的要求,找出输入变量和输出变量取值间的逻辑对应关系并实现对新数据输出变量的预测。包括分类树和回归树。
- 分类树:通过对特征空间的划分,使得同一区域中样本输出变量尽可能取同一类别值。
- 回归树:通过对特征空间的划分,即同一区域中样本输出变量取值的离散程度应尽可能低。
2、两大问题:
2.1 树的生长,即利用训练样本集完成决策树的建立过程。
- 决策树生长伪代码:暂略
-
找到最佳分裂属性和分裂属性中的最佳分割点
1) 分类树:
测度标准1:Gini系数
其中
p(j|t)是节点t中样本输出变量取第j类的归一化概率
p(j,t)是节点t包含第j类的概率
Nj,t是节点t包含第j类的样本量,Nj是全部样本中输出变量值为第j类的样本量
最佳分组变量和最佳分割点应是使ΔG(t)最大的变量和分割点。测度标准2:信息增益(信息消除随机不确定性的程度)
信息量:
先验熵和后验熵:
加权信息增益:
2)回归树:由于回归树的输出变量为数值型,方差是最理想的指标
t为节点,Nt为节点t所含样本量,yi(t)为节点t中第i个观测的输出变量值
2.2 树的剪枝,即利用测试样本集对所形成的决策树进行精简
- 预修剪: 通过指定树的最大深度、节点所含最小样本量来阻止树的充分生长
- 后修剪:先允许树充分生长,再指定一个允许的最大预测误差值,当决策树在测试样本集上的错误率明显增大时,应停止剪枝
2.2.1 最小代价复杂度剪枝法
R(T):在测试样本集上的预测误差(错判率或均方误差)
| T |:T的叶节点数目;
α:复杂度参数
若中间节点的代价复杂度大于它的子树代价复杂度,否则应该保留子树。否则应该剪掉子树(Tt为子树叶子节点)
2.2.2 N倍交叉验证剪枝
- 模型选择:通过不断调整N的取值,找到在预测误差最小下的N值。该N值下的模型(参数)应是一个最佳模型。
- N个预测模型,得到N个在测试样本集上的预测误差。N个预测误差的平均值,作为模型真实预测误差的估计。
2.4 算法:CART、C4.5、C5.0(后2种是对ID3算法的延伸)
3.决策树的应用及剪枝
3.1仍以992年美国总统选举的数据分析不同背景人群的倾向:
变量包含:总统候选人、年龄、年龄分类、受教育年限、最高学历、性别 。
导入数据,设置参数,进行初步建树(先令cp=0建立充分生长的树)
数据概览:
library(rpart)
library(rpart.plot)
library(rattle)
library(readxl)
df <- read_excel('voter.xlsx',sheet = 1)
head(df)
dim(df)
str(df)
#分类变量转换为因子
df$pres92 = as.factor(df$pres92)
df$agecat = as.factor(df$agecat)
df$educ = as.factor(df$educ)
df$degree = as.factor(df$degree)
df$sex = as.factor(df$sex)
#建立分类树,异质性测度标准为信息增益
rc <- rpart.control(minsplit=10, xval=10, maxdepth=30,cp=0)
myTree <- rpart(pres92~age+educ+degree+sex, data=df, method='class',parms=list(split='information'), control=rc)
查看此时的树,很显然非常茂盛
rpart.plot(myTree,type=1,extra=3,branch=1)
查看树建立过程中cp的变化及交叉验证预测误差
printcp(myTree)
plotcp(myTree)
参数理解:
cp:复杂度,nsplit:样本分组次数,rel error:是预测误差相对值的估计
xerror:交叉验证的预测误差相对值(相对于根节点),xstd:预测误差的标准误
本样本共有观测1847,根节点错判率939/1847 = 0.50839为单位1,可见,经过19次分组(产生20个叶子节点)交叉验证误差最小为0.97*0.50839=0.493。
下面对树进行剪枝,将复杂度cp设为19次分组时的cp=0.00266
#树的剪枝
treeFit <- prune(myTree, cp=0.00266)
rpart.plot(treeFit,type=1,extra=3,branch=1)
printcp(treeFit)
plotcp(treeFit)
得到如下的树
参数如下,可知经过19次分组的预测错误率是最低为49%
换一种方式显示树结构,会清晰点:
plot(treeFit, uniform = T, branch = 0.8, margin = 0.1, main = "Tree Classification", compress = T)
text(treeFit, use.n = T, cex = 0.9)
还可以生成规则
asRules(treeFit)
规则解释:以第一个为例:有22个满足条件(
age>=48.5
educ=11,12,13,14,15,16,17,18,20
sex=1
educ=14,15,16,17,18
age>=62.5
educ=14,16)被预测为类别1(pres92=1),错判率为50%
以上述模型对样本所有观测进行预测,查看各分类错判率:
t1 <- diag(tb1)
rs <-vector()
for(i in 1:3){
a = 1-t1[i]/sum(tb1[,i])
rs = c(rs,a)
};rs
第1类的错判率为0.4962217,第2类为0.2380952,第3类为0.4464661