模型的保存和加载(pickle)
程序员文章站
2022-07-06 11:04:23
...
模型的保存和加载
前言:
模型训练是一个耗时的过程,一个优秀的机器学习是非常宝贵的。可以模型保存到磁盘中,也可以在需要使用的时候从磁盘中重新加载模型即可。不需要重新训练。
模型保存和加载相关API:
import pickle
pickle.dump(内存对象, 磁盘文件) # 保存模型
model = pickle.load(磁盘文件) # 加载模型
案例:把训练好的模型保存到磁盘中。
注:这里的训练数据就是一些数据而已下面是txt文档里面的截图,这里的内容手动放入到数组当中也可以
import numpy as np
import sklearn.linear_model as lm
import pickle
# 采集数据(这里我是从文件中读取的数据,也可以手动输入一组数字)
x, y = np.loadtxt('../ml_data/single.txt', delimiter=',', usecols=(0,1), unpack=True)
# 讲x改变为n行一列
x = x.reshape(-1, 1)
# 创建模型
model = lm.LinearRegression()
# 训练模型
model.fit(x,y)
with open('linear.pkl','wb') as f:
pickle.dump(model,f)
print('dump sucess')
案例: 加载训练好的模型进行线性回归线的绘制
import numpy as np
import pickle
import matplotlib.pyplot as mp
# 采集数据
x, y = np.loadtxt('../ml_data/single.txt', delimiter=',', usecols=(0,1), unpack=True)
x = x.reshape(-1, 1)
with open('linear.pkl','rb') as f:
model = pickle.load(f)
# 进行预测
pred_y = model.predict(x)
mp.figure('Linear Regression', facecolor='lightgray')
mp.title('Linear Regression', fontsize=20)
mp.xlabel('x', fontsize=14)
mp.ylabel('y', fontsize=14)
mp.tick_params(labelsize=10)
mp.grid(linestyle=':')
mp.scatter(x, y, c='dodgerblue', alpha=0.75, s=60, label='Sample')
mp.plot(x, pred_y, c='orangered', label='Regression')
mp.legend()
mp.show()