欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

pytorch 多卡训练碰到的问题及解决方案

程序员文章站 2024-03-24 10:55:04
...

1、多卡训练后模型命名多了'module.' 这样在读取模型的时候需要添加关键字名

下面用一个模型读取的函数举例,核心部分是'changed'部分

def load_params_from_file_v2(model, filename, opt=None, to_cpu=False):
       if not os.path.isfile(filename):
           raise FileNotFoundError

       print('==> Loading parameters from checkpoint %s to %s' % (filename, 'CPU' if to_cpu else 'GPU'))
       loc_type = torch.device('cpu') if to_cpu else None

       params = list(model.named_parameters())

       checkpoint = torch.load(filename, map_location=loc_type)
       model_state_disk = checkpoint['model_state']

       if 'version' in checkpoint:
           print('==> Checkpoint trained from version: %s' % checkpoint['version'])

       update_model_state = {}
       #print(model.state_dict())
       for key, val in model_state_disk.items():
           ###########changed####################
           val = model_state_disk[key]
           key = 'module.' + key
           update_model_state[key] = val
           ###########changed#######################

       state_dict = model.state_dict()
       state_dict.update(update_model_state)
       #print(model)
       params = list(model.named_parameters())
       #print(params[0])
       model.load_state_dict(state_dict)


       for key in state_dict:
           if key not in update_model_state:
               print('Not updated weight %s: %s' % (key, str(state_dict[key].shape)))

       if opt is not None:
           opt_state_disk = checkpoint['optimizer_state']
           opt.load_state_dict(opt_state_disk)
       return checkpoint['epoch']

2、多卡训练,显存占用问题

问题描述:当从零开始训练4卡程序时,以GTX1080为例,主卡占用10G显存,其余卡占用8G显存。而中断后load模型却能压爆显存,导致无法训练。

这个是因为load模型方式出了问题

多卡训练中断后需要先load单卡训练时的模型,再调用分布式训练,这样显存和直接从零开始训练是一样的。

1)load model  使用 load_params_from_file(见下面代码)

调用:load_params_from_file(model, cfg.resume_from, opt=optimizer, to_cpu=False , rank=torch.cuda.current_device())

2)model = MMDistributedDataParallel(model,device_ids=[torch.cuda.current_device()],broadcast_buffers=False)

load 模型的代码:

def load_params_from_file(model, filename, opt=None, to_cpu=False, rank='0'):
    if not os.path.isfile(filename):
           raise FileNotFoundError

       print('==> Loading parameters from checkpoint %s to %s' % (filename, 'CPU' if to_cpu else 'GPU'))
       loc_type = torch.device('cpu') if to_cpu else None


       params = list(model.named_parameters())
 
       ##############changed###############
       checkpoint = torch.load(filename, map_location='cuda:{}'.format(rank))
       ##############changed###############
       model_state_disk = checkpoint['model_state']

       if 'version' in checkpoint:
           print('==> Checkpoint trained from version: %s' % checkpoint['version'])

       update_model_state = {}
       for key, val in model_state_disk.items():
           ###########changed####################
           val = model_state_disk[key]
           #key = 'module.' + key
           update_model_state[key] = val

       state_dict = model.state_dict()
       state_dict.update(update_model_state)
       params = list(model.named_parameters())
       model.load_state_dict(state_dict)

       for key in state_dict:
           if key not in update_model_state:
               print('Not updated weight %s: %s' % (key, str(state_dict[key].shape)))

       print('==> Done (loaded %d/%d)' % (len(update_model_state), len(model.state_dict())))


       if opt is not None:
           opt_state_disk = checkpoint['optimizer_state']
           opt.load_state_dict(opt_state_disk)
       return checkpoint['epoch']

 

 

3、多卡问题,训练分类时存在无样本的情况导致训练中断

model = MMDistributedDataParallel(model,device_ids=[torch.cuda.current_device()],broadcast_buffers=False, find_unused_parameters=True)

可以对比问题2中的命令,发现多了find_unused_parameters=True