欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

pytorch纯手工实现logistic回归

程序员文章站 2024-03-18 22:50:46
...

pytorch纯手工实现logistic回归

import torch
import matplotlib.pyplot as plt
import numpy as np
n_data = torch.ones(50, 2)
x1 = torch.normal(2 * n_data, 1)
y1 = torch.zeros(50)
x2 = torch.normal(-2 * n_data, 1)
y2 = torch.ones(50)

x = torch.cat((x1, x2), 0).type(torch.FloatTensor)
y = torch.cat((y1, y2), 0).type(torch.FloatTensor)
# 人工构造的数据集
# plt.scatter(x.data.numpy()[:, 0], x.data.numpy()[:, 1], c=y.data.numpy(), s=100, lw=0, cmap='RdYlGn')
# plt.show()
# 首先通过torch.nn进行实现
import torch.utils.data as Data
from torch import nn
def data_iter(batch_size,features,labels):
    num_examples=len(features)
    indices=list(range(num_examples))
    np.random.shuffle(indices)
    for i in range(0,num_examples,batch_size):
        j=torch.LongTensor(indices[i:min(i+batch_size,num_examples)])
        yield features.index_select(0,j),labels.index_select(0,j)#yield可以理解为断点
batch_size = 10
num_inputs = 2
w=torch.tensor(np.random.normal(0,0.01,(num_inputs,1)),dtype=torch.float32)
b=torch.zeros(1,dtype=torch.float32)
w.requires_grad_(requires_grad=True)
b.requires_grad_(requires_grad=True)
def logistic_reg(X,w,b):
    return torch.sigmoid((torch.mm(X,w)+b))
def sgd(params,lr,batch_size):
    for param in params:
        param.data-=lr*param.grad/batch_size
lr=0.03
num_epochs=40
net=logistic_reg
loss=nn.BCELoss()
for epoch in range(1,num_epochs+1):
    for Xx,yy in data_iter(batch_size,x,y):
        l=loss(net(Xx,w,b),yy.view(-1,1))
        l.backward()
        sgd([w,b],lr,batch_size)
        w.grad.data.zero_()
        b.grad.data.zero_()
    train_l=loss(net(x,w,b),y.view(-1,1))
    value = net(x,w,b)
    mask = value.ge(0.5).float()
    correct = (mask.view(-1, 1) == y.view(-1, 1)).sum()
    print("epoch%d,损失:%f,准确率为:%f" % (epoch, l.item(), correct / len(x)))