欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页  >  IT编程

检测器评价指标计算流程voc_eval简化程序

程序员文章站 2022-07-02 09:15:11
最原始的代码是py-fasterrcnn作者写的,这个使用过程有点繁杂,主要是传递的参数比较多杂,我在原始基础上做了封装和部分重写,方便自己使用,在此也分享一下:import xml.etree.ElementTree as ETimport os,glob,jsonimport numpy as npdef pathjoin(prepath,*filenames): fullpath=prepath for filename in filenames: fu...

最原始的代码是py-fasterrcnn作者写的,这个使用过程有点繁杂,主要是传递的参数比较多杂,我在原始基础上做了封装和部分重写,方便自己使用,在此也分享一下:


import xml.etree.ElementTree as ET
import os,glob,json
import numpy as np

def pathjoin(prepath,*filenames):
    fullpath=prepath
    for filename in filenames:
        fullpath=os.path.join(fullpath,filename)
    return fullpath

def getxextnamefilefromdir(datadir,fileextname='.xml'):
    filepaths = glob.glob('{}/*{}'.format(datadir, fileextname))
    filepaths.sort()
    return filepaths


class VOCDetIndicator:
    def __init__(self,testsetdirs,detresulttxtdir,databasename,clsnames,mergeclsnamesdict={},iouthresh=0.5,use_07_metric=False):
        self.testsetdirs=testsetdirs
        self.clsnames=clsnames
        self.mergeclsnamesdict=mergeclsnamesdict
        self.detresulttxtdir=detresulttxtdir
        self.iouthresh=iouthresh
        self.use_07_metric=use_07_metric

        def getannotgroundtruth():
            def parse_rec(filename):
                """ Parse a PASCAL VOC xml file """
                tree = ET.parse(filename)
                objects = []
                for obj in tree.findall('object'):
                    obj_struct = {}
                    cls = obj.find('name').text
                    if cls in self.mergeclsnamesdict.keys():
                        cls=mergeclsnamesdict[cls]
                    obj_struct['name'] = cls
                    # obj_struct['pose'] = obj.find('pose').text
                    obj_struct['truncated'] = int(obj.find('truncated').text)
                    obj_struct['difficult'] = int(obj.find('difficult').text)
                    bbox = obj.find('bndbox')
                    obj_struct['bbox'] = [int(float(bbox.find('xmin').text)),
                                          int(float(bbox.find('ymin').text)),
                                          int(float(bbox.find('xmax').text)),
                                          int(float(bbox.find('ymax').text))]
                    objects.append(obj_struct)
                # print 'voc_eval.py-->parse_rec()-->objects='+str(objects)+'\n'
                return objects

            recs = {}
            annotcachefilepath = pathjoin(detresulttxtdir, databasename + '_annot.json')
            if os.path.exists(annotcachefilepath):
                print('reading annotation cache file')
                with open(annotcachefilepath, 'r') as f:
                    recs = json.load(f)
            else:
                i = 0
                for testsetdir in testsetdirs:
                    annotpaths = getxextnamefilefromdir(testsetdir, fileextname='.xml')
                    for annotpath in annotpaths:
                        filename = os.path.basename(annotpath)[:-4]
                        recs[filename] = parse_rec(annotpath)
                        i += 1
                        if i % 500 == 0:
                            print('Reading annotation for {:d}'.format(i))
                # save

                print('Saving cached annotations to {:s}'.format(annotcachefilepath))
                with open(annotcachefilepath, 'w') as fp:
                    json.dump(recs, fp,indent=4)

            return recs

        self.annotgroundtruth=getannotgroundtruth()



    def voc_ap(self,rec, prec, use_07_metric=False):
        """ ap = voc_ap(rec, prec, [use_07_metric])
        Compute VOC AP given precision and recall.
        If use_07_metric is true, uses the
        VOC 07 11 point method (default:False).
        """
        if use_07_metric:
            # 11 point metric
            ap = 0.
            for t in np.arange(0., 1.1, 0.1):
                if np.sum(rec >= t) == 0:
                    p = 0
                else:
                    p = np.max(prec[rec >= t])
                ap = ap + p / 11.
        else:
            # correct AP calculation
            # first append sentinel values at the end
            mrec = np.concatenate(([0.], rec, [1.]))
            mpre = np.concatenate(([0.], prec, [0.]))

            # compute the precision envelope
            #
            for i in range(mpre.size - 1, 0, -1):
                mpre[i - 1] = np.maximum(mpre[i - 1], mpre[i])

            # to calculate area under PR curve, look for points
            # where X axis (recall) changes value
            i = np.where(mrec[1:] != mrec[:-1])[0]

            # and sum (\Delta recall) * prec
            ap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1])
        return ap


    def voc_eval(self,clsname):
        recs = self.annotgroundtruth
        # extract gt objects for this class
        class_recs = {}
        npos = 0  #
        for imagename in recs.keys():
            R = [obj for obj in recs[imagename] if obj['name'] == clsname]

            bbox = np.array([x['bbox'] for x in R])
            difficult = np.array([x['difficult'] for x in R]).astype(np.bool)
            det = [False] * len(R)
            npos = npos + sum(~difficult)
            class_recs[imagename] = {'bbox': bbox,
                                     'difficult': difficult,
                                     'det': det}

        # read dets
        detresulttxtname = '{}.txt'.format(clsname)
        detfile =pathjoin(self.detresulttxtdir,detresulttxtname)
        with open(detfile, 'r') as f:
            lines = f.readlines()

        splitlines = [x.strip().split(' ') for x in lines]
        image_ids = [x[0] for x in splitlines]
        confidence = np.array([float(x[1]) for x in splitlines])
        BB = np.array([[float(z) for z in x[2:]] for x in splitlines])

        # sort by confidence
        sorted_ind = np.argsort(-confidence)
        sorted_scores = np.sort(-confidence)
        BB = BB[sorted_ind, :]
        image_ids = [image_ids[x] for x in sorted_ind]

        # go down dets and mark TPs and FPs
        nd = len(image_ids)  #
        tp = np.zeros(nd)  #
        fp = np.zeros(nd)  #

        for d in range(nd):
            R = class_recs[image_ids[d]]  #
            bb = BB[d, :].astype(float)  #
            ovmax = -np.inf
            BBGT = R['bbox'].astype(float)  #

            if BBGT.size > 0:
                # compute overlaps
                # intersection
                ixmin = np.maximum(BBGT[:, 0], bb[0])
                iymin = np.maximum(BBGT[:, 1], bb[1])
                ixmax = np.minimum(BBGT[:, 2], bb[2])
                iymax = np.minimum(BBGT[:, 3], bb[3])
                iw = np.maximum(ixmax - ixmin + 1., 0.)
                ih = np.maximum(iymax - iymin + 1., 0.)
                inters = iw * ih

                # union
                uni = ((bb[2] - bb[0] + 1.) * (bb[3] - bb[1] + 1.) +
                       (BBGT[:, 2] - BBGT[:, 0] + 1.) *
                       (BBGT[:, 3] - BBGT[:, 1] + 1.) - inters)

                overlaps = inters / uni
                ovmax = np.max(overlaps)  #
                jmax = np.argmax(overlaps)

            if ovmax > self.iouthresh:  #
                if not R['difficult'][jmax]:  #
                    if not R['det'][jmax]:  #
                        tp[d] = 1.  #
                        R['det'][jmax] = 1  #
                    else:
                        fp[d] = 1.  #
            else:
                fp[d] = 1.  #


        # compute precision recall
        fp = np.cumsum(fp)  #
        tp = np.cumsum(tp)  #
        rec = tp / float(npos)  #
        # avoid divide by zero in case the first detection matches a difficult
        # ground truth
        prec = tp / np.maximum(tp + fp, np.finfo(np.float64).eps)  #
        ap = self.voc_ap(rec, prec, self.use_07_metric)  #

        return rec, prec, ap



    def gettestsetdetindicator(self):
        precisions,recalls,aps={},{},{}
        for clsname in self.clsnames:
            rec, prec, ap = self.voc_eval(clsname)

            precisions[clsname]=round(prec[-1],3)
            recalls[clsname]=round(rec[-1],3)
            aps[clsname]=round(ap,3)

        map=0
        for cls in aps.keys():
            map+=aps[cls]
        map=map/len(aps)
        map=round(map,3)

        precision_recall_ap_map = {'precision':precisions,'recall':recalls,'ap':aps,'map':map}

        return precision_recall_ap_map


if __name__ == '__main__':
    testsetdirs=['testset1','testset2']#maybe you have mutiply testset and they are in different locale
    detresulttxtdir='detect result txt save dir'
    databasename='your database name'#its convenient for remember the data
    clsnames=('class1','class2','...')#class names
    DetIndicator=VOCDetIndicator(testsetdirs, detresulttxtdir, databasename, clsnames)

 

本文地址:https://blog.csdn.net/lingyunxianhe/article/details/109628455

相关标签: Python