欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

(绝对详细)CenterNet训练自己的数据(pytorch0.4.1)

程序员文章站 2022-07-11 15:36:23
...

我的任务是在行人头肩数据上训练并测试centernet网络,先证明一下我是真的训练了哈,这是用centernet检测的一张结果(训练了10个epochs的结果,大家放心使用,网络功能还是很强大的):

我参考的这篇博客,对我自己的实验帮助很大:https://blog.csdn.net/weixin_42634342/article/details/97756458

论文作者代码https://github.com/xingyizhou/CenterNet

(绝对详细)CenterNet训练自己的数据(pytorch0.4.1)

这个博客是整个训练的过程,可能会有点长。

1. 准备数据集

    0.(我用的数据是VOC格式的,需要将其转化为COCO格式)

    详细过程在我的另一个博客里,本来想写在这里,发现太长了,就移到另一个里了。

    链接:https://blog.csdn.net/weixin_41765699/article/details/100124689

    1. 当我们生成三个json文件之后,来到CenterNet这个工程里,在data文件夹下新建一个文件夹,名字就是你数据集的名字,如下图:

(绝对详细)CenterNet训练自己的数据(pytorch0.4.1)

     再在这个文件夹里面建两个文件夹(annotations里面存放的是我们之前生成的那三个json文件;images存放的是所有的图片,包括训练测试验证三个,所有的):

(绝对详细)CenterNet训练自己的数据(pytorch0.4.1)

    2. 在src/lib/datasets/dataset里面新建一个“ped. py”,文件内容照着文件夹下coco.py改成自己的

       0. 将COCO类改成自己的名字

       (绝对详细)CenterNet训练自己的数据(pytorch0.4.1)

       1. 第14行num_classes=80改成自己的类别数
       2. 第15行default_resolution(这个参数有两种(300,300)或者(512,512),很明显512的参数计算量大,300计算量小,我用的是512,之后打算训练一个300的对比一下)
       3. 接下来的mean和std改成自己图片数据集的均值和方差,脚本链接:                            https://blog.csdn.net/weixin_41765699/article/details/100118660

       4. 修改数据和图片路径,data_dir 输入的是咱们之前建立的数据集文件夹的名字,img_dir 输入的是 images 图片文件夹

(绝对详细)CenterNet训练自己的数据(pytorch0.4.1)

       5. 修改json文件路径如下:

(绝对详细)CenterNet训练自己的数据(pytorch0.4.1)

       6.  类别名字和类别id改成自己

(绝对详细)CenterNet训练自己的数据(pytorch0.4.1)

       我就改了以上六点内容。

    3. 将数据集加入src/lib/datasets/dataset_factory里面

       1. 在dataset_facto字典里加入自己的数据集名字 (格式为   '你之前创建的Python文件的名字':你自己数据集类的名字,因为要从你创建的py文件里找到你的数据类,名字必须对应上)

