欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

sklearn决策树与随机森林 参数及规则提取 模型可视化(初体验)

程序员文章站 2022-04-08 13:54:37
...

决策树

import os
import pandas as pd
import numpy as np
from sklearn import tree
from sklearn.tree import _tree
from sklearn.model_selection import train_test_split
from sklearn.feature_extraction import DictVectorizer
from sklearn.tree import DecisionTreeClassifier
from sklearn.metrics import classification_report
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import confusion_matrix
import matplotlib.pyplot as plt
import pydotplus


def tree_to_code(tree, feature_names):    # 决策树规则提取
    tree_ = tree.tree_
    feature_name = [
        feature_names[i] if i != _tree.TREE_UNDEFINED else "undefined!"
        for i in tree_.feature
    ]
    print('feature_name:', feature_name)
    with open('code.txt', 'a+') as f:
        f.write("def tree({}):".format(", ".join(feature_names)))
        f.write('\n')
        f.close()

    def recurse(node, depth):
        indent = "  " * depth
        # print('tree_.feature:',tree_.feature)
        if tree_.feature[node] != _tree.TREE_UNDEFINED:
            # print('tree_.feature[node]:',tree_.feature[node])
            name = feature_name[node]
            threshold = tree_.threshold[node]
            with open('code.txt', 'a+') as f:
                f.write("{}if {} <= {}:".format(indent, name, threshold))
                f.write('\n')
                f.close()
            recurse(tree_.children_left[node], depth + 1)
            with open('code.txt', 'a+') as f:
                f.write("{}else:  # if {} > {}".format(indent, name, threshold))
                f.write('\n')
                f.close()
            recurse(tree_.children_right[node], depth + 1)
        else:
            with open('code.txt', 'a+') as f:
                f.write("{}return {} -- {}".format(indent, tree_.value[node],
                                                target_name[np.argmax(tree_.value[node])]))
                f.write('\n')
                f.close()

    recurse(0, 1)

pwd = os.getcwd()
titanic = pd.read_csv(pwd + '/ta.txt')
titanic['age'].fillna(titanic['age'].mean(), inplace=True) #  补充缺失值
# 选取一些特征作为我们划分的依据
x = titanic[['pclass', 'age', 'sex']]
y = titanic['survived']
labels = [0, 1]
target_name = ["deid", "survived"]

fea_name = ["sex", "age", "pclass"]
fea_name.sort()

x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.3)  # 测试数据和训练数据的比例 数值为测数据/总体数据

dt = DictVectorizer(sparse=False)   #  sparse=False意思是不产生稀疏矩阵

x_train = dt.fit_transform(x_train.to_dict(orient="record"))

x_test = dt.fit_transform(x_test.to_dict(orient="record"))

# 使用决策树
dtc = DecisionTreeClassifier(                         # 使用默认的就行
                             # class_weight='balanced',   #  平衡数据集
                             # criterion='entropy',     # 划分标准使用gini还是信息熵  默认gini
                             # max_features='sqrt',   
                             )

dtc.fit(x_train, y_train)

dt_predict = dtc.predict(x_test)


tree_to_code(dtc, fea_name)  # 实现决策树的规则提取

print(dtc.score(x_test, y_test))  

print(classification_report(y_test, dt_predict, labels=labels, target_names=target_name))

# # 混淆矩阵并可视化
confmat = confusion_matrix(y_true=y_test, y_pred=rfc_y_predict, labels=labels)  # 输出混淆矩阵
print(confmat)

fig, ax = plt.subplots(figsize=(3, 3))
ax.matshow(confmat, cmap=plt.cm.Blues, alpha=0.3)
for i in range(confmat.shape[0]):
    for j in range(confmat.shape[1]):
        ax.text(x=j, y=i, s=confmat[i, j], va='center', ha='center')

plt.xticks(range(len(confmat)), labels)
plt.yticks(range(len(confmat)), labels)
plt.xlabel('predicted label')
plt.ylabel('true label')
plt.savefig('confusion_matrix.png')
plt.show()

# 可视化决策树

os.environ["PATH"] += os.pathsep + 'graphviz的bin路径'   #  在pycharm运行时 可能会出现找不到graphviz的情况,自己加环境
dot_data = tree.export_graphviz(dtc, out_file=None, feature_names=fea_name, class_names=target_name,
                                filled=True,
                                rounded=True,
                                )
