欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

多任务学习(multi-label learning)扫盲

程序员文章站 2022-03-03 14:37:06
...

1,MLL 没有单独的直观的二维混淆矩阵,sklearn的实现是多少label就有多少混淆矩阵,每个label一个2x2的混淆矩阵。
(每个label的值是0/1)

from sklearn.metrics import multilabel_confusion_matrix
import numpy as np
classes = ['green', 'black', 'red', 'blue']
targetSrc = [[0,1,1,1], [0,0,1,0], [1,0,0,1], [1,1,1,0], [1,0,0,0]]
predSrc = [[0,1,0,1], [0,0,1,1], [1,0,0,1], [1,0,1,0], [1,0,0,0]]
target = np.array(targetSrc)
pred = np.array(predSrc)
cm = multilabel_confusion_matrix(target, pred)
print(cm.shape) # (4, 2, 2)
print(cm)
# [[[2 0]
#   [0 3]] # 第一个label没有预测错的
#
# [[3 0]
# [1 1]] # 第二个label, 有一个 black预测成了非black
#
# [[2 0] # 一个red预测成了非red
#  [1 2]]
#
# [[2 1] # 一个非blue预测成了blue
#  [0 2]]]

注意:

  1. sklearn要安装 0.21及以上版本。
  2. 上述例子中是4个label, 5个样本,最终的结果是由4个小的混淆矩阵组成的。
  3. 每个小矩阵纵轴是 true label [0, 1], 横轴是 predict label [0, 1]
  4. 输入可以接受numpy array或者torch tensor

参考:
https://scikit-learn.org/stable/modules/generated/sklearn.metrics.multilabel_confusion_matrix.html
https://*.com/questions/55877681/importerror-cannot-import-name-multilabel-confusion-matrix

相关标签: ML # DL-基础