pytorch-cpn可视化标注信息
程序员文章站
2022-06-05 19:00:17
pytorch-cpn项目的代码对coco标注进行了重新组装,但是基本的标注内容并没有改变,关键点的坐标依然是以图片左上角为坐标原点,标注格式依然为[x1,y1,v1,x2,y2,v2,…,x17,y17,v17]。在其mscocoMulti.py文件中,以下为核心可视化代码的部分。 plt.figure() c = (np.random.random((1, 3)) * 0.6 + 0.4).tolist()[0] plt.plot(x[v > 0], y[...
pytorch-cpn项目的代码对coco标注进行了重新组装,但是基本的标注内容并没有改变,关键点的坐标依然是以图片左上角为坐标原点,标注格式依然为[x1,y1,v1,x2,y2,v2,…,x17,y17,v17]。在其mscocoMulti.py文件中,以下为核心可视化代码的部分。使用以下代码,即可实现原图和标注在同一个plt画布中显示。
plt.figure()
c = (np.random.random((1, 3)) * 0.6 + 0.4).tolist()[0]
plt.plot(x[v > 0], y[v > 0], 'o', markersize=10, markerfacecolor=c, markeredgecolor='k', markeredgewidth=2)
img = Image.open(os.path.join(img_folder, image_name))
plt.imshow(img)
plt.show()
完整加入以下代码,即可可视化第一个标注信息。
if __name__ == '__main__':
import argparse
parser = argparse.ArgumentParser(description='PyTorch CPN Training')
parser.add_argument('-j', '--workers', default=12, type=int, metavar='N',
help='number of data loading workers (default: 12)')
parser.add_argument('-g', '--num_gpus', default=1, type=int, metavar='N',
help='number of GPU to use (default: 1)')
parser.add_argument('--epochs', default=32, type=int, metavar='N',
help='number of total epochs to run (default: 32)')
parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
help='manual epoch number (useful on restarts)')
parser.add_argument('-c', '--checkpoint', default='checkpoint', type=str, metavar='PATH',
help='path to save checkpoint (default: checkpoint)')
parser.add_argument('--resume', default='', type=str, metavar='PATH',
help='path to latest checkpoint')
args = parser.parse_args()
import sys
sys.path.append('../256.192.model')
from config import cfg
from PIL import Image
import matplotlib.pyplot as plt
img_folder = cfg.img_path
with open(cfg.gt_path) as anno_file:
anno = json.load(anno_file)
for i, item in enumerate(anno):
print(i, item)
if (i == 5):
print(item.keys())
break;
num_class = 17
a = anno[0]
image_name = a['imgInfo']['img_paths']
points = np.array(a['unit']['keypoints']).reshape(num_class, 3).astype(np.float32)
gt_bbox = a['unit']['GT_bbox']
points = points.flatten()
x = points[0::3]
y = points[1::3]
v = points[2::3]
print(points)
print(x)
print(y)
print(v)
print(x[v>0])
print(y[v>0])
plt.figure()
c = (np.random.random((1, 3)) * 0.6 + 0.4).tolist()[0]
plt.plot(x[v > 0], y[v > 0], 'o', markersize=10, markerfacecolor=c, markeredgecolor='k', markeredgewidth=2)
img = Image.open(os.path.join(img_folder, image_name))
plt.imshow(img)
plt.show()
本文地址:https://blog.csdn.net/qq_37025073/article/details/107350206