MXNET深度学习框架-13-读写存模型
程序员文章站
2022-03-06 10:00:20
...
如果一个深度网络(几十层的网络)在训练时出现突然断电,内存溢出,电脑蓝屏等情况,是不是会很抓狂?所以模型存储就变得很重要。本章我们在mxnet下学习如何进行模型存储与读写。
1、读写NDArray
(1) NDArray是mxnet中的一个科学计算库,下面我们来实现以下怎么存储NDArray的参数:
x = nd.ones(shape=3)
y = nd.ones(shape=2)
filename1 = "13-模型存储/test1_array.params"
nd.save(filename1, [x, y])
结果:
可以看到,参数确实已经被保存在了该文件夹下。
不仅仅是NDArray,dict也是一样的:
mydict={"x":x,"y":y}
filename2 = "13-模型存储/test2_dict.params"
nd.save(filename2, mydict)
运行结果:
(2) 读模型
a, b = nd.load(filename1) # 模型保存路径和名字
c=nd.load(filename2)
print(a, b,c)
结果:
(3) 神经网络模型存储与读写
# 随便定义一个网络
def get_net():
net=gn.nn.Sequential() #nn.block
with net.name_scope():
net.add(gn.nn.Dense(10,activation="relu"))
net.add(gn.nn.Dense(2))
return net
net1=get_net()
net1.initialize()
x=nd.random_normal(shape=(3,10),scale=0.01) # 输入x
print(net1(x))
# 下面把模型参数存起来
net_filename="13-模型存储/mlp.params"
net.save_parameters(net_filename) # 存模型参数
结果:
下面把模型读出来:
net2=get_net() # 重新加载一个网络
net2.load_parameters(net_filename)
print(net2(x))
结果:
从上图可知,net1和net2的结果是一样的。
附上所有源码:
import mxnet.ndarray as nd
import mxnet.gluon as gn
# '''---模型存储---'''
# x = nd.ones(shape=3)
# y = nd.ones(shape=2)
# filename1 = "13-模型存储/test1_array.params"
# nd.save(filename1, [x, y])
# # 不仅仅是NDArray,dict也是一样的
# mydict={"x":x,"y":y}
# filename2 = "13-模型存储/test2_dict.params"
# nd.save(filename2, mydict)
#
# '''---读取模型---'''
# a, b = nd.load(filename1) # 模型保存路径和名字
# c=nd.load(filename2)
# print(a, b,c)
'''---读写gluon的参数---'''
# 随便定义一个网络
def get_net():
net=gn.nn.Sequential() #nn.block
with net.name_scope():
net.add(gn.nn.Dense(10,activation="relu"))
net.add(gn.nn.Dense(2))
return net
net1=get_net()
net1.initialize()
x=nd.random_normal(shape=(3,10),scale=0.01) # 输入x
print(net1(x))
# 下面把模型参数存起来
net_filename="13-模型存储/mlp.params"
# net.save_parameters(net_filename) # 存模型参数
# 下面把模型读出来
net2=get_net() # 重新加载一个网络
net2.load_parameters(net_filename)
print(net2(x))