sklearn决策树与随机森林 参数及规则提取 模型可视化(初体验)
程序员文章站
2022-04-08 13:54:37
...
决策树
import os
import pandas as pd
import numpy as np
from sklearn import tree
from sklearn.tree import _tree
from sklearn.model_selection import train_test_split
from sklearn.feature_extraction import DictVectorizer
from sklearn.tree import DecisionTreeClassifier
from sklearn.metrics import classification_report
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import confusion_matrix
import matplotlib.pyplot as plt
import pydotplus
def tree_to_code(tree, feature_names): # 决策树规则提取
tree_ = tree.tree_
feature_name = [
feature_names[i] if i != _tree.TREE_UNDEFINED else "undefined!"
for i in tree_.feature
]
print('feature_name:', feature_name)
with open('code.txt', 'a+') as f:
f.write("def tree({}):".format(", ".join(feature_names)))
f.write('\n')
f.close()
def recurse(node, depth):
indent = " " * depth
# print('tree_.feature:',tree_.feature)
if tree_.feature[node] != _tree.TREE_UNDEFINED:
# print('tree_.feature[node]:',tree_.feature[node])
name = feature_name[node]
threshold = tree_.threshold[node]
with open('code.txt', 'a+') as f:
f.write("{}if {} <= {}:".format(indent, name, threshold))
f.write('\n')
f.close()
recurse(tree_.children_left[node], depth + 1)
with open('code.txt', 'a+') as f:
f.write("{}else: # if {} > {}".format(indent, name, threshold))
f.write('\n')
f.close()
recurse(tree_.children_right[node], depth + 1)
else:
with open('code.txt', 'a+') as f:
f.write("{}return {} -- {}".format(indent, tree_.value[node],
target_name[np.argmax(tree_.value[node])]))
f.write('\n')
f.close()
recurse(0, 1)
pwd = os.getcwd()
titanic = pd.read_csv(pwd + '/ta.txt')
titanic['age'].fillna(titanic['age'].mean(), inplace=True) # 补充缺失值
# 选取一些特征作为我们划分的依据
x = titanic[['pclass', 'age', 'sex']]
y = titanic['survived']
labels = [0, 1]
target_name = ["deid", "survived"]
fea_name = ["sex", "age", "pclass"]
fea_name.sort()
x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.3) # 测试数据和训练数据的比例 数值为测数据/总体数据
dt = DictVectorizer(sparse=False) # sparse=False意思是不产生稀疏矩阵
x_train = dt.fit_transform(x_train.to_dict(orient="record"))
x_test = dt.fit_transform(x_test.to_dict(orient="record"))
# 使用决策树
dtc = DecisionTreeClassifier( # 使用默认的就行
# class_weight='balanced', # 平衡数据集
# criterion='entropy', # 划分标准使用gini还是信息熵 默认gini
# max_features='sqrt',
)
dtc.fit(x_train, y_train)
dt_predict = dtc.predict(x_test)
tree_to_code(dtc, fea_name) # 实现决策树的规则提取
print(dtc.score(x_test, y_test))
print(classification_report(y_test, dt_predict, labels=labels, target_names=target_name))
# # 混淆矩阵并可视化
confmat = confusion_matrix(y_true=y_test, y_pred=rfc_y_predict, labels=labels) # 输出混淆矩阵
print(confmat)
fig, ax = plt.subplots(figsize=(3, 3))
ax.matshow(confmat, cmap=plt.cm.Blues, alpha=0.3)
for i in range(confmat.shape[0]):
for j in range(confmat.shape[1]):
ax.text(x=j, y=i, s=confmat[i, j], va='center', ha='center')
plt.xticks(range(len(confmat)), labels)
plt.yticks(range(len(confmat)), labels)
plt.xlabel('predicted label')
plt.ylabel('true label')
plt.savefig('confusion_matrix.png')
plt.show()
# 可视化决策树
os.environ["PATH"] += os.pathsep + 'graphviz的bin路径' # 在pycharm运行时 可能会出现找不到graphviz的情况,自己加环境
dot_data = tree.export_graphviz(dtc, out_file=None, feature_names=fea_name, class_names=target_name,
filled=True,
rounded=True,
)
graph = pydotplus.graph_from_dot_data(dot_data)
graph.write_pdf("descion_tree.pdf")
随机森林
# 之前的数据导入处理和决策树一样
# 使用随机森林
rfc = RandomForestClassifier(n_estimators=100, max_depth=6) # 如果不设置n_estimators的值 在2.0版本会有警告提示 建议将其设置为2.02的默认值100
rfc.fit(x_train, y_train)
rfc_y_predict = rfc.predict(x_test)
print(rfc.score(x_test, y_test))
print(classification_report(y_test, rfc_y_predict, labels=labels, target_names=target_name))
if os.path.exists(pwd + '/forest/'):
os.chdir(pwd + '/forest/')
else:
os.mkdir(pwd + '/forest/')
os.chdir(pwd + '/forest/')
for idx, estimator in enumerate(rfc.estimators_):
# 导出dot文件
filename = 'forest_' + str(idx) + '.pdf'
dot_data = tree.export_graphviz(estimator,
out_file=None,
feature_names=fea_name,
class_names=target_name,
rounded=True,
proportion=False,
precision=2,
filled=True)
graph = pydotplus.graph_from_dot_data(dot_data)
graph.write_pdf(filename)
本地文件ta的原始文件
性别 Pclass 分别做了数值处理
提取的规则代码块
def tree(age, pclass, sex):
if sex <= 1.5:
if age <= 10.0:
if pclass <= 2.5:
return [[ 0. 12.]] deid
else: # if pclass > 2.5
if age <= 0.583299994468689:
return [[1. 0.]] survived
else: # if age > 0.583299994468689
if age <= 4.0:
return [[0. 3.]] deid
else: # if age > 4.0
if age <= 7.5:
return [[2. 0.]] survived
else: # if age > 7.5
return [[1. 2.]] deid
else: # if age > 10.0
if pclass <= 1.5:
if age <= 54.5:
if age <= 29.0:
if age <= 17.5:
return [[0. 2.]] deid
else: # if age > 17.5
if age <= 24.5:
if age <= 20.0:
return [[2. 0.]] survived
else: # if age > 20.0
if age <= 23.5:
if age <= 21.5:
return [[0. 1.]] deid
else: # if age > 21.5
if age <= 22.5:
return [[1. 0.]] survived
else: # if age > 22.5
return [[0. 1.]] deid
else: # if age > 23.5
return [[2. 0.]] survived
else: # if age > 24.5
if age <= 26.0:
return [[1. 2.]] deid
else: # if age > 26.0
if age <= 27.5:
return [[0. 1.]] deid
else: # if age > 27.5
return [[1. 2.]] deid
else: # if age > 29.0
if age <= 33.5:
if age <= 31.09709072113037:
return [[4. 0.]] survived
else: # if age > 31.09709072113037
if age <= 32.09709072113037:
return [[29. 10.]] survived
else: # if age > 32.09709072113037
return [[2. 0.]] survived
else: # if age > 33.5
if age <= 36.5:
if age <= 35.5:
return [[0. 2.]] deid
else: # if age > 35.5
return [[1. 4.]] deid
else: # if age > 36.5
if age <= 47.5:
if age <= 38.5:
if age <= 37.5:
return [[1. 1.]] survived
else: # if age > 37.5
return [[1. 1.]] survived
else: # if age > 38.5
if age <= 45.5:
if age <= 41.5:
if age <= 39.5:
return [[3. 1.]] survived
else: # if age > 39.5
return [[2. 0.]] survived
else: # if age > 41.5
if age <= 43.0:
return [[2. 1.]] survived
else: # if age > 43.0
if age <= 44.5:
return [[1. 0.]] survived
else: # if age > 44.5
return [[3. 1.]] survived
else: # if age > 45.5
if age <= 46.5:
return [[5. 0.]] survived
else: # if age > 46.5
return [[3. 1.]] survived
else: # if age > 47.5
if age <= 48.5:
return [[1. 2.]] deid
else: # if age > 48.5
if age <= 51.5:
if age <= 49.5:
return [[2. 1.]] survived
else: # if age > 49.5
return [[3. 0.]] survived
else: # if age > 51.5
if age <= 53.0:
return [[1. 1.]] survived
else: # if age > 53.0
return [[1. 1.]] survived
else: # if age > 54.5
return [[14. 0.]] survived
else: # if pclass > 1.5
if age <= 29.5:
if age <= 25.5:
if age <= 23.5:
if age <= 18.5:
return [[17. 0.]] survived
else: # if age > 18.5
if age <= 19.5:
if pclass <= 2.5:
return [[1. 0.]] survived
else: # if pclass > 2.5
return [[4. 1.]] survived
else: # if age > 19.5
if age <= 20.5:
return [[8. 0.]] survived
else: # if age > 20.5
if age <= 22.5:
if age <= 21.5:
if pclass <= 2.5:
return [[5. 0.]] survived
else: # if pclass > 2.5
return [[4. 1.]] survived
else: # if age > 21.5
if pclass <= 2.5:
return [[3. 1.]] survived
else: # if pclass > 2.5
return [[3. 0.]] survived
else: # if age > 22.5
return [[7. 0.]] survived
else: # if age > 23.5
if age <= 24.5:
if pclass <= 2.5:
return [[1. 1.]] survived
else: # if pclass > 2.5
return [[6. 1.]] survived
else: # if age > 24.5
if pclass <= 2.5:
return [[4. 0.]] survived
else: # if pclass > 2.5
return [[4. 1.]] survived
else: # if age > 25.5
return [[23. 0.]] survived
else: # if age > 29.5
if age <= 45.5:
if age <= 44.5:
if age <= 32.5:
if age <= 31.59709072113037:
if pclass <= 2.5:
if age <= 30.59709072113037:
return [[8. 0.]] survived
else: # if age > 30.59709072113037
return [[32. 4.]] survived
else: # if pclass > 2.5
if age <= 30.59709072113037:
return [[1. 1.]] survived
else: # if age > 30.59709072113037
return [[220. 32.]] survived
else: # if age > 31.59709072113037
if pclass <= 2.5:
return [[3. 2.]] survived
else: # if pclass > 2.5
return [[5. 0.]] survived
else: # if age > 32.5
if age <= 35.5:
return [[11. 0.]] survived
else: # if age > 35.5
if age <= 36.5:
if pclass <= 2.5:
return [[1. 0.]] survived
else: # if pclass > 2.5
return [[0. 1.]] deid
else: # if age > 36.5
if pclass <= 2.5:
if age <= 40.5:
return [[3. 0.]] survived
else: # if age > 40.5
if age <= 41.5:
return [[1. 1.]] survived
else: # if age > 41.5
return [[3. 0.]] survived
else: # if pclass > 2.5
return [[11. 0.]] survived
else: # if age > 44.5
if pclass <= 2.5:
return [[2. 0.]] survived
else: # if pclass > 2.5
return [[1. 1.]] survived
else: # if age > 45.5
return [[13. 0.]] survived
else: # if sex > 1.5
if pclass <= 2.5:
if pclass <= 1.5:
if age <= 62.5:
if age <= 36.5:
if age <= 35.5:
if age <= 24.5:
return [[ 0. 19.]] deid
else: # if age > 24.5
if age <= 26.0:
return [[1. 0.]] survived
else: # if age > 26.0
if age <= 31.09709072113037:
return [[0. 6.]] deid
else: # if age > 31.09709072113037
if age <= 32.09709072113037:
return [[ 1. 23.]] deid
else: # if age > 32.09709072113037
return [[0. 5.]] deid
else: # if age > 35.5
return [[1. 3.]] deid
else: # if age > 36.5
return [[ 0. 31.]] deid
else: # if age > 62.5
if age <= 63.5:
return [[1. 1.]] survived
else: # if age > 63.5
return [[0. 1.]] deid
else: # if pclass > 1.5
if age <= 17.5:
return [[0. 9.]] deid
else: # if age > 17.5
if age <= 22.5:
if age <= 21.5:
if age <= 18.5:
return [[1. 3.]] deid
else: # if age > 18.5
return [[0. 5.]] deid
else: # if age > 21.5
return [[2. 0.]] survived
else: # if age > 22.5
if age <= 26.5:
return [[0. 5.]] deid
else: # if age > 26.5
if age <= 27.5:
return [[1. 1.]] survived
else: # if age > 27.5
if age <= 29.5:
return [[0. 5.]] deid
else: # if age > 29.5
if age <= 30.5:
return [[1. 2.]] deid
else: # if age > 30.5
if age <= 46.0:
if age <= 43.0:
if age <= 39.0:
if age <= 37.0:
if age <= 31.59709072113037:
if age <= 31.09709072113037:
return [[0. 2.]] deid
else: # if age > 31.09709072113037
return [[ 3. 17.]] deid
else: # if age > 31.59709072113037
return [[0. 9.]] deid
else: # if age > 37.0
return [[1. 0.]] survived
else: # if age > 39.0
return [[0. 4.]] deid
else: # if age > 43.0
return [[1. 0.]] survived
else: # if age > 46.0
return [[0. 5.]] deid
else: # if pclass > 2.5
if age <= 19.5:
if age <= 12.0:
if age <= 5.5:
if age <= 1.0833500027656555:
return [[0. 1.]] deid
else: # if age > 1.0833500027656555
if age <= 3.5:
return [[1. 0.]] survived
else: # if age > 3.5
return [[0. 1.]] deid
else: # if age > 5.5
return [[2. 0.]] survived
else: # if age > 12.0
if age <= 17.5:
if age <= 15.5:
return [[0. 1.]] deid
else: # if age > 15.5
if age <= 16.5:
return [[1. 3.]] deid
else: # if age > 16.5
return [[0. 1.]] deid
else: # if age > 17.5
if age <= 18.5:
return [[2. 3.]] deid
else: # if age > 18.5
return [[0. 1.]] deid
else: # if age > 19.5
if age <= 21.5:
return [[3. 0.]] survived
else: # if age > 21.5
if age <= 23.5:
if age <= 22.5:
return [[1. 2.]] deid
else: # if age > 22.5
return [[0. 1.]] deid
else: # if age > 23.5
if age <= 32.5:
if age <= 31.59709072113037:
if age <= 25.5:
return [[1. 1.]] survived
else: # if age > 25.5
if age <= 29.0:
return [[2. 0.]] survived
else: # if age > 29.0
if age <= 30.59709072113037:
return [[1. 1.]] survived
else: # if age > 30.59709072113037
return [[75. 40.]] survived
else: # if age > 31.59709072113037
return [[1. 0.]] survived
else: # if age > 32.5
if age <= 37.0:
return [[0. 3.]] deid
else: # if age > 37.0
if age <= 42.5:
return [[2. 0.]] survived
else: # if age > 42.5
return [[1. 1.]] survived