欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页  >  IT编程

pytorch-cpn可视化标注信息

程序员文章站 2022-06-05 19:00:17
pytorch-cpn项目的代码对coco标注进行了重新组装,但是基本的标注内容并没有改变,关键点的坐标依然是以图片左上角为坐标原点,标注格式依然为[x1,y1,v1,x2,y2,v2,…,x17,y17,v17]。在其mscocoMulti.py文件中,以下为核心可视化代码的部分。 plt.figure() c = (np.random.random((1, 3)) * 0.6 + 0.4).tolist()[0] plt.plot(x[v > 0], y[...

pytorch-cpn项目的代码对coco标注进行了重新组装,但是基本的标注内容并没有改变,关键点的坐标依然是以图片左上角为坐标原点,标注格式依然为[x1,y1,v1,x2,y2,v2,…,x17,y17,v17]。在其mscocoMulti.py文件中,以下为核心可视化代码的部分。使用以下代码,即可实现原图和标注在同一个plt画布中显示。

   		plt.figure()
        c = (np.random.random((1, 3)) * 0.6 + 0.4).tolist()[0]
        plt.plot(x[v > 0], y[v > 0], 'o', markersize=10, markerfacecolor=c, markeredgecolor='k', markeredgewidth=2)
        img = Image.open(os.path.join(img_folder, image_name))
        plt.imshow(img)
        plt.show()

完整加入以下代码,即可可视化第一个标注信息。

if __name__ == '__main__':
    import argparse
    parser = argparse.ArgumentParser(description='PyTorch CPN Training')
    parser.add_argument('-j', '--workers', default=12, type=int, metavar='N',
                        help='number of data loading workers (default: 12)')
    parser.add_argument('-g', '--num_gpus', default=1, type=int, metavar='N',
                        help='number of GPU to use (default: 1)')
    parser.add_argument('--epochs', default=32, type=int, metavar='N',
                        help='number of total epochs to run (default: 32)')
    parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
                        help='manual epoch number (useful on restarts)')
    parser.add_argument('-c', '--checkpoint', default='checkpoint', type=str, metavar='PATH',
                        help='path to save checkpoint (default: checkpoint)')
    parser.add_argument('--resume', default='', type=str, metavar='PATH',
                        help='path to latest checkpoint')
    args = parser.parse_args()

    import sys

    sys.path.append('../256.192.model')
    from config import cfg
    from PIL import Image
    import matplotlib.pyplot as plt

    img_folder =  cfg.img_path
    with open(cfg.gt_path) as anno_file:
        anno = json.load(anno_file)
        for i, item in enumerate(anno):
            print(i, item)
            if (i == 5):
                print(item.keys())
                break;

        num_class = 17
        a = anno[0]
        image_name = a['imgInfo']['img_paths']
        points = np.array(a['unit']['keypoints']).reshape(num_class, 3).astype(np.float32)
        gt_bbox = a['unit']['GT_bbox']
        points = points.flatten()
        x = points[0::3]
        y = points[1::3]
        v = points[2::3]
        print(points)
        print(x)
        print(y)
        print(v)
        print(x[v>0])
        print(y[v>0])
        plt.figure()
        c = (np.random.random((1, 3)) * 0.6 + 0.4).tolist()[0]
        plt.plot(x[v > 0], y[v > 0], 'o', markersize=10, markerfacecolor=c, markeredgecolor='k', markeredgewidth=2)
        img = Image.open(os.path.join(img_folder, image_name))
        plt.imshow(img)
        plt.show()

pytorch-cpn可视化标注信息

pytorch-cpn可视化标注信息

本文地址:https://blog.csdn.net/qq_37025073/article/details/107350206