how to save a pytorch model and reload it from the pickle file
程序员文章站
2023-12-30 16:21:34
...
If you want to train a neural network in a GPU machine,but using it in a working scenario that the machine don't the GPU or TPU. So you can't train a neural network from scratch.Then here is a good choice I recommend to you.
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
from torch.autograd import Variable
# this is making fake data
class Net(nn.Module):
def __init__(self,n_input,n_hidden,n_output):
super(Net,self).__init__()
self.hidden=nn.Linear(n_input,n_hidden)
self.predict=nn.Linear(n_hidden,n_output)
def forward(self,x):
x=self.hidden(x)
x=F.relu(x)
x=self.predict(x)
return x
# the first step is to initialize the parameters of the network
# namely the network's neuron number of each layer
net1=Net(1,10,1)
net2=nn.Sequential(
nn.Linear(1,10),
nn.ReLU(),
nn.Linear(10,1)
)
x = torch.unsqueeze(torch.linspace(-1, 1, 100), dim=1) # x data (tensor), shape=(100, 1)
y = x.pow(2) + 0.2*torch.rand(x.size()) # noisy y data (tensor), shape=(100, 1)
def save():
optimizer=torch.optim.SGD(net1.parameters(),lr=0.5)
loss_func=nn.MSELoss()
for t in range(1000):
prediction=net1(x)
loss=loss_func(prediction,y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
# if t%100==0:
# print(loss.data.numpy())
plt.figure(1,figsize=(10, 3))
plt.subplot(131)
plt.title("Net1")
plt.scatter(x.data.numpy(),y.data.numpy())
plt.plot(x.data.numpy(),prediction.data.numpy(),'r-',lw=5)
torch.save(net1,"net1.pkl")#save the entire network
torch.save(net1.state_dict(),"net_params.pkl")#save the only parameters of the network.
def restore_net():
# restore the entire network net1 to net3
net3=torch.load('net1.pkl')
prediction=net3(x)
# plot result
plt.subplot(132)
plt.title("Net2")
plt.scatter(x.data.numpy(),y.data.numpy())
plt.plot(x.data.numpy(),prediction.data.numpy(),'r-',lw=5)
def restore_param():
# if you want to load the net parameters,u should have a network at first.
class Net(nn.Module):
def __init__(self,n_input,n_hidden,n_output):
super(Net,self).__init__()
self.hidden=nn.Linear(n_input,n_hidden)
self.predict=nn.Linear(n_hidden,n_output)
def forward(self,x):
x=F.relu(self.hidden(x))
x=self.predict(x)
return x
net4=Net(1,10,1)
#the net is the same with the net defined above.
net4.load_state_dict(torch.load("net_params.pkl"))
plt.subplot(133)
plt.title("Net4")
prediction=net4(x)
plt.scatter(x.data.numpy(),y.data.numpy())
plt.plot(x.data.numpy(),prediction.data.numpy(),'r-',lw=5)
plt.show()
save()
restore_net()
restore_param()
then you could realize ones train, run everywhere.