欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

利用python画混淆矩阵

程序员文章站 2022-03-14 16:34:38
...

我们在这里提供两种不同的应用场合:

1.如果你已经通过实验后有了矩阵数据,那么就可以直接利用这一方法,源码如下:

from sklearn.metrics import confusion_matrix    # 生成混淆矩阵函数
import matplotlib.pyplot as plt    # 绘图库
import numpy as np

def plot_confusion_matrix(cm, labels_name, title):
    cm = cm / cm.sum(axis=1)[:, np.newaxis]    # 归一化
    plt.imshow(cm, interpolation='nearest')    # 在特定的窗口上显示图像
    plt.title(title)    # 图像标题
    plt.colorbar()
    num_class = np.array(range(len(labels_name)))#获取标签的间隔数    
    plt.xticks(num_class, labels_name, rotation=90)    # 将标签印在x轴坐标上
    plt.yticks(num_class, labels_name)    # 将标签印在y轴坐标上
    plt.ylabel('True label')    
    plt.xlabel('Predicted label')
    plt.show()

#这里cm为你事先已经获得矩阵数据,一般为list类型
cm = [[5, 0, 0, 0],
      [1, 4, 0, 0],
      [0, 0, 5, 0],
      [1, 0, 0, 2]] 
cm = np.array(cm) #将list类型转换成数组类型,如果已经是numpy数组类型,则忽略此步骤。
labels_name = ['1','2','3','4']#这里个横纵坐标标签集合赋值
plot_confusion_matrix(cm,labels_name,"confusion_matrix")#调用函数

结果如图:

利用python画混淆矩阵

2.如果你是在训练网络的过程中想要在训练结束后绘制混淆矩阵,那么就需在上面的代码上稍作修改:
 

from sklearn.metrics import confusion_matrix    # 生成混淆矩阵函数
import matplotlib.pyplot as plt    # 绘图库
import numpy as np
import tensorflow as tf

def plot_confusion_matrix(cm, labels_name, title):
    cm = cm / cm.sum(axis=1)[:, np.newaxis]    # 归一化
    plt.imshow(cm, interpolation='nearest')    # 在特定的窗口上显示图像
    plt.title(title)    # 图像标题
    plt.colorbar()
    num_class = np.array(range(len(labels_name)))#获取标签的间隔数    
    plt.xticks(num_class, labels_name, rotation=90)    # 将标签印在x轴坐标上
    plt.yticks(num_class, labels_name)    # 将标签印在y轴坐标上
    plt.ylabel('True label')    
    plt.xlabel('Predicted label')
    plt.show()

#这里通过list类型的标签数据来生成混淆矩阵
y_true = [.......] #这里要想办法将你的实际标签类别转换成list类型
pred_y = [.......] #这里要想办法将你的预测标签类别转换成list类型,一般网络的最后一层通常是类别,这里根据你的网络来获得
cm = confusion_matrix(y_true,pred_y)
 
cm = np.array(cm) #将list类型转换成数组类型,如果已经是numpy数组类型,则忽略此步骤。
labels_name = ['1','2','3','4']#这里个横纵坐标标签集合赋值
plot_confusion_matrix(cm,labels_name,"confusion_matrix")#调用函数

实验结果与上图类似。

相关标签: 功能代码积累