欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

加入triplet loss的 reid Pytorch实现

程序员文章站 2024-03-14 21:34:53
...

把triplet loss加到reid的实现中了,工程目录结构如下图所示:

加入triplet loss的 reid Pytorch实现

loss.py代码如下:

import torch
import torch.nn as nn
import torch.nn.functional as F

def euclidean_dist(x,y):
    m,n = x.size(0),y.size(0)
    xx = torch.pow(x,2).sum(1,keepdim=True).expand(m,n)
    yy = torch.pow(y,2).sum(dim=1,keepdim=True).expand(n,m).t()
    dist = xx + yy
    dist.addmm_(1,-2,x,y.t())
    dist = dist.clamp(min=1e-12).sqrt()
    return dist

def cosine_dist(x,y):
    bs1, bs2 = x.size(0),y.size(0)
    frac_up = torch.matmul(x,y.transpose(0,1))
    frac_down = (torch.sqrt( torch.pow(x,2).sum(dim=1) ).view(bs1,1).repeat(1,bs2)) * \
                (torch.sqrt( torch.pow(y,2).sum(dim=1).view(1,bs2).repeat(bs1,1) )  )
    cosine = frac_up/frac_down
    cos_d = 1 - cosine
    return cos_d

def _batch_hard(mat_distance,mat_similarity,indice=False):
    sorted_mat_distance, positive_indices = torch.sort(mat_distance + (-100000.0)*(1 - mat_similarity),dim=1, descending=True)
    hard_p = sorted_mat_distance[:,0]
    hard_p_indice = positive_indices[:,0]
    sorted_mat_distance, negative_indices = torch.sort( mat_distance + 100000.0 * mat_similarity,dim = 1,descending=False )
    hard_n = sorted_mat_distance[:,0]
    hard_n_indice = negative_indices[:,0]
    if(indice):
        return hard_p, hard_n, hard_p_indice, hard_n_indice
    return hard_p, hard_n

class TripletLoss(nn.Module):
    def __init__(self, margin=0.5, normalize_feature = True):
        super(TripletLoss,  self).__init__()
        self.margin = margin
        self.normalize_feature = normalize_feature
        self.margin_loss = nn.MarginRankingLoss(margin = margin)

    def forward(self, emb, label):
        if self.normalize_feature:
            emb = F.normalize(emb)
        #print('emb')
        #print(emb)
        mat_dist = euclidean_dist(emb, emb)
        #print('mat_dist')
        assert mat_dist.size(0) == mat_dist.size(1)
        N = mat_dist.size(0)
        mat_sim = label.expand(N,N).eq(label.expand(N,N).t()).float()
        #print(mat_dist)
        #print(mat_sim)
        dist_ap, dist_an = _batch_hard(mat_dist, mat_sim)
        assert dist_an.size(0) == dist_ap.size(0)
        y = torch.ones_like(dist_ap)
        loss = self.margin_loss(dist_an, dist_ap, y)

        prec = (dist_an.data > dist_ap.data).sum() * 1.0 / y.size(0)
        return loss, prec


# loss = nn.CrossEntropyLoss()
# an = torch.randn(4,3)
# y = torch.ones(4).long()
# print(an)
# print(y)
# l = loss(an,y)
# print(l)
# l.backward()
# print(an.grad)


里面定义了triplet loss,这个东西说着挺简单的,但是实现起来还是有些地方需要仔细琢磨考量的,建议对代码的态度不是看看就好,而是敲击,因为一边敲击,一边思考,一边学习招数。之前看这个triplet loss的时候,总是在想,数据集按照pytorch的格式的话,怎么确定谁是anchor,谁是positive, 谁是negative ,这把通过敲击代码全部明白了,只要将从dataloader中读取的data中的label转化一下就可以知道谁是positive,谁是negative了。

model.py的代码如下:

import torch
import torch.nn as nn
from torchvision import models
from torch.nn import functional as F

