欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

MXNET深度学习框架-13-读写存模型

程序员文章站 2022-03-06 10:00:20
...

        如果一个深度网络(几十层的网络)在训练时出现突然断电,内存溢出,电脑蓝屏等情况,是不是会很抓狂?所以模型存储就变得很重要。本章我们在mxnet下学习如何进行模型存储与读写。

1、读写NDArray
        (1) NDArray是mxnet中的一个科学计算库,下面我们来实现以下怎么存储NDArray的参数:

x = nd.ones(shape=3)
y = nd.ones(shape=2)
filename1 = "13-模型存储/test1_array.params"
nd.save(filename1, [x, y])

结果:
MXNET深度学习框架-13-读写存模型
可以看到,参数确实已经被保存在了该文件夹下。
        不仅仅是NDArray,dict也是一样的:

mydict={"x":x,"y":y}
filename2 = "13-模型存储/test2_dict.params"
nd.save(filename2, mydict)

运行结果:
MXNET深度学习框架-13-读写存模型
        (2) 读模型

a, b = nd.load(filename1)  # 模型保存路径和名字
c=nd.load(filename2)
print(a, b,c)

结果:
MXNET深度学习框架-13-读写存模型
        (3) 神经网络模型存储与读写

# 随便定义一个网络
def get_net():
    net=gn.nn.Sequential()  #nn.block
    with net.name_scope():
        net.add(gn.nn.Dense(10,activation="relu"))
        net.add(gn.nn.Dense(2))
    return net
net1=get_net()
net1.initialize()
x=nd.random_normal(shape=(3,10),scale=0.01)  # 输入x
print(net1(x))
# 下面把模型参数存起来
net_filename="13-模型存储/mlp.params"
net.save_parameters(net_filename) # 存模型参数

结果:
MXNET深度学习框架-13-读写存模型
MXNET深度学习框架-13-读写存模型
下面把模型读出来:

net2=get_net() # 重新加载一个网络
net2.load_parameters(net_filename)
print(net2(x))

结果:
MXNET深度学习框架-13-读写存模型
从上图可知,net1和net2的结果是一样的。

附上所有源码:

import mxnet.ndarray as nd
import mxnet.gluon as gn


# '''---模型存储---'''
# x = nd.ones(shape=3)
# y = nd.ones(shape=2)
# filename1 = "13-模型存储/test1_array.params"
# nd.save(filename1, [x, y])
# # 不仅仅是NDArray,dict也是一样的
# mydict={"x":x,"y":y}
# filename2 = "13-模型存储/test2_dict.params"
# nd.save(filename2, mydict)
#
# '''---读取模型---'''
# a, b = nd.load(filename1)  # 模型保存路径和名字
# c=nd.load(filename2)
# print(a, b,c)

'''---读写gluon的参数---'''
# 随便定义一个网络
def get_net():
    net=gn.nn.Sequential()  #nn.block
    with net.name_scope():
        net.add(gn.nn.Dense(10,activation="relu"))
        net.add(gn.nn.Dense(2))
    return net
net1=get_net()
net1.initialize()
x=nd.random_normal(shape=(3,10),scale=0.01)  # 输入x
print(net1(x))
# 下面把模型参数存起来
net_filename="13-模型存储/mlp.params"
# net.save_parameters(net_filename) # 存模型参数
# 下面把模型读出来
net2=get_net() # 重新加载一个网络
net2.load_parameters(net_filename)
print(net2(x))