欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

PyTorch学习笔记(19)优化器(二)

程序员文章站 2022-07-12 23:15:32
...

学习率

梯度下降
wi+1=wig(wi) \boldsymbol{w}_{\boldsymbol{i}+\mathbf{1}}=\boldsymbol{w}_{\boldsymbol{i}}-\boldsymbol{g}\left(\boldsymbol{w}_{\boldsymbol{i}}\right)

wi+1=wiLRg(wi)\boldsymbol{w}_{i+1}=\boldsymbol{w}_{i}-\mathrm{LR} * \boldsymbol{g}\left(\boldsymbol{w}_{i}\right)
学习率(learning rate) 控制更新的步伐

Momentum (动量,冲量):结合当前梯度与上一次更新信息,用于当前更新
指数加权平均:
vt=βvt1+(1β)θt \mathrm{v}_{t}=\boldsymbol{\beta} * \boldsymbol{v}_{t-1}+(\mathbf{1}-\boldsymbol{\beta}) * \boldsymbol{\theta}_{t}

优化器 optim.SGD

主要参数:
params 管理的参数组
lr 初始学习率
momentum 动量系数 贝塔
weight_decay L2正则化系数
nesterov 是否采用NAG 通常是False



# -*- coding:utf-8 -*-
import torch
import numpy as np
import matplotlib.pyplot as plt
torch.manual_seed(1)


def func(x_t):
    """
    y = (2x)^2 = 4*x^2      dy/dx = 8x
    """
    return torch.pow(2*x_t, 2)


# init
x = torch.tensor([2.], requires_grad=True)


# ------------------------------ plot data ------------------------------
flag = 0
# flag = 1
if flag:

    x_t = torch.linspace(-3, 3, 100)
    y = func(x_t)
    plt.plot(x_t.numpy(), y.numpy(), label="y = 4*x^2")
    plt.grid()
    plt.xlabel("x")
    plt.ylabel("y")
    plt.legend()
    plt.show()


# ------------------------------ gradient descent ------------------------------
flag = 0
# flag = 1
if flag:
    iter_rec, loss_rec, x_rec = list(), list(), list()

    lr = 0.01    # /1. /.5 /.2 /.1 /.125
    max_iteration = 20   # /1. 4     /.5 4   /.2 20 200

    for i in range(max_iteration):

        y = func(x)
        y.backward()

        print("Iter:{}, X:{:8}, X.grad:{:8}, loss:{:10}".format(
            i, x.detach().numpy()[0], x.grad.detach().numpy()[0], y.item()))

        x_rec.append(x.item())

        x.data.sub_(lr * x.grad)    # x -= x.grad  数学表达式意义:  x = x - x.grad    # 0.5 0.2 0.1 0.125
        x.grad.zero_()

        iter_rec.append(i)
        loss_rec.append(y)

    plt.subplot(121).plot(iter_rec, loss_rec, '-ro')
    plt.xlabel("Iteration")
    plt.ylabel("Loss value")

    x_t = torch.linspace(-3, 3, 100)
    y = func(x_t)
    plt.subplot(122).plot(x_t.numpy(), y.numpy(), label="y = 4*x^2")
    plt.grid()
    y_rec = [func(torch.tensor(i)).item() for i in x_rec]
    plt.subplot(122).plot(x_rec, y_rec, '-ro')
    plt.legend()
    plt.show()

# ------------------------------ multi learning rate ------------------------------

# flag = 0
flag = 1
if flag:
    iteration = 100
    num_lr = 10
    lr_min, lr_max = 0.01, 0.2  # .5 .3 .2

    lr_list = np.linspace(lr_min, lr_max, num=num_lr).tolist()
    loss_rec = [[] for l in range(len(lr_list))]
    iter_rec = list()

    for i, lr in enumerate(lr_list):
        x = torch.tensor([2.], requires_grad=True)
        for iter in range(iteration):

            y = func(x)
            y.backward()
            x.data.sub_(lr * x.grad)  # x.data -= x.grad
            x.grad.zero_()

            loss_rec[i].append(y.item())

    for i, loss_r in enumerate(loss_rec):
        plt.plot(range(len(loss_r)), loss_r, label="LR: {}".format(lr_list[i]))
    plt.legend()
    plt.xlabel('Iterations')
    plt.ylabel('Loss value')
    plt.show()