class resnet_model(nn.Module):
    def __init__(self,cut_at_pooling=False, num_features=0, norm=False, dropout=0, num_classes=0 ):
        super(resnet_model,self).__init__()
        self.cut_at_pooling = cut_at_pooling
        resnet = models.resnet50(pretrained=False)
        resnet.load_state_dict(torch.load('./pretrain_model/resnet50.pth'))
        resnet.layer4[0].conv2.stride = (1,1)
        resnet.layer4[0].downsample[0].stride = (1,1)
        self.base = nn.Sequential(
            resnet.conv1,
            resnet.bn1,
            resnet.maxpool,
            resnet.layer1,
            resnet.layer2,
            resnet.layer3,
            resnet.layer4,
        )
        self.gap = nn.AdaptiveAvgPool2d(1)

        if not self.cut_at_pooling:
            self.num_features = num_features
            self.norm = norm
            self.dropout = dropout
            self.has_embedding = num_features > 0
            self.num_classes = num_classes

            out_planes = resnet.fc.in_features

            if self.has_embedding:
                self.feat = nn.Linear(out_planes, self.num_features)
                self.feat_bn = nn.BatchNorm1d(self.num_features)
                nn.init.kaiming_normal_(self.feat.weight,mode='fan_out')
                nn.init.constant_(self.feat.bias,0)
            else:
                self.num_features = out_planes
                self.feat_bn = nn.BatchNorm1d(self.num_features)
            self.feat_bn.bias.requires_grad_(False)
            if self.dropout > 0:
                self.drop = nn.Dropout(self.dropout)
            if self.num_classes > 0:
                self.classifier = nn.Linear(self.num_features,self.num_classes, bias=False)
                nn.init.normal_(self.classifier.weight, std=0.001)
        nn.init.constant_(self.feat_bn.weight, 1)
        nn.init.constant_(self.feat_bn.bias, 0)

    def forward(self,x,feature_withbn = False):
        x = self.base(x)

        x = self.gap(x)
        x = x.view(x.size(0), -1)

        if self.cut_at_pooling:
            return x

        if self.has_embedding:
            bn_x = self.feat_bn(self.feat(x))
        else:
            bn_x = self.feat_bn(x)

        if self.training is False:
            bn_x = F.normalize(bn_x)
            return bn_x

        if self.norm:
            bn_x = F.normalize(bn_x)
        elif self.has_embedding:
            bn_x = F.relu(bn_x)

        if self.dropout > 0:
            bn_x = self.drop(bn_x)

        if self.num_classes > 0:
            prob = self.classifier( bn_x )
        else:
            return x, bn_x

        if feature_withbn:
            return bn_x, prob
        return x, prob

   model.py好像也没啥好说的,以resnet50为backbone,另外加入了一个线性分类器。

reid.py的代码如下:

import torch
import torch.nn as nn
from torchvision  import datasets, transforms
from model import resnet_model
from torch.optim import lr_scheduler
import loss

transform_list = [
    transforms.Resize((256,128), interpolation=3),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.Pad(10),
    transforms.RandomCrop((256,128)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406],[0.229, 0.224, 0.225])
]

transform_compose = transforms.Compose(transform_list)

try_dataset1 = datasets.ImageFolder('./try_data1',transform_compose)
try_dataloader1 = torch.utils.data.DataLoader( try_dataset1, batch_size=32,shuffle=True )

try_data1_len = len(try_dataset1)
try_data1_class_name = try_dataset1.classes

net = resnet_model(num_classes=try_data1_len)
net.cuda()

params = []
for key, value in net.named_parameters():
    if not value.requires_grad:
        continue
    params += [ { 'params': [value], 'lr': 0.00035, 'weight_decay':5e-4} ]

optimizer = torch.optim.Adam( params )
exp_lr_scheduler = lr_scheduler.StepLR( optimizer, step_size=10, gamma=0.1 )

criterion_ce = nn.CrossEntropyLoss()
criterion_triple = loss.TripletLoss()

triplet_loss_list = []
pre_loss_list = []
loss_list = []
acc_list = []

for epoch in range(30):

    print("epoch: {} / 30" .format(epoch + 1))
    for data in try_dataloader1:
        input, labels = data
        input = input.cuda()
        labels = labels.cuda()

        features, pres = net(input)

        tri_loss, _ = criterion_triple( features, labels )
        ce_loss = criterion_ce(features, labels)

        loss = tri_loss + ce_loss
        triplet_loss_list.append(tri_loss.item())
        pre_loss_list.append(ce_loss.item())
        loss_list.append(loss.item())

        _, pid = torch.max(pres.data, dim = 1)
        acc = torch.sum( pid==labels.data )/pid.size(0)
        acc_list.append(acc.item())

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    exp_lr_scheduler.step()
    all_acc = sum(acc_list)/(len(acc_list))
    all_triple_loss = sum(triplet_loss_list)/(len(triplet_loss_list))
    all_pre_loss = sum(pre_loss_list)/(len(pre_loss_list))
    all_loss = sum(loss_list)/(len(loss_list))

    print('accuracy: {:.4f}'.format(all_acc))
    print('triplet loss: {:.4f}:'.format(all_triple_loss))
    print('predict loss: {:.4f}'.format(all_pre_loss))
    print('loss : {:.4f}'.format(all_loss))

权当记录一下吧,有问题尽管留言。争取做到全网最简单易懂,但是最完整的代码展示。