欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

Pytorch Optimization

程序员文章站 2024-03-19 09:29:22
...

Optimizer

torch.optim

每个 optimizer 中有一个 param_groups 维护一组参数更新,其中包含了诸如学习率之类的超参数。通过访问 pprint(opt.param_group)可以查看或者修改

[
 {'dampening': 0,
  'lr': 0.01,
  'momentum': 0,
  'nesterov': False,
  'params': [Parameter containing:
                tensor([[-0.4239,  0.2810,  0.3866],
                        [ 0.1081, -0.3685,  0.4922],
                        [ 0.1043,  0.5353, -0.1368],
                        [ 0.5171,  0.3946, -0.3541],
                        [ 0.2255,  0.4731, -0.4114]], requires_grad=True),
                            Parameter containing:
                tensor([ 0.3145, -0.5053, -0.1401, -0.1902, -0.5681], requires_grad=True)],
  'weight_decay': 0},
 {'dampening': 0,
  'lr': 0.01,
  'momentum': 0,
  'nesterov': False,
  'params': [Parameter containing:
                tensor([[[[ 0.0476,  0.2790],
                        [ 0.0285, -0.1737]],

                        [[-0.0268,  0.2334],
                        [-0.0095, -0.1972]],

                        [[-0.0309,  0.0752],
                        [-0.1166, -0.1442]]],


                        [[[ 0.2219, -0.1128],
                        [ 0.1363,  0.0779]],

                        [[-0.1370, -0.0915],
                        [ 0.0588, -0.0528]],

                        [[ 0.0544,  0.2210],
                        [ 0.2658, -0.2197]]],


                        [[[ 0.0621,  0.2371],
                        [-0.1248, -0.1972]],

                        [[-0.0829, -0.1541],
                        [ 0.2709,  0.0952]],

                        [[-0.1588, -0.1018],
                        [ 0.2712,  0.2416]]]], requires_grad=True),
                            Parameter containing:
                tensor([ 0.0690, -0.2328, -0.0965], requires_grad=True)],
  'weight_decay': 0}
]

每一组 param_group 有不同的参数,对应模型中不同的 Parameter

  • 基本操作

add_param_group 用于添加新的参数。API

传入应该是一个字典,类似于 optim.param_group 列表中的一个元素,包含了诸如 params, lr 等参数。如果没有则用初始化 Optimizer 时的默认参数代替

更新所有参数组。

更新时满足两个条件:

  1. 参数是有 grad (requires_grad = True)。
  2. optimizerparam_group 内。

Pytorch Lightning

关于手动执行参数更新参考文档。关于配置 Lightning_Module 优化器参考此处

如果需要手动执行参数更新。Set self.automatic_optimization=False in your LightningModule’s __init__

Use the following functions and call them manually:

  1. self.optimizers() to access your optimizers (one or multiple). Use self.lr_schedulers() to access your schedulers.

  2. optimizer.zero_grad() to clear the gradients from the previous training step

  3. self.manual_backward(loss) instead of loss.backward()

  4. optimizer.step() to update your model parameters. scheduler.step() to schedule your learning rate.

官方示例

from pytorch_lightning import LightningModule

class MyModel(LightningModule):

    def __init__(self):
        super().__init__()
        # Important: This property activates manual optimization.
        self.automatic_optimization = False

    def training_step(self, batch, batch_idx):
        opt = self.optimizers()
        loss = self.compute_loss(batch)
        self.manual_backward(loss)
        
        # accumulate gradients of `n` batches
	    if (batch_idx + 1) % n == 0:
	        opt.step()
	        opt.zero_grad()