pytorch 多卡训练碰到的问题及解决方案
程序员文章站
2024-03-24 10:55:04
...
1、多卡训练后模型命名多了'module.' 这样在读取模型的时候需要添加关键字名
下面用一个模型读取的函数举例,核心部分是'changed'部分
def load_params_from_file_v2(model, filename, opt=None, to_cpu=False):
if not os.path.isfile(filename):
raise FileNotFoundError
print('==> Loading parameters from checkpoint %s to %s' % (filename, 'CPU' if to_cpu else 'GPU'))
loc_type = torch.device('cpu') if to_cpu else None
params = list(model.named_parameters())
checkpoint = torch.load(filename, map_location=loc_type)
model_state_disk = checkpoint['model_state']
if 'version' in checkpoint:
print('==> Checkpoint trained from version: %s' % checkpoint['version'])
update_model_state = {}
#print(model.state_dict())
for key, val in model_state_disk.items():
###########changed####################
val = model_state_disk[key]
key = 'module.' + key
update_model_state[key] = val
###########changed#######################
state_dict = model.state_dict()
state_dict.update(update_model_state)
#print(model)
params = list(model.named_parameters())
#print(params[0])
model.load_state_dict(state_dict)
for key in state_dict:
if key not in update_model_state:
print('Not updated weight %s: %s' % (key, str(state_dict[key].shape)))
if opt is not None:
opt_state_disk = checkpoint['optimizer_state']
opt.load_state_dict(opt_state_disk)
return checkpoint['epoch']
2、多卡训练,显存占用问题
问题描述:当从零开始训练4卡程序时,以GTX1080为例,主卡占用10G显存,其余卡占用8G显存。而中断后load模型却能压爆显存,导致无法训练。
这个是因为load模型方式出了问题
多卡训练中断后需要先load单卡训练时的模型,再调用分布式训练,这样显存和直接从零开始训练是一样的。
1)load model 使用 load_params_from_file(见下面代码)
调用:load_params_from_file(model, cfg.resume_from, opt=optimizer, to_cpu=False , rank=torch.cuda.current_device())
2)model = MMDistributedDataParallel(model,device_ids=[torch.cuda.current_device()],broadcast_buffers=False)
load 模型的代码:
def load_params_from_file(model, filename, opt=None, to_cpu=False, rank='0'):
if not os.path.isfile(filename):
raise FileNotFoundError
print('==> Loading parameters from checkpoint %s to %s' % (filename, 'CPU' if to_cpu else 'GPU'))
loc_type = torch.device('cpu') if to_cpu else None
params = list(model.named_parameters())
##############changed###############
checkpoint = torch.load(filename, map_location='cuda:{}'.format(rank))
##############changed###############
model_state_disk = checkpoint['model_state']
if 'version' in checkpoint:
print('==> Checkpoint trained from version: %s' % checkpoint['version'])
update_model_state = {}
for key, val in model_state_disk.items():
###########changed####################
val = model_state_disk[key]
#key = 'module.' + key
update_model_state[key] = val
state_dict = model.state_dict()
state_dict.update(update_model_state)
params = list(model.named_parameters())
model.load_state_dict(state_dict)
for key in state_dict:
if key not in update_model_state:
print('Not updated weight %s: %s' % (key, str(state_dict[key].shape)))
print('==> Done (loaded %d/%d)' % (len(update_model_state), len(model.state_dict())))
if opt is not None:
opt_state_disk = checkpoint['optimizer_state']
opt.load_state_dict(opt_state_disk)
return checkpoint['epoch']
3、多卡问题,训练分类时存在无样本的情况导致训练中断
model = MMDistributedDataParallel(model,device_ids=[torch.cuda.current_device()],broadcast_buffers=False, find_unused_parameters=True)
可以对比问题2中的命令,发现多了find_unused_parameters=True