欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页  >  IT编程

多标签图像识别的评价指标map和实现代码(python)

程序员文章站 2022-04-02 10:54:10
主要思路是计算每个类别的ap,再对所有的类别取平均得到map。ap 是p-r 曲线下的面积def getAplist(pred, label, aplistSavePath): ''' pred is output of sigmoid, Calculate the AP for each category and get map, Number of categories in this example is 5000 input : pred [ b...

主要思路是计算每个类别的ap,再对所有的类别取平均得到map。
ap 是p-r 曲线下的面积

def getAplist(pred, label, aplistSavePath):
    '''
    pred is output of sigmoid,
    Calculate the AP for each category and get map, Number of categories in this example is 5000
    
    input :
    pred [ batch, C ] tensor
    label[batch, C] tensor
    
    output:
    aplist : [ 5000 ] numpy.array
    '''
    map_list = []
    for cls_index in range(5000):
        cls_pred = pred[:,cls_index]
        cls_label = label[:,cls_index]
        cls_ap = certainClassAP(cls_pred, cls_label, 5000, 0.5)
        map_list.append(cls_ap)
    map_list = np.array(map_list)
    print ('map is',map_list.mean())
    np.save(aplistSavePath, map_list)

def certainClassAP(model_pred, labels, N, accuracy_th):
    '''
    get ap of certain class
    
    model_pred: [batch] tensor
    labels: [batch] tensor
    N: (e.g. 5000) int
    accuracy_th: (e.g. 0.5) float
    '''
    p_list = [0 for i in range(N)]
    r_list = [0 for i in range(N)]
    for i in range (N):
        temp_pred = model_pred[:i+1]
        temp_label = labels[:i+1]
        pred_result = temp_pred > accuracy_th
        pred_result = pred_result.float()
        pred_one_num = torch.sum(pred_result)
        if pred_one_num == 0:
            p_list[i] = 0
            r_list[i] = 0
            continue
        target_one_num = torch.sum(temp_label)
        true_predict_num = torch.sum(pred_result * temp_label)
        precision = true_predict_num / pred_one_num
        recall = true_predict_num / target_one_num
        p_list[i] = precision
        r_list[i] = recall
    precisions = np.array(p_list)
    recalls = np.array(r_list)
    average_precision = 0
    for threshold in np.arange(0, 1.1, 0.1):
        precisions_at_recall_threshold = precisions[recalls >= threshold]
        if precisions_at_recall_threshold.size > 0:
            max_precision = np.max(precisions_at_recall_threshold)
        else:
           max_precision = 0
        average_precision = average_precision + max_precision / 11
    print ('cur class ap:',average_precision)
    return average_precision
getAplist(pred, label, './aplist@0.5.npy')

本文地址:https://blog.csdn.net/weixin_42544131/article/details/110925975