python画混淆矩阵
程序员文章站
2022-06-15 15:34:49
...
对于分类问题,为了直观表示各类别分类的准确性,一般使用混淆矩阵M.
混淆矩阵M的每一行代表每个真实类(GT),每一列表示预测的类。即:Mij表示GroundTruth类别为i的所有数据中被预测为类别j的数目。
这里给出两种方法画混淆矩阵。
方法一:这里采用画图像的办法,绘制混淆矩阵的表示图。颜色越深,值越大。
# -*- coding: utf-8 -*-
# By Changxu Cheng, HUST
from __future__ import division
import numpy as np
from skimage import io, color
from PIL import Image, ImageDraw, ImageFont
import os
def drawCM(matrix, savname):
# Display different color for different elements
lines, cols = matrix.shape
sumline = matrix.sum(axis=1).reshape(lines, 1)
ratiomat = matrix / sumline
toplot0 = 1 - ratiomat
toplot = toplot0.repeat(50).reshape(lines, -1).repeat(50, axis=0)
io.imsave(savname, color.gray2rgb(toplot))
# Draw values on every block
image = Image.open(savname)
draw = ImageDraw.Draw(image)
font = ImageFont.truetype(os.path.join(os.getcwd(), "draw/ARIAL.TTF"), 15)
for i in range(lines):
for j in range(cols):
dig = str(matrix[i, j])
if i == j:
filled = (255, 181, 197)
else:
filled = (46, 139, 87)
draw.text((50 * j + 10, 50 * i + 10), dig, font=font, fill=filled)
image.save(savname, 'jpeg')
if __name__ == "__main__":
drawCM(np.random.randint(16, size=16).reshape(4,4), 'tmp.jpg')
注意:需要用到字体文件。代码中使用的是ARIAL.TTF。这样才可以在图中直接标注出数目。
某实验结果图如下(不是上述__name == "__main__"代码的执行结果)
方法二:利用matplotlib.pyplot.matshow画图
from __future__ import division
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.ticker import MultipleLocator
def plotCM(classes, matrix, savname):
"""classes: a list of class names"""
# Normalize by row
matrix = matrix.astype(np.float)
linesum = matrix.sum(1)
linesum = np.dot(linesum.reshape(-1, 1), np.ones((1, matrix.shape[1])))
matrix /= linesum
# plot
plt.switch_backend('agg')
fig = plt.figure()
ax = fig.add_subplot(111)
cax = ax.matshow(matrix)
fig.colorbar(cax)
ax.xaxis.set_major_locator(MultipleLocator(1))
ax.yaxis.set_major_locator(MultipleLocator(1))
for i in range(matrix.shape[0]):
ax.text(i, i, str('%.2f' % (matrix[i, i] * 100)), va='center', ha='center')
ax.set_xticklabels([''] + classes, rotation=90)
ax.set_yticklabels([''] + classes)
#save
plt.savefig(savname)
这种方法可以直接标出坐标轴的含义,比较方便。