欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

how to save a pytorch model and reload it from the pickle file

程序员文章站 2023-12-30 16:21:34
...
If you want to train a neural network in a GPU machine,but using it in a working scenario that the machine don't the GPU or TPU. So you can't train a neural network from scratch.Then here is a good choice I recommend to you.
import torch 
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
from torch.autograd import Variable

# this is making fake data 

class Net(nn.Module):
    def __init__(self,n_input,n_hidden,n_output):
        super(Net,self).__init__()
        self.hidden=nn.Linear(n_input,n_hidden)
        self.predict=nn.Linear(n_hidden,n_output)
    def forward(self,x):
        x=self.hidden(x)
        x=F.relu(x)
        x=self.predict(x)
        return x
# the first step is to initialize the parameters of the network
# namely the network's neuron number of each layer

net1=Net(1,10,1)
        

net2=nn.Sequential(
    nn.Linear(1,10),
    nn.ReLU(),
    nn.Linear(10,1)
)

x = torch.unsqueeze(torch.linspace(-1, 1, 100), dim=1)  # x data (tensor), shape=(100, 1)
y = x.pow(2) + 0.2*torch.rand(x.size())  # noisy y data (tensor), shape=(100, 1)



def save():
    
    optimizer=torch.optim.SGD(net1.parameters(),lr=0.5)
    loss_func=nn.MSELoss()
    
    for t in range(1000):
        prediction=net1(x)
        loss=loss_func(prediction,y)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
#         if t%100==0:
#             print(loss.data.numpy())
            
    plt.figure(1,figsize=(10, 3))
    plt.subplot(131)
    plt.title("Net1")
    plt.scatter(x.data.numpy(),y.data.numpy())
    plt.plot(x.data.numpy(),prediction.data.numpy(),'r-',lw=5)
    
    torch.save(net1,"net1.pkl")#save the entire network
    torch.save(net1.state_dict(),"net_params.pkl")#save the only parameters of the network.
    
def restore_net():
#     restore the entire network net1 to net3
    net3=torch.load('net1.pkl')
    prediction=net3(x)
    
#     plot result
    plt.subplot(132)
    plt.title("Net2")
    plt.scatter(x.data.numpy(),y.data.numpy())
    plt.plot(x.data.numpy(),prediction.data.numpy(),'r-',lw=5)

    
def restore_param():
#     if you want to load the net parameters,u should have a network at first.
    class Net(nn.Module):
        def __init__(self,n_input,n_hidden,n_output):
            super(Net,self).__init__()
            self.hidden=nn.Linear(n_input,n_hidden)
            self.predict=nn.Linear(n_hidden,n_output)
        def forward(self,x):
            x=F.relu(self.hidden(x))
            x=self.predict(x)
            return x
    net4=Net(1,10,1)
    #the net is the same with the net defined above.    
    net4.load_state_dict(torch.load("net_params.pkl"))
    plt.subplot(133)
    plt.title("Net4")
    prediction=net4(x)
    plt.scatter(x.data.numpy(),y.data.numpy())
    plt.plot(x.data.numpy(),prediction.data.numpy(),'r-',lw=5)
    plt.show()
    
save()
restore_net()
restore_param()

then you could realize ones train, run everywhere.

相关标签: some_new_tricks

上一篇:

下一篇: