鸢尾花决策树分类及可视化
程序员文章站
2024-03-19 21:18:28
...
鸢尾花数据集简介
Iris数据集作为入门经典数据集。Iris数据集是常用的分类实验数据集,早在1936年,模式识别的先驱Fisher就在论文The use of multiple measurements in taxonomic problems中使用了它 (直至今日该论文仍然被频繁引用)。
Iris也称鸢尾花卉数据集,是一类多重变量分析的数据集。数据集包含150个数据集,分为3类,每类50个数据,每个数据包含4个属性:花萼长度(sepal length),花萼宽度(sepal width),花瓣长度(petal length),花瓣宽度(petal width),可通过4个属性预测鸢尾花卉属于(Setosa,Versicolour,Virginica)三个种类中的哪一类。在三个类别中,其中有一个类别和其他两个类别是线性可分的。
在sklearn中已内置了此数据集。
代码
import pandas as pd
import pydotplus
import numpy as np
from IPython.display import Image, display
from sklearn import preprocessing
from sklearn import tree
from sklearn.datasets import load_iris
import matplotlib.pyplot as plt
%matplotlib inline
from matplotlib.colors import ListedColormap
# 载入数据
iris = load_iris()
# 构建决策树
clf = tree.DecisionTreeClassifier(max_depth=5)# 深度为4层时有1个结果仍然模糊,5层就足够
clf = clf.fit(iris.data, iris.target)
# 数据可视化
dot_data = tree.export_graphviz(clf,
out_file = None,
feature_names = iris.feature_names,
class_names = iris.target_names,
filled=True,
rounded=True
)
graph = pydotplus.graph_from_dot_data(dot_data)
display(Image(graph.create_png()))
# 整个数据集分类结果可视化
x=iris.data[:,2:4] # 取出花瓣的长和宽
y=iris.target # 取出标签
#计算散点图的坐标上下界
x_min,x_max=x[:,0].min() -0.5, x[:,0].max()+0.5
y_min, y_max=x[:,1].min()-0.5, x[:,1].max()+0.5
#绘制边界
cmap_light=ListedColormap(['#AAAAFF','#AAFFAA','#FFAAAA'])
h=0.02
xx,yy=np.meshgrid(np.arange(x_min,x_max,h),np.arange(y_min,y_max,h))
clf = tree.DecisionTreeClassifier(max_depth=5)
clf = clf.fit(x, y)
Z=clf.predict( np.c_[xx.ravel(),yy.ravel()])
Z=Z.reshape(xx.shape)
plt.figure()
plt.pcolormesh(xx,yy,Z,cmap=cmap_light)
plt.scatter(x[:,0],x[:,1],c=y)
plt.xlim( xx.min(), xx.max() )
plt.ylim( yy.min(),yy.max() )
plt.show()
结果展示
一点小提示
在Python中使用pydotplus绘制决策树图形时,需要用到GraphViz模块,我用的是Anaconda,直接在Anaconda下安装运行结果报异常,无法找到GraphViz模块,后查找到用下载GraphViz模块,一般下载该模块的stabel版本,并将其添加到环境变量中来。然后在python下安装pydotplus,运行代码正常得到结果。
大家可参考:pydotplus安装和基本入门
推荐阅读
-
鸢尾花决策树分类及可视化
-
Graphviz安装及使用:决策树可视化
-
sklearn中决策树回归器DecisionTreeRegressor的实际应用及可视化
-
决策树(Decision Tree)分类算法原理及应用
-
python分类分析--决策树算法原理及案例
-
决策树 鸢尾花分类 数据挖掘Python
-
分类算法学习(四)——决策树算法的原理及简单实现
-
分别采用线性LDA、k-means和SVM算法对鸢尾花数据集和月亮数据集进行二分类可视化分析
-
Py之matplotlib&seaborn :高级图可视化之Q-Q分位数图probplot、boxplot箱线图、stripplot分类散点图案例应用及代码实现
-
Python-鸢尾花数据集/月亮数据集的线性LDA、k-means和SVM算法二分类可视化分析