零基础入门CV赛事-TASK4模型训练与验证学习笔记
程序员文章站
2024-03-19 16:55:40
...
零基础入门CV赛事-TASK4模型训练与验证
这次笔记为参与天池街道字符识别赛事中组对学习的第四篇学习笔记
赛题名称
- 赛题名称:零基础入门CV之街道字符识别
- 赛题目标:通过这道赛题可以引导大家走入计算机视觉的世界,主要针对竞赛选手上手视觉赛题,提高对数据建模能力。
- 赛题任务:赛题以计算机视觉中字符识别为背景,要求选手预测街道字符编码,这是一个典型的字符识别问题。
为了简化赛题难度,赛题数据采用公开数据集SVHN,因此大家可以选择很多相应的paper作为思路参考。
Task4 模型训练与验证
4.1 学习目标
- 理解验证集的作用,并使用训练集和验证集完成训练
- 学会使用Pytorch环境下的模型读取和加载,并了解调参流程
4.2 模型训练与验证
在本次比赛中,官方已将训练集和验证集划分好,之后的目标就是使用Pytorch来完成CNN的训练和验证过程,CNN网络结构与之前的章节中保持一致。我们需要完成的逻辑结构如下:
- 构建训练集和验证集;
- 每轮进行训练和验证,并根据最优验证集精度保存模型。
# 导入训练集和验证集
train_loader = torch.utils.data.DataLoader(
train_dataset,
batch_size=10,
shuffle=True,
num_workers=10, # 在windows下不能设置为大于0
)
val_loader = torch.utils.data.DataLoader(
val_dataset,
batch_size=10,
shuffle=False,
num_workers=10,
)
# 导入模型
model = SVHN_Model1()
criterion = nn.CrossEntropyLoss (size_average=False)
optimizer = torch.optim.Adam(model.parameters(), 0.001)
best_loss = 1000.0
for epoch in range(20):
print('Epoch: ', epoch)
train(train_loader, model, criterion, optimizer, epoch)
val_loss = validate(val_loader, model, criterion)
# 记录下验证集精度
if val_loss < best_loss:
best_loss = val_loss
torch.save(model.state_dict(), './model.pt')
每个Epoch的训练代码:
def train(train_loader, model, criterion, optimizer, epoch):
# 切换模型为训练模式
model.train()
for i, data in enumerate(train_loader):
data[1] = data[1].long()
c0, c1, c2, c3, c4, c5 = model(data[0])
loss = criterion(c0, data[1][:, 0]) + \
criterion(c1, data[1][:, 1]) + \
criterion(c2, data[1][:, 2]) + \
criterion(c3, data[1][:, 3]) + \
criterion(c4, data[1][:, 4]) + \
criterion(c5, data[1][:, 5])
# loss /= 6
optimizer.zero_grad()
loss.backward()
optimizer.step()
每个Epoch的验证代码:
def validate(val_loader, model, criterion):
# 切换模型为预测模型
model.eval()
val_loss = []
# 不记录模型梯度信息
with torch.no_grad():
for i, data in enumerate(val_loader):
data[1] = data[1].long()
c0, c1, c2, c3, c4, c5 = model(data[0])
loss = criterion(c0, data[1][:, 0]) + \
criterion(c1, data[1][:, 1]) + \
criterion(c2, data[1][:, 2]) + \
criterion(c3, data[1][:, 3]) + \
criterion(c4, data[1][:, 4]) + \
criterion(c5, data[1][:, 5])
# loss /= 6
val_loss.append(loss.item())
return np.mean(val_loss)
4.3 模型保存与加载
在Pytorch中模型的保存和加载非常简单,比较常见的做法是保存和加载模型参数:torch.save(model_object.state_dict(), 'model.pt')
model.load_state_dict(torch.load(' model.pt'))
4.4 心得体会
时间过得也是很快,从一周前刚开始接触这个赛事,Pytorch的操作完全无知的情况下到现在自己可以构建CNN模型的代码,虽然现在的水平也不能上得了台面,但现在做个调参侠还是可以将模块自带的模型进行迁移训练来适应当前数据集的变化,进而提高模型验证的精度,总之,通过最近的学习,我在这里也学到很多关于Pytorch和CNN的知识。
上一篇: c语言写一个简单的小游戏-推箱子
下一篇: 高级架构师备考经验分享