莫烦PyTorch学习笔记(五)——模型的存取
程序员文章站
2022-07-06 14:12:20
...
本文主要是介绍如何对训练好的神经网络模型进行存取。
编辑器:spyder
1.快速搭建神经网络
这里采用上一节介绍的方法快速搭建一个小的神经网络:
def save():
# save net1
net1 = torch.nn.Sequential(
torch.nn.Linear(1, 10),
torch.nn.ReLU(),
torch.nn.Linear(10, 1)
)
optimizer = torch.optim.SGD(net1.parameters(), lr=0.5)
loss_func = torch.nn.MSELoss()
for t in range(100):
prediction = net1(x)
loss = loss_func(prediction, y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
2.模型的存储
模型的存储十分简单,这里我们提供两种存储方式:
torch.save(net1, 'net.pkl') # 保存整个网络,包括整个计算图
torch.save(net1.state_dict(), 'net_params.pkl') # 只保存网络中的参数 (速度快, 占内存少)
net1.state_dict()
的方式可以获得当前模型的参数。两种方式模型的保存类型都是.pkl
。
3.模型的提取
与模型的存储类似,模型的提取也有两种方式,下面将分别进行介绍。
- 提取整个网络
def restore_net():
# restore entire net1 to net2
net2 = torch.load('net.pkl')
prediction = net2(x)
这种方式会获得原来的整个神经网络的参数及计算图,网络大的时候加载速度比较慢。
- 只提取网络参数
def restore_params():
# 新建 net3
net3 = torch.nn.Sequential(
torch.nn.Linear(1, 10),
torch.nn.ReLU(),
torch.nn.Linear(10, 1)
)
# 将保存的参数复制到 net3
net3.load_state_dict(torch.load('net_params.pkl'))
prediction = net3(x)
这种方法速度快,但是需要建立一个与原来网络结构相同的新网络,才能进行参数的复制。
4.网络提取后的测试效果
毫无疑问,三个网络是一模一样的。