graph = pydotplus.graph_from_dot_data(dot_data)
graph.write_pdf("descion_tree.pdf")

随机森林

# 之前的数据导入处理和决策树一样
# 使用随机森林

rfc = RandomForestClassifier(n_estimators=100, max_depth=6)  #  如果不设置n_estimators的值 在2.0版本会有警告提示 建议将其设置为2.02的默认值100

rfc.fit(x_train, y_train)

rfc_y_predict = rfc.predict(x_test)

print(rfc.score(x_test, y_test))

print(classification_report(y_test, rfc_y_predict, labels=labels, target_names=target_name))

if os.path.exists(pwd + '/forest/'):
    os.chdir(pwd + '/forest/')
else:
    os.mkdir(pwd + '/forest/')
    os.chdir(pwd + '/forest/')

for idx, estimator in enumerate(rfc.estimators_):
    # 导出dot文件
    filename = 'forest_' + str(idx) + '.pdf'
    dot_data = tree.export_graphviz(estimator,
                                    out_file=None,
                                    feature_names=fea_name,
                                    class_names=target_name,
                                    rounded=True,
                                    proportion=False,
                                    precision=2,
                                    filled=True)
                                    
    graph = pydotplus.graph_from_dot_data(dot_data)

    graph.write_pdf(filename)

本地文件ta的原始文件
性别 Pclass 分别做了数值处理

提取的规则代码块

