第一课时 基于MNIST数据集
程序员文章站
2022-03-04 20:09:52
...
import torch
from linear_net.load import load_data
import torch.utils.data as td
# 定义加载数据的函数,data_folder为保存gz数据的文件夹,该文件夹下有4个文件
# 'train-labels-idx1-ubyte.gz', 'train-images-idx3-ubyte.gz',
# 't10k-labels-idx1-ubyte.gz', 't10k-images-idx3-ubyte.gz'
(train_images, train_labels), (test_images, test_labels) = load_data()
X_train = torch.tensor(train_images / 255)
y_train = torch.tensor(train_labels).view(train_labels.shape[0], 1)
X_test = torch.tensor(test_images / 255)
y_test = torch.tensor(test_labels)
# y_train = torch.zeros(y_train.shape[0],10).scatter(dim=1,index=y_train.long(),src=torch.ones_like(y_train).float())
# print(y_train.shape)
# print(X_train.shape)
class MyDataset(td.Dataset):
def __init__(self, d_t, d_l):
self.data_tensor = d_t
self.target_tensor = d_l
def __getitem__(self, item):
return self.data_tensor[item], self.target_tensor[item]
def __len__(self):
return self.data_tensor.size(0)
class FlattenLayer(torch.nn.Module):
def __init__(self):
super(FlattenLayer, self).__init__()
def forward(self, x):
return x.view(1, -1)
class Net(torch.nn.Module):
def __init__(self):
super(Net, self).__init__()
self.flatten = FlattenLayer()
self.linear_1 = torch.nn.Linear(28 * 28, 5)
self.relu = torch.nn.ReLU()
self.linear_2 = torch.nn.Linear(5, 10)
self.softmax = torch.nn.Softmax(dim=1)
# self.linear_1.weight.data.normal_(0.01,0.1)
# self.linear_2.weight.data.normal_(0.01,0.1)
# torch.nn.init.constant(self.linear_1.bias,0.0)
# torch.nn.init.constant(self.linear_2.bias,0.0)
def forward(self, x):
x = self.flatten(x)
x = self.relu(self.linear_1(x))
x = self.linear_2(x)
# print(x)
return x
'''net = torch.nn.Sequential(
OrderedDict([
('flatten', FlattenLayer()),
('linear_1', torch.nn.Linear(28 * 28, 5)),
('relu',torch.nn.ReLU()),
('linear_2', torch.nn.Linear(5, 10)),
#('softmax',torch.nn.Softmax())
]))'''
net_t = Net()
train = MyDataset(X_train, y_train)
test = MyDataset(X_test.float(), y_test.long())
torch.nn.init.normal_(net_t.linear_1.weight, 0, std=0.1)
torch.nn.init.normal_(net_t.linear_2.weight, 0, std=0.1)
torch.nn.init.constant(net_t.linear_1.bias, 0)
torch.nn.init.constant(net_t.linear_2.bias, 0)
loss = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(net_t.parameters(), lr=0.01)
data_loader = td.DataLoader(train, batch_size=200, shuffle=True)
# print(net_t)
correct = 0
for i, data in enumerate(test):
inputs, label = data
outputs = net_t(inputs)
_, index = torch.max(outputs, 1)
if index[0].long() == y_test[i].long():
correct += 1
print(correct / y_test.shape[0])
for epoch in range(2):
running_loss = 0.0
for i, data in enumerate(train):
net_t.zero_grad()
inputs, label = data
# inputs = inputs.float()
optimizer.zero_grad()
outputs = net_t(inputs.float())
# print(label)
# print(outputs)
l = loss(outputs, label.long())
l.requires_grad_()
l.retain_grad()
# print(l)
l.backward()
w1, b1 = net_t.linear_1.parameters()
# print(w1.grad.sum())
running_loss += l.item()
optimizer.step()
print(running_loss / len(train))
print(running_loss)
correct = 0
for i, data in enumerate(test):
inputs, label = data
outputs = net_t(inputs)
_, index = torch.max(outputs, 1)
if index[0].long() == y_test[i].long():
correct += 1
print(correct / y_test.shape[0])
import gzip
import os
import numpy as np
def load_data(data_folder=''):
files = [
'train-labels-idx1-ubyte.gz', 'train-images-idx3-ubyte.gz',
't10k-labels-idx1-ubyte.gz', 't10k-images-idx3-ubyte.gz'
]
paths = []
for fname in files:
paths.append(os.path.join(data_folder,fname))
with gzip.open(paths[0], 'rb') as lbpath:
y_train = np.frombuffer(lbpath.read(), np.uint8, offset=8)
with gzip.open(paths[1], 'rb') as imgpath:
x_train = np.frombuffer(
imgpath.read(), np.uint8, offset=16).reshape(len(y_train), 28, 28)
with gzip.open(paths[2], 'rb') as lbpath:
y_test = np.frombuffer(lbpath.read(), np.uint8, offset=8)
with gzip.open(paths[3], 'rb') as imgpath:
x_test = np.frombuffer(
imgpath.read(), np.uint8, offset=16).reshape(len(y_test), 28, 28)
return (x_train, y_train), (x_test, y_test)
写了一下午呐,踩了好多坑,还是很有收获的,刚开始数据集都导不进来//????
这个是task1,
task2的传送门:https://blog.csdn.net/qq_16899143/article/details/104319470
上一篇: 蓝桥杯JavaB组---振兴中华
下一篇: 通用分页存储过程