欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

PyTorch模型持久化(模型的保存和加载)

程序员文章站 2022-07-06 11:04:29
...
#PyTorch保存整个模型和保存模型的参数
torch.save(mlpl,"data/mlpl.pkl")#保存整个模型
mlplload=torch.load("data/mlpl.pkl")#导入保存的模型
print(mlplload)
#只保存模型的参数 mlpl.state_dict()获取网络中已经训练好的参数
torch.save(mlpl.state_dict(),"data/mlpl_params.pkl")
mlpl_params=torch.load("data/mlpl_params.pkl")
print(mlpl_params)

PyTorch模型持久化(模型的保存和加载)