def tree(age, pclass, sex):
  if sex <= 1.5:
    if age <= 10.0:
      if pclass <= 2.5:
        return [[ 0. 12.]] deid
      else:  # if pclass > 2.5
        if age <= 0.583299994468689:
          return [[1. 0.]] survived
        else:  # if age > 0.583299994468689
          if age <= 4.0:
            return [[0. 3.]] deid
          else:  # if age > 4.0
            if age <= 7.5:
              return [[2. 0.]] survived
            else:  # if age > 7.5
              return [[1. 2.]] deid
    else:  # if age > 10.0
      if pclass <= 1.5:
        if age <= 54.5:
          if age <= 29.0:
            if age <= 17.5:
              return [[0. 2.]] deid
            else:  # if age > 17.5
              if age <= 24.5:
                if age <= 20.0:
                  return [[2. 0.]] survived
                else:  # if age > 20.0
                  if age <= 23.5:
                    if age <= 21.5:
                      return [[0. 1.]] deid
                    else:  # if age > 21.5
                      if age <= 22.5:
                        return [[1. 0.]] survived
                      else:  # if age > 22.5
                        return [[0. 1.]] deid
                  else:  # if age > 23.5
                    return [[2. 0.]] survived
              else:  # if age > 24.5
                if age <= 26.0:
                  return [[1. 2.]] deid
                else:  # if age > 26.0
                  if age <= 27.5:
                    return [[0. 1.]] deid
                  else:  # if age > 27.5
                    return [[1. 2.]] deid
          else:  # if age > 29.0
            if age <= 33.5:
              if age <= 31.09709072113037:
                return [[4. 0.]] survived
              else:  # if age > 31.09709072113037
                if age <= 32.09709072113037:
                  return [[29. 10.]] survived
                else:  # if age > 32.09709072113037
                  return [[2. 0.]] survived
            else:  # if age > 33.5
              if age <= 36.5:
                if age <= 35.5:
                  return [[0. 2.]] deid
                else:  # if age > 35.5
                  return [[1. 4.]] deid
              else:  # if age > 36.5
                if age <= 47.5:
                  if age <= 38.5:
                    if age <= 37.5:
                      return [[1. 1.]] survived
                    else:  # if age > 37.5
                      return [[1. 1.]] survived
                  else:  # if age > 38.5
                    if age <= 45.5:
                      if age <= 41.5:
                        if age <= 39.5:
                          return [[3. 1.]] survived
                        else:  # if age > 39.5
                          return [[2. 0.]] survived
                      else:  # if age > 41.5
                        if age <= 43.0:
                          return [[2. 1.]] survived
                        else:  # if age > 43.0
                          if age <= 44.5:
                            return [[1. 0.]] survived
                          else:  # if age > 44.5
                            return [[3. 1.]] survived
                    else:  # if age > 45.5
                      if age <= 46.5:
                        return [[5. 0.]] survived
                      else:  # if age > 46.5
                        return [[3. 1.]] survived
                else:  # if age > 47.5
                  if age <= 48.5:
                    return [[1. 2.]] deid
                  else:  # if age > 48.5
                    if age <= 51.5:
                      if age <= 49.5:
                        return [[2. 1.]] survived
                      else:  # if age > 49.5
                        return [[3. 0.]] survived
                    else:  # if age > 51.5
                      if age <= 53.0:
                        return [[1. 1.]] survived
                      else:  # if age > 53.0
                        return [[1. 1.]] survived
        else:  # if age > 54.5
          return [[14.  0.]] survived
      else:  # if pclass > 1.5
        if age <= 29.5:
          if age <= 25.5:
            if age <= 23.5:
              if age <= 18.5:
                return [[17.  0.]] survived
              else:  # if age > 18.5
                if age <= 19.5:
                  if pclass <= 2.5:
                    return [[1. 0.]] survived
                  else:  # if pclass > 2.5
                    return [[4. 1.]] survived
                else:  # if age > 19.5
                  if age <= 20.5:
                    return [[8. 0.]] survived
                  else:  # if age > 20.5
                    if age <= 22.5:
                      if age <= 21.5:
                        if pclass <= 2.5:
                          return [[5. 0.]] survived
                        else:  # if pclass > 2.5
                          return [[4. 1.]] survived
                      else:  # if age > 21.5
                        if pclass <= 2.5:
                          return [[3. 1.]] survived
                        else:  # if pclass > 2.5
                          return [[3. 0.]] survived
                    else:  # if age > 22.5
                      return [[7. 0.]] survived
            else:  # if age > 23.5
              if age <= 24.5:
                if pclass <= 2.5:
                  return [[1. 1.]] survived
                else:  # if pclass > 2.5
                  return [[6. 1.]] survived
              else:  # if age > 24.5
                if pclass <= 2.5:
                  return [[4. 0.]] survived
                else:  # if pclass > 2.5
                  return [[4. 1.]] survived
          else:  # if age > 25.5
            return [[23.  0.]] survived
        else:  # if age > 29.5
          if age <= 45.5:
            if age <= 44.5:
              if age <= 32.5:
                if age <= 31.59709072113037:
                  if pclass <= 2.5:
                    if age <= 30.59709072113037:
                      return [[8. 0.]] survived
                    else:  # if age > 30.59709072113037
                      return [[32.  4.]] survived
                  else:  # if pclass > 2.5
                    if age <= 30.59709072113037:
                      return [[1. 1.]] survived
                    else:  # if age > 30.59709072113037
                      return [[220.  32.]] survived
                else:  # if age > 31.59709072113037
                  if pclass <= 2.5:
                    return [[3. 2.]] survived
                  else:  # if pclass > 2.5
                    return [[5. 0.]] survived
              else:  # if age > 32.5
                if age <= 35.5:
                  return [[11.  0.]] survived
                else:  # if age > 35.5
                  if age <= 36.5:
                    if pclass <= 2.5:
                      return [[1. 0.]] survived
                    else:  # if pclass > 2.5
                      return [[0. 1.]] deid
                  else:  # if age > 36.5
                    if pclass <= 2.5:
                      if age <= 40.5:
                        return [[3. 0.]] survived
                      else:  # if age > 40.5
                        if age <= 41.5:
                          return [[1. 1.]] survived
                        else:  # if age > 41.5
                          return [[3. 0.]] survived
                    else:  # if pclass > 2.5
                      return [[11.  0.]] survived
            else:  # if age > 44.5
              if pclass <= 2.5:
                return [[2. 0.]] survived
              else:  # if pclass > 2.5
                return [[1. 1.]] survived
          else:  # if age > 45.5
            return [[13.  0.]] survived
  else:  # if sex > 1.5
    if pclass <= 2.5:
      if pclass <= 1.5:
        if age <= 62.5:
          if age <= 36.5:
            if age <= 35.5:
              if age <= 24.5:
                return [[ 0. 19.]] deid
              else:  # if age > 24.5
                if age <= 26.0:
                  return [[1. 0.]] survived
                else:  # if age > 26.0
                  if age <= 31.09709072113037:
                    return [[0. 6.]] deid
                  else:  # if age > 31.09709072113037
                    if age <= 32.09709072113037:
                      return [[ 1. 23.]] deid
                    else:  # if age > 32.09709072113037
                      return [[0. 5.]] deid
            else:  # if age > 35.5
              return [[1. 3.]] deid
          else:  # if age > 36.5
            return [[ 0. 31.]] deid
        else:  # if age > 62.5
          if age <= 63.5:
            return [[1. 1.]] survived
          else:  # if age > 63.5
            return [[0. 1.]] deid
      else:  # if pclass > 1.5
        if age <= 17.5:
          return [[0. 9.]] deid
        else:  # if age > 17.5
          if age <= 22.5:
            if age <= 21.5:
              if age <= 18.5:
                return [[1. 3.]] deid
              else:  # if age > 18.5
                return [[0. 5.]] deid
            else:  # if age > 21.5
              return [[2. 0.]] survived
          else:  # if age > 22.5
            if age <= 26.5:
              return [[0. 5.]] deid
            else:  # if age > 26.5
              if age <= 27.5:
                return [[1. 1.]] survived
              else:  # if age > 27.5
                if age <= 29.5:
                  return [[0. 5.]] deid
                else:  # if age > 29.5
                  if age <= 30.5:
                    return [[1. 2.]] deid
                  else:  # if age > 30.5
                    if age <= 46.0:
                      if age <= 43.0:
                        if age <= 39.0:
                          if age <= 37.0:
                            if age <= 31.59709072113037:
                              if age <= 31.09709072113037:
                                return [[0. 2.]] deid
                              else:  # if age > 31.09709072113037
                                return [[ 3. 17.]] deid
                            else:  # if age > 31.59709072113037
                              return [[0. 9.]] deid
                          else:  # if age > 37.0
                            return [[1. 0.]] survived
                        else:  # if age > 39.0
                          return [[0. 4.]] deid
                      else:  # if age > 43.0
                        return [[1. 0.]] survived
                    else:  # if age > 46.0
                      return [[0. 5.]] deid
    else:  # if pclass > 2.5
      if age <= 19.5:
        if age <= 12.0:
          if age <= 5.5:
            if age <= 1.0833500027656555:
              return [[0. 1.]] deid
            else:  # if age > 1.0833500027656555
              if age <= 3.5:
                return [[1. 0.]] survived
              else:  # if age > 3.5
                return [[0. 1.]] deid
          else:  # if age > 5.5
            return [[2. 0.]] survived
        else:  # if age > 12.0
          if age <= 17.5:
            if age <= 15.5:
              return [[0. 1.]] deid
            else:  # if age > 15.5
              if age <= 16.5:
                return [[1. 3.]] deid
              else:  # if age > 16.5
                return [[0. 1.]] deid
          else:  # if age > 17.5
            if age <= 18.5:
              return [[2. 3.]] deid
            else:  # if age > 18.5
              return [[0. 1.]] deid
      else:  # if age > 19.5
        if age <= 21.5:
          return [[3. 0.]] survived
        else:  # if age > 21.5
          if age <= 23.5:
            if age <= 22.5:
              return [[1. 2.]] deid
            else:  # if age > 22.5
              return [[0. 1.]] deid
          else:  # if age > 23.5
            if age <= 32.5:
              if age <= 31.59709072113037:
                if age <= 25.5:
                  return [[1. 1.]] survived
                else:  # if age > 25.5
                  if age <= 29.0:
                    return [[2. 0.]] survived
                  else:  # if age > 29.0
                    if age <= 30.59709072113037:
                      return [[1. 1.]] survived
                    else:  # if age > 30.59709072113037
                      return [[75. 40.]] survived
              else:  # if age > 31.59709072113037
                return [[1. 0.]] survived
            else:  # if age > 32.5
              if age <= 37.0:
                return [[0. 3.]] deid
              else:  # if age > 37.0
                if age <= 42.5:
                  return [[2. 0.]] survived
                else:  # if age > 42.5
                  return [[1. 1.]] survived

sklearn决策树与随机森林 参数及规则提取 模型可视化(初体验)

sklearn决策树与随机森林 参数及规则提取 模型可视化(初体验)
sklearn决策树与随机森林 参数及规则提取 模型可视化(初体验)

相关标签: 机器学习