tree_DecisionTreeClassifier—决策树代码演示
程序员文章站
2024-03-19 21:14:28
...
一、安装Graphviz:
1.1 下载安装:官网
1.将graphviz安装目录(C:\Program Files (x86)\Graphviz2.38\)下的bin文件夹添加到Path环境变量中。
2.进入windows命令行界面,输入dot -version,然后按回车,如果显示graphviz的相关版本信息,则安装配置成功。
1.2,安装Python插件
安装插件graphviz: pip install graphviz
安装插件pydotplus:pip install pydotplus
二、 决策树代码演示
from sklearn.datasets import load_iris # 导入类库
from sklearn import tree
import pydotplus
import sys
import os
os.environ["PATH"] += os.pathsep + 'C:/Program Files (x86)/Graphviz2.38/bin/'
# 注意后面的路径是你自己的graphviz的bin目录。
iris = load_iris() # 载入sciki-learn的自带数据
clf = tree.DecisionTreeClassifier() # 载入决策树分类模型
clf = clf.fit(iris.data, iris.target) # 决策树拟合,得到模型
with open("iris.dot",'w') as f: # 将模型存入dot文件iris.dot
f = tree.export_graphviz(clf, out_file=f)
# 第一种方式,用graphviz的dot命令生成决策树的可视化文件
# 在命令行执行 dot -Tpdf iris.dot -o iris.pdf, 在当前目录生成决策树的可视化文件iris.pdf.
# 打开可以看到决策树的模型图。
# 第二种方式,用pydotplus生成iris.pdf
dot_data = tree.export_graphviz(clf, out_file=None)
graph = pydotplus.graph_from_dot_data(dot_data)
graph.write_pdf("iris.pdf")
1.2 参数优化
from sklearn import tree
from sklearn import model_selection
from sklearn.datasets import load_iris
from sklearn.grid_search import GridSearchCV
from sklearn.metrics import confusion_matrix
from sklearn.metrics import precision_score
from sklearn.metrics import recall_score
from sklearn.metrics import f1_score
from sklearn.metrics import accuracy_score
from sklearn.metrics import classification_report
iris=load_iris()
x=iris.data
y=iris.target
X_train,X_test,y_train,y_test = model_selection\
.train_test_split(x,y,test_size=0.3,
random_state=123456)
parameters={
'criterion':['gini','entropy'],
'max_depth':[1,2,3,4,5,6,7,8]
}
dtree=tree.DecisionTreeClassifier()
grid_search=GridSearchCV(dtree,parameters,scoring='accuracy',cv=5)
grid_search.fit(x,y)
grid_search.best_estimator_ #查看grid_search方法
grid_search.best_score_ #正确率
grid_search.best_params_ #最佳 参数组合
dtree=tree.DecisionTreeClassifier(criterion='gini',max_depth=5)
dtree.fit(X_train,y_train)
pred=dtree.predict(X_test)
print(classification_report(y_test,pred))
下一篇: 二叉树的层序遍历
推荐阅读