python3.4之决策树
程序员文章站
2023-12-31 21:20:58
#!/usr/bin/env python
# coding=utf-8
import numpy as np
from sklearn import tree
f...
#!/usr/bin/env python # coding=utf-8 import numpy as np from sklearn import tree from sklearn.metrics import precision_recall_curve from sklearn.metrics import classification_report from sklearn.cross_validation import train_test_split import pydot from sklearn.externals.six import StringIO def loadDataSet(): data = [] label = [] with open('D:python/fat.txt') as file: for line in file: tokens = line.strip().split(' ') data.append([float(tk) for tk in tokens[:-1]]) label.append(tokens[-1]) x = np.array(data) print('x:') print(x) label = np.array(label) y = np.zeros(label.shape) y[label == 'fat'] = 1 print('y:') print(y) return x, y def decisionTreeClf(): x, y = loadDataSet() # 拆分数据集和训练集 x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.2) print('x_train:'); print(x_train) print('x_test:'); print(x_test) print('y_train:'); print(y_train) print('y_test:'); print(y_test) # 使用信息熵作为划分标准 clf = tree.DecisionTreeClassifier(criterion='entropy') print(clf) clf.fit(x_train, y_train) dot_data = StringIO() with open("iris.dot", 'w') as f: f=tree.export_graphviz(clf, out_file=f) tree.export_graphviz(clf, out_file=dot_data) graph = pydot.graph_from_dot_data(dot_data.getvalue()) graph[0].write_pdf("ex.pdf") # Image(graph.create_png()) # 打印特征在分类起到的作用性 print(clf.feature_importances_) # 打印测试结果 answer = clf.predict(x_train) print('x_train:') print(x_train) print('answer:') print(answer) print('y_train:') print(y_train) print('计算正确率:') print(np.mean(answer == y_train)) # 准确率与召回率 precision, recall, thresholds = precision_recall_curve(y_train, clf.predict(x_train) ) answer = clf.predict_proba(x)[:, 1] print(classification_report(y, answer, target_names=['thin', 'fat'])) decisionTreeClf() # print('ll')
数据集fat.txt文件内容如下:
1.5 50 thin 1.5 60 fat 1.6 40 thin 1.6 60 fat 1.7 60 thin 1.7 80 fat 1.8 60 thin 1.8 90 fat 1.9 70 thin 1.9 80 fat
所需要的Python包有:
pygraphviz (1.3.1)
pyparsing (2.1.10)
scikit-learn (0.18.1)
pygraphviz (1.3.1)包是可视化包。
下载可视化工具:
graphviz-2.38.msi
百度搜索安装可视化工具。