(绝对详细)CenterNet训练自己的数据(pytorch0.4.1)
我参考的这篇博客,对我自己的实验帮助很大:https://blog.csdn.net/weixin_42634342/article/details/97756458
论文作者代码:https://github.com/xingyizhou/CenterNet
这个博客是整个训练的过程,可能会有点长。
1. 准备数据集
0.(我用的数据是VOC格式的,需要将其转化为COCO格式)
详细过程在我的另一个博客里,本来想写在这里,发现太长了,就移到另一个里了。
链接:https://blog.csdn.net/weixin_41765699/article/details/100124689
1. 当我们生成三个json文件之后,来到CenterNet这个工程里,在data文件夹下新建一个文件夹,名字就是你数据集的名字,如下图:
再在这个文件夹里面建两个文件夹(annotations里面存放的是我们之前生成的那三个json文件;images存放的是所有的图片,包括训练测试验证三个,所有的):
2. 在src/lib/datasets/dataset
里面新建一个“ped. py”,文件内容照着文件夹下coco.py改成自己的
0. 将COCO类改成自己的名字
1. 第14行num_classes=80改成自己的类别数
2. 第15行default_resolution(这个参数有两种(300,300)或者(512,512),很明显512的参数计算量大,300计算量小,我用的是512,之后打算训练一个300的对比一下)
3. 接下来的mean和std改成自己图片数据集的均值和方差,脚本链接: https://blog.csdn.net/weixin_41765699/article/details/100118660
4. 修改数据和图片路径,data_dir 输入的是咱们之前建立的数据集文件夹的名字,img_dir 输入的是 images 图片文件夹
5. 修改json文件路径如下:
6. 类别名字和类别id改成自己
我就改了以上六点内容。
3. 将数据集加入src/lib/datasets/dataset_factory
里面
1. 在dataset_facto字典里加入自己的数据集名字 (格式为 '你之前创建的Python文件的名字':你自己数据集类的名字,因为要从你创建的py文件里找到你的数据类,名字必须对应上)
4. 修改/src/lib/opts.py
1.第一步,将自己的数据集设为默认数据集,加入到help里面
-
self.parser.add_argument('--dataset', default='ped',
-
help='coco | kitti | coco_hp | pascal | ped)
2.修改ctdet任务使用的默认数据集为新添加的数据集,如下(修改分辨率,类别数,均值,方差,数据集名字):
3. 修改src/lib/utils/debugger.py文件(变成自己数据的类别和名字,前后数据集名字一定保持一致)
再加上自己数据的类别,不包括背景__background__
到这里,准备数据集的工作就告一段落了!
2. 搭建pytorch0.4.1+cuda90+cudnn7.6.1(版本不能改,还有就是numpy的版本必须在1.13以上,建议装最新的)
我搭建这个环境也费老大劲了,pytorch1.2貌似直接pip安装就自动装上了cuda和cudnn,0.4.1版本的我没看见有自动安装的,所以就苦哈哈自己动手装了,关于这个,我也记录了一下,大家也可以自己上网查查别的方法
cuda和cudnn安装链接:https://blog.csdn.net/weixin_41765699/article/details/99966617
torch0.4.1安装链接:https://blog.csdn.net/weixin_41765699/article/details/99756697
3. 克隆工程并运行demo
关于工程里面这个作者写的很详细了,我是按照一步步来的,没有出错。https://github.com/xingyizhou/CenterNet/blob/master/readme/INSTALL.md
程序里面在运行demo.py之前,会下载一个预训练权重(比如dla34,resnet18,resnet101之类的),这个不用管,等他下载完,因为我们训练的时候也要用的。(下载的时候可能会很慢,如果是在龟速的话,将他下载的网址用QQ浏览器打开自己下载,下载完放到这个它自动创建的文件夹里就可以了,QQ浏览器下载确实比其他的稍快一些)
改完之后在MODEL_ZOO.md里面下载参数,ctdet_coco_dla_2x,下载完毕后放在models文件夹里面。
到这里,环境基本搭建成功,接下来可以跑一下代码了
(模型文件下载貌似要Google drive,这是我下载的:
链接:https://pan.baidu.com/s/1QOmIwy8lXJBuLv5hH5j3ag
提取码:vwk4 )
运行demo.py
python demo.py ctdet --demo /home/CenterNet/images/ --load_model /home/CenterNet/models/ctdet_coco_dla_2x.pth
嘿嘿嘿
要注意的是,当弹出第一站图片的时候,按esc除外的任意键可以继续检测下一张图,想要保存检测结果的话,只需要在src/lib/detectors/cdet.py文件中:
-
def show_results(self, debugger, image, results): # demo文件會調用這個函數,本函`python main.py ctdet --exp_id coco_dla --batch_size 32 --master_batch 1 --lr 1.25e-4 --gpus 0,1`數是demo時顯示圖片並保存圖片
-
debugger.add_img(image, img_id='ctdet')
-
for j in range(1, self.num_classes + 1):
-
for bbox in results[j]:
-
if bbox[4] > self.opt.vis_thresh:
-
debugger.add_coco_bbox(bbox[:4], j - 1, bbox[4], img_id='ctdet')
-
debugger.show_all_imgs(pause=self.pause)
-
debugger.save_all_imgs(path='/home/czb/CenterNet-master/output/', genID=True)
加上一行代码,就是最后一行debugger.save_all_imgs(path='/home/CenterNet/output/', genID=True) ,path是输出路径,需要在CenterNet文件夹下新建一个文件夹output,然后再运行一遍发现检测后的图片就会保存在这个文件夹里面了。当然,去掉倒数第二行show_all_imgs,那么运行的时候就不会弹出照片了。
4. 训练阶段
1. 定位一下发现前面自己建立的ped.py文件(修改下面的代码):
-
if split == 'val':
-
self.annot_path = os.path.join(
-
self.data_dir, 'annotations',
-
'val.json').format(split) # 修改test的json文件位置
-
else:
-
if opt.task == 'exdet':
-
self.annot_path = os.path.join(
-
self.data_dir, 'annotations',
-
'train.json').format(split)
-
else:
-
self.annot_path = os.path.join(
-
self.data_dir, 'annotations',
-
'train.json').format(split) # 这才是train文件
要把第一行if split 改为 ==‘val’,这样validate时就会调用val.json文件。把最后一行要调用的文件改为‘train.json’,这样训练的时候才会调用train.json文件。改完之后数据集导入就正常了。
2. 运行main.py
python main.py ctdet --exp_id coco_dla --batch_size 32 --master_batch 1 --lr 1.25e-4 --gpus 0,1
(如果显示显存不够之类的那种错误,需要在opts.py文件中将--num_workers改成0,batch_size改成16或者更小
)
这时候会下载一个预训练模型,可能会很慢,我是。。下载的,这是百度盘链接,需要的可以用:
链接:https://pan.baidu.com/s/1I1oW_l2Xe2-LV1gIjViPTg
提取码:2pt0
下载完之后放在/root/.torch/models里面
(我的是在这个里,你也可以看看他自动下载的那个在哪个文件夹里,然后把权重放在那个文件夹下)
没有意外的话,经过上面的步骤,就开始训练了::::
5. 测试部分
当训练完之后(我训练了两天,泰坦X,140个epochs,有点憨批,其实最好的模型是在第55个epoch出现的),在./exp/ctdet/coco_dla/文件夹下会出现如下文件
其中,model_last是最后一次epoch的模型;model_best是val最好的模型,我选的是model_best模型;
然后开始测试。。。。。。
1. 在我们之前建立的ped.py中修改如下部分,加入if split == ‘test’:…,作用是当test时指定标签文件为之前建立的测试文件 test.json
2. 运行test.py
python test.py --exp_id coco_dla --not_prefetch_test ctdet --load_model /root/CenterNet/exp/ctdet/coco_dla/model_best.pth
不出意外的话会出现下面的画面(出现一系列AP值),其中,一般使用的是第二行,也就是IOU=0.5,全区域的AP值,其他的分别是不同IOU以及不同目标尺寸区域的结果。
完事了。。。
2019.9.6
附加1:
我想换个骨干网络试试,作者的源代码支持resnet和hourglass,我尝试替换成resnet18,记录一下替换方法:
在原来的训练命令行命令里添加上两个参数:(顺便把exp_id 改一下,保证每个模型不乱)
python main.py ctdet --exp_id coco_res_18 --batch_size 32 --master_batch 1 --lr 1.25e-4 --gpus 0,1 --arch res_18 --head_conv 64
开始训练时也会下载相应的预训练模型,如果下载速度慢,也参照上面说的方法下载。
训练之后,在测试和运行demo的命令行代码里也要加上两个参数:--arch res_18 --head_conv 64
附加2:
训练完成的时候,我们需要绘制出loss值得曲线,以下代码可以实现该功能:
训练生成的日志文件一般在exp/ctdet/../../logs.txt,找到这个文件,打开之后会出现如下:
我们需要读取这些loss值并可视化(一般情况下,该代码只需要改变日志文件的路径即可):
-
import matplotlib.pyplot as plt
-
import numpy as np
-
-
-
def plot_loss_curve(log_file):
-
-
loss_data = open(log_file)
-
all_lines = loss_data.readlines()
-
print(all_lines[4].split(' '))
-
# losses
-
total_loss = [] # 4
-
hm_loss = [] # 7
-
wh_loss = [] # 10
-
off_loss = [] # 13
-
val_loss = [] # 19
-
spend_time = [] # 16
-
num_lines = len(all_lines)
-
for line in range(num_lines):
-
total_loss1 = all_lines[line].split(' ')[4]
-
hm_loss1 = all_lines[line].split(' ')[7]
-
wh_loss1 = all_lines[line].split(' ')[10]
-
off_loss1 = all_lines[line].split(' ')[13]
-
spend_time1 = all_lines[line].split(' ')[16]
-
-
total_loss.append(float(total_loss1))
-
hm_loss.append(float(hm_loss1))
-
wh_loss.append(float(wh_loss1))
-
off_loss.append(float(off_loss1))
-
spend_time.append(float(spend_time1))
-
-
index_val = np.linspace(0, 140, 29) - 1
-
index_val = np.delete(index_val, 0, 0)
-
-
for id in index_val:
-
-
val_loss1 = all_lines[int(id)].split(' ')[19]
-
val_loss.append(float(val_loss1))
-
return val_loss, total_loss
-
-
-
if __name__ == '__main__':
-
# 标准图形绘制
-
# sns.set()
-
vloss_res18, loss_res18 = plot_loss_curve('logres18.txt') # 读取训练时生成的日志文件
-
# vloss_resdcn18, loss_resdcn18 = plot_loss_curve('logresdcn18.txt')
-
# vloss_dla, loss_dla = plot_loss_curve('logdla34.txt')
-
# vloss_res101, loss_res101 = plot_loss_curve('logres101.txt')
-
# vloss_dla34p, loss_dla34p = plot_loss_curve('logdla34p.txt')
-
# vloss_hg, loss_hg = plot_loss_curve('loghourglass.txt')
-
-
fig = plt.figure(figsize=(10, 4))
-
ax = fig.add_subplot(111)
-
ax.plot(range(len(loss_res18)), loss_res18, 'c', label='res_18_train_loss', linewidth=1) # 这个label是图线自己的标签;
-
# ax.plot(range(len(loss_resdcn18)), loss_resdcn18, 'y', label='resdcn_18_train_loss', linewidth=1)
-
# ax.plot(range(len(loss_dla)), loss_dla, 'b', label='dla_34_train_loss', linewidth=1)
-
# ax.plot(range(len(loss_res101)), loss_res101, 'g', label='res_101_train_loss', linewidth=1)
-
# ax.plot(range(len(loss_dla34p)), loss_dla34p, 'r', label='dla_34_train_loss', linewidth=1)
-
# ax.plot(range(len(loss_hg)), loss_hg, 'r', label='hourglass_train_loss', linewidth=1)
-
-
# ax.plot(index_val+1, val_loss, label='val_loss')
-
# ax.plot(np.random.randn(1000).cumsum(), label='line2')
-
# ax.set_xlim([0, 800]) # 设置刻度;
-
# ax.set_xticks(range(0, 500, 100)) # 设置显示的刻度;
-
# ax.set_yticklabels(['jan', 'feb', 'mar']) # 设置刻度标签;
-
ax.set_xlabel('epochs') # 设置坐标轴标签;
-
ax.set_ylabel('loss_value')
-
# ax.text(8750, 20, "海拔", color='red') # 加入文本
-
ax.set_title('loss_of_CenterNet')
-
ax.legend(loc='best') # 将图例摆放在不遮挡图线的位置即可
-
ax.grid() # 添加网格
-
plt.savefig('loss_of_CenterNet.png') # 保存文件到指定文件夹
-
plt.show()
-
-
fig1 = plt.figure(figsize=(12, 8))
-
ax1 = fig1.add_subplot(111)
-
ax1.plot(range(len(vloss_res18)), vloss_res18, 'c', label='res_18_val_loss', linewidth=2) # 这个label是图线自己的标签;
-
# ax1.plot(range(len(vloss_resdcn18)), vloss_resdcn18, 'y', label='resdcn_18_val_loss', linewidth=2)
-
# ax1.plot(range(len(vloss_dla)), vloss_dla, 'b', label='dla_34_val_loss', linewidth=2)
-
# ax1.plot(range(len(vloss_res101)), vloss_res101, 'g', label='res_101_val_loss', linewidth=2)
-
# ax1.plot(range(len(vloss_dla34p)), vloss_dla34p, 'r', label='dla_34_val_loss_p', linewidth=2)
-
# ax.plot(index_val+1, val_loss, label='val_loss')
-
# ax.plot(np.random.randn(1000).cumsum(), label='line2')
-
# ax.set_xlim([0, 800]) # 设置刻度;
-
# ax.set_xticks(range(0, 500, 100)) # 设置显示的刻度;
-
# ax.set_yticklabels(['jan', 'feb', 'mar']) # 设置刻度标签;
-
ax1.set_xlabel('epochs') # 设置坐标轴标签;
-
ax1.set_ylabel('loss_value')
-
# ax.text(8750, 20, "海拔", color='red') # 加入文本
-
ax1.set_title('val_loss_of_CenterNet')
-
ax1.legend(loc='best') # 将图例摆放在不遮挡图线的位置即可
-
ax1.grid() # 添加网格
-
plt.savefig('val_loss_of_CenterNet.png') # 保存文件到指定文件夹
-
plt.show()
<div class="person-messagebox">
<div class="left-message"><a href="https://blog.csdn.net/weixin_41765699">
<img src="https://profile.csdnimg.cn/1/2/9/3_weixin_41765699" class="avatar_pic" username="weixin_41765699">
</a></div>
<div class="middle-message">
<div class="title"><span class="tit "><a href="https://blog.csdn.net/weixin_41765699" data-report-click="{"mod":"popu_379","ab":"new"}" target="_blank">linbior</a></span>
<!-- 等级,level -->
<img class="identity-icon" src="https://csdnimg.cn/identity/blog5.png"> </div>
<div class="text"><span>原创文章 46</span><span>获赞 80</span><span>访问量 14万+</span></div>
</div>
<div class="right-message">
<a class="btn btn-sm bt-button personal-watch" data-report-click="{"mod":"popu_379","ab":"new"}">关注</a>
<a href="https://im.csdn.net/im/main.html?userName=weixin_41765699" target="_blank" class="btn btn-sm bt-button personal-letter">私信
</a>
</div>
</div>
</div>