机器学习五(sklearn决策树——多分类)
程序员文章站
2022-03-30 18:56:03
...
1.前言
sklearn决策树分类,采用ID3算法,自带iris数据集(根据草的特征进行分类,有3类,用0、1、2标记)。
2.决策树绘制准备
(1)下载安装graphviz
https://graphviz.gitlab.io/_pages/Download/Download_windows.html
(2)pycharm install graphviz
File->setting->project(project interpreter)->右侧绿+->查询安装
(3)决策树方法参数说明
http://scikit-learn.org/stable/modules/generated/sklearn.tree.DecisionTreeClassifier.html#sklearn.tree.DecisionTreeClassifier
2.python代码
(1) 执行代码如下tree_class.py:
from sklearn.datasets import load_iris
from sklearn import tree
from sklearn.model_selection import train_test_split
import graphviz
import os
def multi_class_tree():
iris = load_iris()
x = iris['data']
y = iris['target']
dtc = tree.DecisionTreeClassifier(criterion="entropy")
x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.1)
clf = dtc.fit(x_train, y_train)
print(clf.predict(x_test))
print(y_test)
dot_data = tree.export_graphviz(clf, out_file=None)
graph = graphviz.Source(dot_data)
os.environ["PATH"] += os.pathsep + 'F:/Program Files/Graphviz2.38/bin/'
graph.render("iris", view=True)
multi_class_tree()
3.验证结果
(1)测试集结果与预测结果比较
预测:[0 1 2 2 2 1 1 1 0 2 1 2 0 1 0]
实际:[0 1 2 2 2 1 1 2 0 2 1 2 0 1 0]
(2)决策树