pytorch中lr_scheduler的使用
程序员文章站
2022-04-11 08:02:57
...
torch.optim.lr_scheduler.StepLR
- 代码
import torch
import torch.optim as optim
from torch.optim import lr_scheduler
from torchvision.models import AlexNet
import matplotlib.pyplot as plt
model = AlexNet(num_classes=2)
optimizer = optim.SGD(params=model.parameters(), lr=0.05)
# lr_scheduler.StepLR()
# Assuming optimizer uses lr = 0.05 for all groups
# lr = 0.05 if epoch < 30
# lr = 0.005 if 30 <= epoch < 60
# lr = 0.0005 if 60 <= epoch < 90
scheduler = lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1)
plt.figure()
x = list(range(100))
y = []
for epoch in range(100):
scheduler.step()
lr = scheduler.get_lr()
print(epoch, scheduler.get_lr()[0])
y.append(scheduler.get_lr()[0])
plt.plot(x, y)
0<epoch<30, lr = 0.05
30<=epoch<60, lr = 0.005
60<=epoch<90, lr = 0.0005
torch.optim.lr_scheduler.MultiStepLR
与StepLR
相比,MultiStepLR
可以设置指定的区间
- 代码
# ---------------------------------------------------------------
# 可以指定区间
# lr_scheduler.MultiStepLR()
# Assuming optimizer uses lr = 0.05 for all groups
# lr = 0.05 if epoch < 30
# lr = 0.005 if 30 <= epoch < 80
# lr = 0.0005 if epoch >= 80
print()
plt.figure()
y.clear()
scheduler = lr_scheduler.MultiStepLR(optimizer, [30, 80], 0.1)
for epoch in range(100):
scheduler.step()
print(epoch, 'lr={:.6f}'.format(scheduler.get_lr()[0]))
y.append(scheduler.get_lr()[0])
plt.plot(x, y)
plt.show()
torch.optim.lr_scheduler.ExponentialLR
指数衰减
- 代码
scheduler = lr_scheduler.ExponentialLR(optimizer, gamma=0.9)
print()
plt.figure()
y.clear()
for epoch in range(100):
scheduler.step()
print(epoch, 'lr={:.6f}'.format(scheduler.get_lr()[0]))
y.append(scheduler.get_lr()[0])
plt.plot(x, y)
plt.show()
上一篇: 非关系型数据库Redis
下一篇: Lca相关算法