(绝对详细)CenterNet训练自己的数据(pytorch0.4.1)

    4. 修改/src/lib/opts.py

         1.第一步,将自己的数据集设为默认数据集,加入到help里面

  1. self.parser.add_argument('--dataset', default='ped',  
  2.                                  help='coco | kitti | coco_hp | pascal | ped)

         2.修改ctdet任务使用的默认数据集为新添加的数据集,如下(修改分辨率,类别数,均值,方差,数据集名字):

   (绝对详细)CenterNet训练自己的数据(pytorch0.4.1)

         3. 修改src/lib/utils/debugger.py文件(变成自己数据的类别和名字,前后数据集名字一定保持一致)

       (绝对详细)CenterNet训练自己的数据(pytorch0.4.1)

            再加上自己数据的类别,不包括背景__background__

            (绝对详细)CenterNet训练自己的数据(pytorch0.4.1)

      到这里,准备数据集的工作就告一段落了!

2. 搭建pytorch0.4.1+cuda90+cudnn7.6.1(版本不能改,还有就是numpy的版本必须在1.13以上,建议装最新的)

     我搭建这个环境也费老大劲了,pytorch1.2貌似直接pip安装就自动装上了cuda和cudnn,0.4.1版本的我没看见有自动安装的,所以就苦哈哈自己动手装了,关于这个,我也记录了一下,大家也可以自己上网查查别的方法

cuda和cudnn安装链接:https://blog.csdn.net/weixin_41765699/article/details/99966617

torch0.4.1安装链接:https://blog.csdn.net/weixin_41765699/article/details/99756697

3. 克隆工程并运行demo

    关于工程里面这个作者写的很详细了,我是按照一步步来的,没有出错。https://github.com/xingyizhou/CenterNet/blob/master/readme/INSTALL.md

    程序里面在运行demo.py之前,会下载一个预训练权重(比如dla34,resnet18,resnet101之类的),这个不用管,等他下载完,因为我们训练的时候也要用的。(下载的时候可能会很慢,如果是在龟速的话,将他下载的网址用QQ浏览器打开自己下载,下载完放到这个它自动创建的文件夹里就可以了,QQ浏览器下载确实比其他的稍快一些)

    (绝对详细)CenterNet训练自己的数据(pytorch0.4.1)

   改完之后在MODEL_ZOO.md里面下载参数,ctdet_coco_dla_2x,下载完毕后放在models文件夹里面。
到这里,环境基本搭建成功,接下来可以跑一下代码了

(模型文件下载貌似要Google drive,这是我下载的:

链接:https://pan.baidu.com/s/1QOmIwy8lXJBuLv5hH5j3ag 
提取码:vwk4 )

    运行demo.py

python demo.py ctdet --demo /home/CenterNet/images/ --load_model /home/CenterNet/models/ctdet_coco_dla_2x.pth

(绝对详细)CenterNet训练自己的数据(pytorch0.4.1)嘿嘿嘿

     要注意的是,当弹出第一站图片的时候,按esc除外的任意键可以继续检测下一张图,想要保存检测结果的话,只需要在src/lib/detectors/cdet.py文件中:

  1.     def show_results(self, debugger, image, results):  # demo文件會調用這個函數,本函`python main.py ctdet --exp_id coco_dla --batch_size 32 --master_batch 1 --lr 1.25e-4  --gpus 0,1`數是demo時顯示圖片並保存圖片
  2.         debugger.add_img(image, img_id='ctdet')
  3.         for j in range(1, self.num_classes + 1):
  4.             for bbox in results[j]:
  5.                 if bbox[4] > self.opt.vis_thresh:
  6.                     debugger.add_coco_bbox(bbox[:4], j - 1, bbox[4], img_id='ctdet')
  7.         debugger.show_all_imgs(pause=self.pause)
  8.         debugger.save_all_imgs(path='/home/czb/CenterNet-master/output/', genID=True)

      加上一行代码,就是最后一行debugger.save_all_imgs(path='/home/CenterNet/output/', genID=True) ,path是输出路径,需要在CenterNet文件夹下新建一个文件夹output,然后再运行一遍发现检测后的图片就会保存在这个文件夹里面了。当然,去掉倒数第二行show_all_imgs,那么运行的时候就不会弹出照片了。

4. 训练阶段

     1. 定位一下发现前面自己建立的ped.py文件(修改下面的代码):

  1.   if split == 'val':
  2.             self.annot_path = os.path.join(
  3.                 self.data_dir, 'annotations',
  4.                 'val.json').format(split) # 修改test的json文件位置
  5.         else:
  6.             if opt.task == 'exdet':
  7.                 self.annot_path = os.path.join(
  8.                     self.data_dir, 'annotations',
  9.                     'train.json').format(split)
  10.             else:
  11.                 self.annot_path = os.path.join(
  12.                     self.data_dir, 'annotations',
  13.                     'train.json').format(split) # 这才是train文件

     要把第一行if split 改为 ==‘val’,这样validate时就会调用val.json文件。把最后一行要调用的文件改为‘train.json’,这样训练的时候才会调用train.json文件。改完之后数据集导入就正常了。
     2. 运行main.py

python main.py ctdet --exp_id coco_dla --batch_size 32 --master_batch 1 --lr 1.25e-4  --gpus 0,1
(如果显示显存不够之类的那种错误,需要在opts.py文件中将--num_workers改成0,batch_size改成16或者更小

(绝对详细)CenterNet训练自己的数据(pytorch0.4.1)

   (绝对详细)CenterNet训练自己的数据(pytorch0.4.1)

(绝对详细)CenterNet训练自己的数据(pytorch0.4.1)

(绝对详细)CenterNet训练自己的数据(pytorch0.4.1)

     这时候会下载一个预训练模型,可能会很慢,我是。。下载的,这是百度盘链接,需要的可以用:

    链接:https://pan.baidu.com/s/1I1oW_l2Xe2-LV1gIjViPTg 
    提取码:2pt0 

     下载完之后放在/root/.torch/models里面

(我的是在这个里,你也可以看看他自动下载的那个在哪个文件夹里,然后把权重放在那个文件夹下)

(绝对详细)CenterNet训练自己的数据(pytorch0.4.1)

    没有意外的话,经过上面的步骤,就开始训练了::::

(绝对详细)CenterNet训练自己的数据(pytorch0.4.1)

5. 测试部分

     当训练完之后(我训练了两天,泰坦X,140个epochs,有点憨批,其实最好的模型是在第55个epoch出现的),在./exp/ctdet/coco_dla/文件夹下会出现如下文件

(绝对详细)CenterNet训练自己的数据(pytorch0.4.1)

     其中,model_last是最后一次epoch的模型;model_best是val最好的模型,我选的是model_best模型;

然后开始测试。。。。。。

   1. 在我们之前建立的ped.py中修改如下部分,加入if split == ‘test’:…,作用是当test时指定标签文件为之前建立的测试文件       test.json

(绝对详细)CenterNet训练自己的数据(pytorch0.4.1)

   2. 运行test.py

       python test.py --exp_id coco_dla --not_prefetch_test ctdet --load_model /root/CenterNet/exp/ctdet/coco_dla/model_best.pth

   不出意外的话会出现下面的画面(出现一系列AP值),其中,一般使用的是第二行,也就是IOU=0.5,全区域的AP值,其他的分别是不同IOU以及不同目标尺寸区域的结果。

(绝对详细)CenterNet训练自己的数据(pytorch0.4.1)

  完事了。。。

2019.9.6

附加1:

我想换个骨干网络试试,作者的源代码支持resnet和hourglass,我尝试替换成resnet18,记录一下替换方法:

在原来的训练命令行命令里添加上两个参数:(顺便把exp_id 改一下,保证每个模型不乱)

python main.py ctdet --exp_id coco_res_18 --batch_size 32 --master_batch 1 --lr 1.25e-4  --gpus 0,1 --arch res_18 --head_conv 64

开始训练时也会下载相应的预训练模型,如果下载速度慢,也参照上面说的方法下载。

训练之后,在测试和运行demo的命令行代码里也要加上两个参数:--arch res_18 --head_conv 64

附加2:

训练完成的时候,我们需要绘制出loss值得曲线,以下代码可以实现该功能:

训练生成的日志文件一般在exp/ctdet/../../logs.txt,找到这个文件,打开之后会出现如下:

(绝对详细)CenterNet训练自己的数据(pytorch0.4.1)

我们需要读取这些loss值并可视化(一般情况下,该代码只需要改变日志文件的路径即可):

  1. import matplotlib.pyplot as plt
  2. import numpy as np
  3. def plot_loss_curve(log_file):
  4. loss_data = open(log_file)
  5. all_lines = loss_data.readlines()
  6. print(all_lines[4].split(' '))
  7. # losses
  8. total_loss = [] # 4
  9. hm_loss = [] # 7
  10. wh_loss = [] # 10
  11. off_loss = [] # 13
  12. val_loss = [] # 19
  13. spend_time = [] # 16
  14. num_lines = len(all_lines)
  15. for line in range(num_lines):
  16. total_loss1 = all_lines[line].split(' ')[4]
  17. hm_loss1 = all_lines[line].split(' ')[7]
  18. wh_loss1 = all_lines[line].split(' ')[10]
  19. off_loss1 = all_lines[line].split(' ')[13]
  20. spend_time1 = all_lines[line].split(' ')[16]
  21. total_loss.append(float(total_loss1))
  22. hm_loss.append(float(hm_loss1))
  23. wh_loss.append(float(wh_loss1))
  24. off_loss.append(float(off_loss1))
  25. spend_time.append(float(spend_time1))
  26. index_val = np.linspace(0, 140, 29) - 1
  27. index_val = np.delete(index_val, 0, 0)
  28. for id in index_val:
  29. val_loss1 = all_lines[int(id)].split(' ')[19]
  30. val_loss.append(float(val_loss1))
  31. return val_loss, total_loss
  32. if __name__ == '__main__':
  33. # 标准图形绘制
  34. # sns.set()
  35. vloss_res18, loss_res18 = plot_loss_curve('logres18.txt') # 读取训练时生成的日志文件
  36. # vloss_resdcn18, loss_resdcn18 = plot_loss_curve('logresdcn18.txt')
  37. # vloss_dla, loss_dla = plot_loss_curve('logdla34.txt')
  38. # vloss_res101, loss_res101 = plot_loss_curve('logres101.txt')
  39. # vloss_dla34p, loss_dla34p = plot_loss_curve('logdla34p.txt')
  40. # vloss_hg, loss_hg = plot_loss_curve('loghourglass.txt')
  41. fig = plt.figure(figsize=(10, 4))
  42. ax = fig.add_subplot(111)
  43. ax.plot(range(len(loss_res18)), loss_res18, 'c', label='res_18_train_loss', linewidth=1) # 这个label是图线自己的标签;
  44. # ax.plot(range(len(loss_resdcn18)), loss_resdcn18, 'y', label='resdcn_18_train_loss', linewidth=1)
  45. # ax.plot(range(len(loss_dla)), loss_dla, 'b', label='dla_34_train_loss', linewidth=1)
  46. # ax.plot(range(len(loss_res101)), loss_res101, 'g', label='res_101_train_loss', linewidth=1)
  47. # ax.plot(range(len(loss_dla34p)), loss_dla34p, 'r', label='dla_34_train_loss', linewidth=1)
  48. # ax.plot(range(len(loss_hg)), loss_hg, 'r', label='hourglass_train_loss', linewidth=1)
  49. # ax.plot(index_val+1, val_loss, label='val_loss')
  50. # ax.plot(np.random.randn(1000).cumsum(), label='line2')
  51. # ax.set_xlim([0, 800]) # 设置刻度;
  52. # ax.set_xticks(range(0, 500, 100)) # 设置显示的刻度;
  53. # ax.set_yticklabels(['jan', 'feb', 'mar']) # 设置刻度标签;
  54. ax.set_xlabel('epochs') # 设置坐标轴标签;
  55. ax.set_ylabel('loss_value')
  56. # ax.text(8750, 20, "海拔", color='red') # 加入文本
  57. ax.set_title('loss_of_CenterNet')
  58. ax.legend(loc='best') # 将图例摆放在不遮挡图线的位置即可
  59. ax.grid() # 添加网格
  60. plt.savefig('loss_of_CenterNet.png') # 保存文件到指定文件夹
  61. plt.show()
  62. fig1 = plt.figure(figsize=(12, 8))
  63. ax1 = fig1.add_subplot(111)
  64. ax1.plot(range(len(vloss_res18)), vloss_res18, 'c', label='res_18_val_loss', linewidth=2) # 这个label是图线自己的标签;
  65. # ax1.plot(range(len(vloss_resdcn18)), vloss_resdcn18, 'y', label='resdcn_18_val_loss', linewidth=2)
  66. # ax1.plot(range(len(vloss_dla)), vloss_dla, 'b', label='dla_34_val_loss', linewidth=2)
  67. # ax1.plot(range(len(vloss_res101)), vloss_res101, 'g', label='res_101_val_loss', linewidth=2)
  68. # ax1.plot(range(len(vloss_dla34p)), vloss_dla34p, 'r', label='dla_34_val_loss_p', linewidth=2)
  69. # ax.plot(index_val+1, val_loss, label='val_loss')
  70. # ax.plot(np.random.randn(1000).cumsum(), label='line2')
  71. # ax.set_xlim([0, 800]) # 设置刻度;
  72. # ax.set_xticks(range(0, 500, 100)) # 设置显示的刻度;
  73. # ax.set_yticklabels(['jan', 'feb', 'mar']) # 设置刻度标签;
  74. ax1.set_xlabel('epochs') # 设置坐标轴标签;
  75. ax1.set_ylabel('loss_value')
  76. # ax.text(8750, 20, "海拔", color='red') # 加入文本
  77. ax1.set_title('val_loss_of_CenterNet')
  78. ax1.legend(loc='best') # 将图例摆放在不遮挡图线的位置即可
  79. ax1.grid() # 添加网格
  80. plt.savefig('val_loss_of_CenterNet.png') # 保存文件到指定文件夹
  81. plt.show()

 

        <div class="person-messagebox">
            <div class="left-message"><a href="https://blog.csdn.net/weixin_41765699">
                <img src="https://profile.csdnimg.cn/1/2/9/3_weixin_41765699" class="avatar_pic" username="weixin_41765699">
            </a></div>
            <div class="middle-message">
                                    <div class="title"><span class="tit "><a href="https://blog.csdn.net/weixin_41765699" data-report-click="{&quot;mod&quot;:&quot;popu_379&quot;,&quot;ab&quot;:&quot;new&quot;}" target="_blank">linbior</a></span>
                    <!-- 等级,level -->
                                            <img class="identity-icon" src="https://csdnimg.cn/identity/blog5.png">                                            </div>
                <div class="text"><span>原创文章 46</span><span>获赞 80</span><span>访问量 14万+</span></div>
            </div>
                            <div class="right-message">
                                        <a class="btn btn-sm  bt-button personal-watch" data-report-click="{&quot;mod&quot;:&quot;popu_379&quot;,&quot;ab&quot;:&quot;new&quot;}">关注</a>
                                                            <a href="https://im.csdn.net/im/main.html?userName=weixin_41765699" target="_blank" class="btn btn-sm bt-button personal-letter">私信
                    </a>
                                </div>
                        </div>
                    
    </div>