python深度学习之多标签分类器及pytorch实现源码
程序员文章站
2022-03-06 08:08:50
目录多标签分类器多标签分类器损失函数代码实现多标签分类器多标签分类任务与多分类任务有所不同,多分类任务是将一个实例分到某个类别中,多标签分类任务是将某个实例分到多个类别中。多标签分类任务有有两大特点:...
多标签分类器
多标签分类任务与多分类任务有所不同,多分类任务是将一个实例分到某个类别中,多标签分类任务是将某个实例分到多个类别中。多标签分类任务有有两大特点:
- 类标数量不确定,有些样本可能只有一个类标,有些样本的类标可能高达几十甚至上百个
- 类标之间相互依赖,例如包含蓝天类标的样本很大概率上包含白云
如下图所示,即为一个多标签分类学习的一个例子,一张图片里有多个类别,房子,树,云等,深度学习模型需要将其一一分类识别出来。
多标签分类器损失函数
代码实现
针对图像的多标签分类器pytorch的简化代码实现如下所示。因为图像的多标签分类器的数据集比较难获取,所以可以通过对mnist数据集中的每个图片打上特定的多标签,例如类别1的多标签可以为[1,1,0,1,0,1,0,0,1],然后再利用重新打标后的数据集训练出一个mnist的多标签分类器。
from torchvision import datasets, transforms from torch.utils.data import dataloader, dataset import torch import torch.nn as nn import torch.optim as optim import torch.nn.functional as f import os class cnn(nn.module): def __init__(self): super().__init__() self.sq1 = nn.sequential( nn.conv2d(in_channels=1, out_channels=16, kernel_size=5, stride=1, padding=2), # (16, 28, 28) # output: (16, 28, 28) nn.relu(), nn.maxpool2d(kernel_size=2), # (16, 14, 14) ) self.sq2 = nn.sequential( nn.conv2d(in_channels=16, out_channels=32, kernel_size=5, stride=1, padding=2), # (32, 14, 14) nn.relu(), nn.maxpool2d(2), # (32, 7, 7) ) self.out = nn.linear(32 * 7 * 7, 100) def forward(self, x): x = self.sq1(x) x = self.sq2(x) x = x.view(x.size(0), -1) x = self.out(x) ## sigmoid activation output = f.sigmoid(x) # 1/(1+e**(-x)) return output def loss_fn(pred, target): return -(target * torch.log(pred) + (1 - target) * torch.log(1 - pred)).sum() def multilabel_generate(label): y1 = f.one_hot(label, num_classes = 100) y2 = f.one_hot(label+10, num_classes = 100) y3 = f.one_hot(label+50, num_classes = 100) multilabel = y1+y2+y3 return multilabel # def multilabel_generate(label): # multilabel_dict = {} # multi_list = [] # for i in range(label.shape[0]): # multi_list.append(multilabel_dict[label[i].item()]) # multilabel_tensor = torch.tensor(multi_list) # return multilabel def train(): epoches = 10 mnist_net = cnn() mnist_net.train() opitimizer = optim.sgd(mnist_net.parameters(), lr=0.002) mnist_train = datasets.mnist("mnist-data", train=true, download=true, transform=transforms.totensor()) train_loader = torch.utils.data.dataloader(mnist_train, batch_size= 128, shuffle=true) for epoch in range(epoches): loss = 0 for batch_x, batch_y in train_loader: opitimizer.zero_grad() outputs = mnist_net(batch_x) loss = loss_fn(outputs, multilabel_generate(batch_y)) / batch_x.shape[0] loss.backward() opitimizer.step() print(loss) if __name__ == '__main__': train()
以上就是python深度学习之多标签分类器及pytorch源码的详细内容,更多关于多标签分类器pytorch源码的资料请关注其它相关文章!