多标签图像识别的评价指标map和实现代码(python)
程序员文章站
2022-04-02 10:54:10
主要思路是计算每个类别的ap,再对所有的类别取平均得到map。ap 是p-r 曲线下的面积def getAplist(pred, label, aplistSavePath): ''' pred is output of sigmoid, Calculate the AP for each category and get map, Number of categories in this example is 5000 input : pred [ b...
主要思路是计算每个类别的ap,再对所有的类别取平均得到map。
ap 是p-r 曲线下的面积
def getAplist(pred, label, aplistSavePath):
'''
pred is output of sigmoid,
Calculate the AP for each category and get map, Number of categories in this example is 5000
input :
pred [ batch, C ] tensor
label[batch, C] tensor
output:
aplist : [ 5000 ] numpy.array
'''
map_list = []
for cls_index in range(5000):
cls_pred = pred[:,cls_index]
cls_label = label[:,cls_index]
cls_ap = certainClassAP(cls_pred, cls_label, 5000, 0.5)
map_list.append(cls_ap)
map_list = np.array(map_list)
print ('map is',map_list.mean())
np.save(aplistSavePath, map_list)
def certainClassAP(model_pred, labels, N, accuracy_th):
'''
get ap of certain class
model_pred: [batch] tensor
labels: [batch] tensor
N: (e.g. 5000) int
accuracy_th: (e.g. 0.5) float
'''
p_list = [0 for i in range(N)]
r_list = [0 for i in range(N)]
for i in range (N):
temp_pred = model_pred[:i+1]
temp_label = labels[:i+1]
pred_result = temp_pred > accuracy_th
pred_result = pred_result.float()
pred_one_num = torch.sum(pred_result)
if pred_one_num == 0:
p_list[i] = 0
r_list[i] = 0
continue
target_one_num = torch.sum(temp_label)
true_predict_num = torch.sum(pred_result * temp_label)
precision = true_predict_num / pred_one_num
recall = true_predict_num / target_one_num
p_list[i] = precision
r_list[i] = recall
precisions = np.array(p_list)
recalls = np.array(r_list)
average_precision = 0
for threshold in np.arange(0, 1.1, 0.1):
precisions_at_recall_threshold = precisions[recalls >= threshold]
if precisions_at_recall_threshold.size > 0:
max_precision = np.max(precisions_at_recall_threshold)
else:
max_precision = 0
average_precision = average_precision + max_precision / 11
print ('cur class ap:',average_precision)
return average_precision
getAplist(pred, label, './aplist@0.5.npy')
本文地址:https://blog.csdn.net/weixin_42544131/article/details/110925975
下一篇: 定时闹钟