欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页  >  IT编程

python3.4之决策树

程序员文章站 2022-03-29 14:32:52
#!/usr/bin/env python # coding=utf-8 import numpy as np from sklearn import tree f...
#!/usr/bin/env python
# coding=utf-8

import numpy as np
from sklearn import tree
from sklearn.metrics import precision_recall_curve
from sklearn.metrics import classification_report
from sklearn.cross_validation import train_test_split
import pydot
from sklearn.externals.six import StringIO

def loadDataSet():
    data = []
    label = []
    with open('D:python/fat.txt') as file:
        for line in file:
            tokens = line.strip().split(' ')
            data.append([float(tk) for tk in tokens[:-1]])
            label.append(tokens[-1])
    x = np.array(data)
    print('x:')
    print(x)
    label = np.array(label)
    y = np.zeros(label.shape)
    y[label == 'fat'] = 1
    print('y:')
    print(y)
    return x, y

def decisionTreeClf():
    x, y = loadDataSet()

    # 拆分数据集和训练集
    x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.2)
    print('x_train:');
    print(x_train)
    print('x_test:');
    print(x_test)
    print('y_train:');
    print(y_train)
    print('y_test:');
    print(y_test)
    # 使用信息熵作为划分标准
    clf = tree.DecisionTreeClassifier(criterion='entropy')
    print(clf)
    clf.fit(x_train, y_train)
    dot_data = StringIO() 
    with open("iris.dot", 'w') as f: 
        f=tree.export_graphviz(clf, out_file=f)
        tree.export_graphviz(clf, out_file=dot_data)
        graph = pydot.graph_from_dot_data(dot_data.getvalue())  
        graph[0].write_pdf("ex.pdf")  
#         Image(graph.create_png())
    # 打印特征在分类起到的作用性
    print(clf.feature_importances_)

    # 打印测试结果
    answer = clf.predict(x_train)
    print('x_train:')
    print(x_train)
    print('answer:')
    print(answer)
    print('y_train:')
    print(y_train)
    print('计算正确率:')
    print(np.mean(answer == y_train))

    # 准确率与召回率
    precision, recall, thresholds = precision_recall_curve(y_train, clf.predict(x_train)
)
    answer = clf.predict_proba(x)[:, 1]
    print(classification_report(y, answer, target_names=['thin', 'fat']))

decisionTreeClf()
# print('ll')

数据集fat.txt文件内容如下:

1.5 50 thin
1.5 60 fat
1.6 40 thin
1.6 60 fat
1.7 60 thin
1.7 80 fat
1.8 60 thin
1.8 90 fat
1.9 70 thin
1.9 80 fat

所需要的Python包有:

pygraphviz (1.3.1)

pyparsing (2.1.10)

scikit-learn (0.18.1)

pygraphviz (1.3.1)包是可视化包。

下载可视化工具:

graphviz-2.38.msi

百度搜索安装可视化工具。