欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

python画混淆矩阵

程序员文章站 2022-06-15 15:34:49
...

对于分类问题,为了直观表示各类别分类的准确性,一般使用混淆矩阵M. 

混淆矩阵M的每一行代表每个真实类(GT),每一列表示预测的类。即:Mij表示GroundTruth类别为i的所有数据中被预测为类别j的数目。

这里给出两种方法画混淆矩阵。

方法一:这里采用画图像的办法,绘制混淆矩阵的表示图。颜色越深,值越大。

# -*- coding: utf-8 -*-
# By Changxu Cheng, HUST

from __future__ import division
import  numpy as np
from skimage import io, color
from PIL import Image, ImageDraw, ImageFont
import os

def drawCM(matrix, savname):
    # Display different color for different elements
    lines, cols = matrix.shape
    sumline = matrix.sum(axis=1).reshape(lines, 1)
    ratiomat = matrix / sumline
    toplot0 = 1 - ratiomat
    toplot = toplot0.repeat(50).reshape(lines, -1).repeat(50, axis=0)
    io.imsave(savname, color.gray2rgb(toplot))
    # Draw values on every block
    image = Image.open(savname)
    draw = ImageDraw.Draw(image)
    font = ImageFont.truetype(os.path.join(os.getcwd(), "draw/ARIAL.TTF"), 15)
    for i in range(lines):
        for j in range(cols):
            dig = str(matrix[i, j])
            if i == j:
                filled = (255, 181, 197)
            else:
                filled = (46, 139, 87)
            draw.text((50 * j + 10, 50 * i + 10), dig, font=font, fill=filled)
    image.save(savname, 'jpeg')

if __name__ == "__main__":
    drawCM(np.random.randint(16, size=16).reshape(4,4), 'tmp.jpg')

注意:需要用到字体文件。代码中使用的是ARIAL.TTF。这样才可以在图中直接标注出数目。

某实验结果图如下(不是上述__name == "__main__"代码的执行结果)

python画混淆矩阵

方法二:利用matplotlib.pyplot.matshow画图

from __future__ import division
import  numpy as np
import matplotlib.pyplot as plt
from matplotlib.ticker import MultipleLocator

def plotCM(classes, matrix, savname):
    """classes: a list of class names"""
    # Normalize by row
    matrix = matrix.astype(np.float)
    linesum = matrix.sum(1)
    linesum = np.dot(linesum.reshape(-1, 1), np.ones((1, matrix.shape[1])))
    matrix /= linesum
    # plot
    plt.switch_backend('agg')
    fig = plt.figure()
    ax = fig.add_subplot(111)
    cax = ax.matshow(matrix)
    fig.colorbar(cax)
    ax.xaxis.set_major_locator(MultipleLocator(1))
    ax.yaxis.set_major_locator(MultipleLocator(1))
    for i in range(matrix.shape[0]):
        ax.text(i, i, str('%.2f' % (matrix[i, i] * 100)), va='center', ha='center')
    ax.set_xticklabels([''] + classes, rotation=90)
    ax.set_yticklabels([''] + classes)
    #save
    plt.savefig(savname)

这种方法可以直接标出坐标轴的含义,比较方便。

python画混淆矩阵