基于Pytorch简单的sotfmax分类算法实现
程序员文章站
2022-07-14 18:50:04
...
使用pytorch搭建简单的softmax分类网络
目的
之前一段时间有很多次想学习深度学习内容,但很多次都失败了。分析其中的一个原因就是每次都陷入某一个定理或公式的推导中,当受挫了就放弃了。下次就又卡在相同的地方了,导致自己每次看都觉得自己搞不定这个。现在,换一个角度去学这个东西,先从整体去看它,如果有不会的地方就略过,先会用,当有了一个整体的认知后,再去想它的实现细节,以及为什么使用这个。
文章基于《动手学深度学习 Pytorch版》
一个简单的softmax模型的实现
一个简单的softmax模型怎么搭建呢?结合书中的实现过程,其实主要有5个步骤:
- 加载数据
- 定义模型和初始化模型参数
- 定义损失函数
- 定义优化算法
- 训练
整个模型的代码:
import torch
from torch import nn
from d2l import torch as d2l
# 加载数据
batch_size = 256
train_iter,test_iter = d2l.load_data_fashion_mnist(batch_size)
# 定义模型及初始化模型参数
net = nn.Sequential(nn.Flatten(),nn.Linear(784,10))
def init_weights(m):
if type(m) == nn,Linear:
nn.init.normal_(m.weight,std=0.01)
net.apply(init_weight)
# 定义损失函数
loss = nn.CrossEntropyLoss()
# 定义优化算法
trainer = torch.optim.SGD(net.parameters(),lr=0.1)
# 训练
num_epochs = 10
d2l.train_ch3(net,train_iter,test_iter,loss,num_epochs,trainer)
训练结果:
至此,一个softmax分类算法完成了!放鞭炮!放鞭炮!放鞭炮!
整个过程就5步,至于每一步里面的具体细节,之后再慢慢去抠,先对整体框架有个初步认识,不会的就略过,反正都有实现好的东西可以用。等学的多了之后,该会的就都会了。
上一篇: java课题--day4基础回顾
下一篇: day4作业-------Java基础班