欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

鸢尾花决策树分类及可视化

程序员文章站 2024-03-19 21:18:28
...

鸢尾花数据集简介

Iris数据集作为入门经典数据集。Iris数据集是常用的分类实验数据集,早在1936年,模式识别的先驱Fisher就在论文The use of multiple measurements in taxonomic problems中使用了它 (直至今日该论文仍然被频繁引用)。
Iris也称鸢尾花卉数据集,是一类多重变量分析的数据集。数据集包含150个数据集,分为3类,每类50个数据,每个数据包含4个属性:花萼长度(sepal length),花萼宽度(sepal width),花瓣长度(petal length),花瓣宽度(petal width),可通过4个属性预测鸢尾花卉属于(Setosa,Versicolour,Virginica)三个种类中的哪一类。在三个类别中,其中有一个类别和其他两个类别是线性可分的。
在sklearn中已内置了此数据集。

代码

import pandas as pd
import pydotplus
import numpy as np
from IPython.display import Image, display
from sklearn import preprocessing
from sklearn import tree
from sklearn.datasets import load_iris
import matplotlib.pyplot as plt
%matplotlib inline
from matplotlib.colors import ListedColormap


# 载入数据
iris = load_iris()
# 构建决策树
clf = tree.DecisionTreeClassifier(max_depth=5)# 深度为4层时有1个结果仍然模糊,5层就足够
clf = clf.fit(iris.data, iris.target)
# 数据可视化
dot_data = tree.export_graphviz(clf,
                                out_file = None,
                                feature_names = iris.feature_names,
                                class_names = iris.target_names,
                                filled=True,
                                rounded=True
                               )
graph = pydotplus.graph_from_dot_data(dot_data)
display(Image(graph.create_png()))

# 整个数据集分类结果可视化
x=iris.data[:,2:4]   # 取出花瓣的长和宽 
y=iris.target       # 取出标签
 #计算散点图的坐标上下界
x_min,x_max=x[:,0].min() -0.5, x[:,0].max()+0.5
y_min, y_max=x[:,1].min()-0.5, x[:,1].max()+0.5 
#绘制边界
cmap_light=ListedColormap(['#AAAAFF','#AAFFAA','#FFAAAA'])
h=0.02
xx,yy=np.meshgrid(np.arange(x_min,x_max,h),np.arange(y_min,y_max,h))

clf = tree.DecisionTreeClassifier(max_depth=5)
clf = clf.fit(x, y)
Z=clf.predict( np.c_[xx.ravel(),yy.ravel()])
Z=Z.reshape(xx.shape)

plt.figure()
plt.pcolormesh(xx,yy,Z,cmap=cmap_light)

plt.scatter(x[:,0],x[:,1],c=y)

plt.xlim( xx.min(), xx.max() )
plt.ylim( yy.min(),yy.max() )
plt.show()

结果展示

鸢尾花决策树分类及可视化
鸢尾花决策树分类及可视化

一点小提示

在Python中使用pydotplus绘制决策树图形时,需要用到GraphViz模块,我用的是Anaconda,直接在Anaconda下安装运行结果报异常,无法找到GraphViz模块,后查找到用下载GraphViz模块,一般下载该模块的stabel版本,并将其添加到环境变量中来。然后在python下安装pydotplus,运行代码正常得到结果。
大家可参考:pydotplus安装和基本入门