PyTorch学习笔记(6)逻辑回顾LR
程序员文章站
2022-07-14 20:17:49
...
线性回归是分析自变量x与因变量y(标量)之间关系的方法
逻辑回归是分析自变量x与因变量y(概率)之间关系的方法
机器学习模型训练步骤
数据 模型 损失函数 优化器 迭代训练过程
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import numpy as np
# 生成数据
sample_nums = 100
mean_value = 1.7
bias = 1
n_data = torch.ones(sample_nums,2)
# 使用torch.normal 正太分布 生成 自变量 x
x0 = torch.normal(mean_value * n_data ,1)+bias #类别0 数据 shape = (100,2)
y0 = torch.zeros(sample_nums) #类别0 标签 shape = (100,1)
x1 = torch.normal(-mean_value * n_data ,1)+bias #类别1 数据 shape = (100,2)
y1 = torch.ones(sample_nums) #类别1 标签 shape = (100.1)
# 用cat方法 对自变量 和 因变量进行拼接
train_x = torch.cat((x0,x1),0)
train_y = torch.cat((y0,y1),0)
# 选择模型
class LR(nn.Module):
def __init__(self):
super(LR, self).__init__()
self.features = nn.Linear(2,1)
self.sigmoid = nn.Sigmoid()
# 逻辑回归的前向函数
def forward(self,x):
x = self.features(x)
x = self.sigmoid(x)
return x
# 实例化逻辑回归模型
lr_net = LR()
# 选择损失函数
# 二分类中的交叉熵函数
loss_fn = nn.BCELoss()
# 选择优化器
# 随机梯度下降的方法
lr = 0.01 #学习率
optimizer = torch.optim.SGD(lr_net.parameters(),lr=lr,momentum=0.9)
# 模型的训练
# 训练迭代跟新的过程
for iteration in range(1000):
#前向传播
y_pred = lr_net(train_x)
# 计算loss
loss = loss_fn(y_pred.squeeze(),train_y)
# 反向传播
loss.backward()
# 更新参数
optimizer.step()
# 绘图
# 迭代20次画一次图
if iteration %20 == 0:
mask = y_pred.ge(0.5).float().squeeze() #以0.5作为阈值进行分类
correct = (mask == train_y).sum() #计算正确预测的样本数
acc = correct.item() / train_y.size(0) #计算分类准确率
plt.scatter(x0.data.numpy()[:,0],x0.data.numpy()[:,1],c = 'r',label = 'class 0 ')
plt.scatter(x1.data.numpy()[:,0],x1.data.numpy()[:,1],c = 'r',label = 'class 1 ')
w0,w1 = lr_net.features.weight[0]
w0,w1 = float(w0.item()),float(w1.item())
plot_b = float(lr_net.features.bias[0].item())
plot_x = np.arange(-6,6,0.1)
plot_y = (-w0 * plot_x - plot_b) / w1
plt.xlim(-5,7)
plt.ylim(-7,7)
plt.plot(plot_x,plot_y)
plt.text(-5,5,'Loss = %.4f' % loss.data.numpy(),fontdict={'size':20,'color':'red'})
plt.title("Iteration:{}\n w0:{:.2f} w1{:.2f} b:{:.2f} accuracy:{:.2%}".format(iteration,w0,w1,plot_b,acc))
plt.legend()
plt.show()
plt.pause(0.5)
if acc>0.99:
break
结果
上一篇: R语言:异常数据处理
下一篇: 数据结构与算法分析笔记02:链表