【ML--15】在R语言中使用决策树算法做多分类预测
一、算法简介
决策树模型是一种简单易用的非参数分类器。它不需要对数据有任何的先验假设,计算速度较快,结果容易解释,而且稳健性强,不怕噪声数据和缺失数据。决策树模型的基本计算步骤如下:先从n个自变量中挑选一个,寻找最佳分割点,将数据划分为两组。针对分组后数据,将上述步骤重复下去,直到满足某种条件。
在决策树建模中需要解决的重要问题有三个:
如何选择自变量
如何选择分割点
确定停止划分的条件
在R语言中关于决策树建模,最为常用的有两个包,一个是rpart包,另一个是party包。我们来看一下对于上述问题,这两个包分别是怎么处理的。
rpart包的处理方式:
首先对所有自变量和所有分割点进行评估,最佳的选择是使分割后组内的数据更为“一致”(pure)。这里的“一致”是指组内数据的因变量取值变异较小。rpart包对这种“一致”性的默认度量是Gini值。确定停止划分的参数有很多(参见rpart.control),确定这些参数是非常重要而微妙的,因为划分越细,模型越复杂,越容易出现过度拟合的情况,而划分过粗,又会出现拟合不足。处理这个问题通常是使用“剪枝”(prune)方法。即先建立一个划分较细较为复杂的树模型,再根据交叉检验(Cross-Validation)的方法来估计不同“剪枝”条件下,各模型的误差,选择误差最小的树模型。
party包的处理方式:
它的背景理论是“条件推断决策树”(conditional inference trees):它根据统计检验来确定自变量和分割点的选择。即先假设所有自变量与因变量均独立。再对它们进行卡方独立检验,检验P值小于阀值的自变量加入模型,相关性最强的自变量作为第一次分割的自变量。自变量选择好后,用置换检验来选择分割点。用party包建立的决策树不需要剪枝,因为阀值就决定了模型的复杂程度。所以如何决定阀值参数是非常重要的(参见ctree_control)。较为流行的做法是取不同的参数值进行交叉检验,选择误差最小的模型参数。
2、party包实现代码
####################################决策树算法#################
rm(list=ls())
gc()
options(scipen = 200)
library(party)
iris <- iris
##########划分训练和测试集#################
set.seed(2016)
train.indeces<-sample(1:nrow(iris),100)
iris.train<-iris[train.indeces,]
iris.test<-iris[-train.indeces,]
myFormula <- Species~.
# 建立决策树
iris_ctree <- ctree(myFormula, data=iris.train)
# 检测预测值
(preTable<-table(predict(iris_ctree), iris.train$Species))
##准确率
(accurary<-sum(diag(preTable))/sum(preTable))
#打印决策树
print(iris_ctree)
plot(iris_ctree)
plot(iris_ctree, type="simple")
########################测试模型################
# 在测试集上测试决策树
testPred <- predict(iris_ctree, newdata = iris.test)
(preTatable1<-table(testPred, iris.test$Species))
##准确率
(accurary<-sum(diag(preTatable1))/sum(preTatable1))
3、运行结果:
> (preTable<-table(predict(iris_ctree), iris.train$Species))
setosa versicolor virginica
setosa 38 0 0
versicolor 0 29 1
virginica 0 3 29
> (accurary<-sum(diag(preTable))/sum(preTable))
[1] 0.96
> #打印决策树
> (preTatable1<-table(testPred, iris.test$Species))
testPred setosa versicolor virginica
setosa 12 0 0
versicolor 0 16 2
virginica 0 2 18
> (accurary<-sum(diag(preTatable1))/sum(preTatable1))
[1] 0.92
>
> print(iris_ctree)
Conditional inference tree with 3 terminal nodes
Response: Species
Inputs: Sepal.Length, Sepal.Width, Petal.Length, Petal.Width
Number of observations: 100
1) Petal.Width <= 0.6; criterion = 1, statistic = 92.93
2)* weights = 38
1) Petal.Width > 0.6
3) Petal.Width <= 1.5; criterion = 1, statistic = 42.156
4)* weights = 30
3) Petal.Width > 1.5
5)* weights = 32
>
上一篇: Asp.net后台调用js 2种方法
下一篇: java图片识别文字的方法