加入triplet loss的 reid Pytorch实现
程序员文章站
2024-03-14 21:34:53
...
把triplet loss加到reid的实现中了,工程目录结构如下图所示:
loss.py代码如下:
import torch
import torch.nn as nn
import torch.nn.functional as F
def euclidean_dist(x,y):
m,n = x.size(0),y.size(0)
xx = torch.pow(x,2).sum(1,keepdim=True).expand(m,n)
yy = torch.pow(y,2).sum(dim=1,keepdim=True).expand(n,m).t()
dist = xx + yy
dist.addmm_(1,-2,x,y.t())
dist = dist.clamp(min=1e-12).sqrt()
return dist
def cosine_dist(x,y):
bs1, bs2 = x.size(0),y.size(0)
frac_up = torch.matmul(x,y.transpose(0,1))
frac_down = (torch.sqrt( torch.pow(x,2).sum(dim=1) ).view(bs1,1).repeat(1,bs2)) * \
(torch.sqrt( torch.pow(y,2).sum(dim=1).view(1,bs2).repeat(bs1,1) ) )
cosine = frac_up/frac_down
cos_d = 1 - cosine
return cos_d
def _batch_hard(mat_distance,mat_similarity,indice=False):
sorted_mat_distance, positive_indices = torch.sort(mat_distance + (-100000.0)*(1 - mat_similarity),dim=1, descending=True)
hard_p = sorted_mat_distance[:,0]
hard_p_indice = positive_indices[:,0]
sorted_mat_distance, negative_indices = torch.sort( mat_distance + 100000.0 * mat_similarity,dim = 1,descending=False )
hard_n = sorted_mat_distance[:,0]
hard_n_indice = negative_indices[:,0]
if(indice):
return hard_p, hard_n, hard_p_indice, hard_n_indice
return hard_p, hard_n
class TripletLoss(nn.Module):
def __init__(self, margin=0.5, normalize_feature = True):
super(TripletLoss, self).__init__()
self.margin = margin
self.normalize_feature = normalize_feature
self.margin_loss = nn.MarginRankingLoss(margin = margin)
def forward(self, emb, label):
if self.normalize_feature:
emb = F.normalize(emb)
#print('emb')
#print(emb)
mat_dist = euclidean_dist(emb, emb)
#print('mat_dist')
assert mat_dist.size(0) == mat_dist.size(1)
N = mat_dist.size(0)
mat_sim = label.expand(N,N).eq(label.expand(N,N).t()).float()
#print(mat_dist)
#print(mat_sim)
dist_ap, dist_an = _batch_hard(mat_dist, mat_sim)
assert dist_an.size(0) == dist_ap.size(0)
y = torch.ones_like(dist_ap)
loss = self.margin_loss(dist_an, dist_ap, y)
prec = (dist_an.data > dist_ap.data).sum() * 1.0 / y.size(0)
return loss, prec
# loss = nn.CrossEntropyLoss()
# an = torch.randn(4,3)
# y = torch.ones(4).long()
# print(an)
# print(y)
# l = loss(an,y)
# print(l)
# l.backward()
# print(an.grad)
里面定义了triplet loss,这个东西说着挺简单的,但是实现起来还是有些地方需要仔细琢磨考量的,建议对代码的态度不是看看就好,而是敲击,因为一边敲击,一边思考,一边学习招数。之前看这个triplet loss的时候,总是在想,数据集按照pytorch的格式的话,怎么确定谁是anchor,谁是positive, 谁是negative ,这把通过敲击代码全部明白了,只要将从dataloader中读取的data中的label转化一下就可以知道谁是positive,谁是negative了。
model.py的代码如下:
import torch
import torch.nn as nn
from torchvision import models
from torch.nn import functional as F
class resnet_model(nn.Module):
def __init__(self,cut_at_pooling=False, num_features=0, norm=False, dropout=0, num_classes=0 ):
super(resnet_model,self).__init__()
self.cut_at_pooling = cut_at_pooling
resnet = models.resnet50(pretrained=False)
resnet.load_state_dict(torch.load('./pretrain_model/resnet50.pth'))
resnet.layer4[0].conv2.stride = (1,1)
resnet.layer4[0].downsample[0].stride = (1,1)
self.base = nn.Sequential(
resnet.conv1,
resnet.bn1,
resnet.maxpool,
resnet.layer1,
resnet.layer2,
resnet.layer3,
resnet.layer4,
)
self.gap = nn.AdaptiveAvgPool2d(1)
if not self.cut_at_pooling:
self.num_features = num_features
self.norm = norm
self.dropout = dropout
self.has_embedding = num_features > 0
self.num_classes = num_classes
out_planes = resnet.fc.in_features
if self.has_embedding:
self.feat = nn.Linear(out_planes, self.num_features)
self.feat_bn = nn.BatchNorm1d(self.num_features)
nn.init.kaiming_normal_(self.feat.weight,mode='fan_out')
nn.init.constant_(self.feat.bias,0)
else:
self.num_features = out_planes
self.feat_bn = nn.BatchNorm1d(self.num_features)
self.feat_bn.bias.requires_grad_(False)
if self.dropout > 0:
self.drop = nn.Dropout(self.dropout)
if self.num_classes > 0:
self.classifier = nn.Linear(self.num_features,self.num_classes, bias=False)
nn.init.normal_(self.classifier.weight, std=0.001)
nn.init.constant_(self.feat_bn.weight, 1)
nn.init.constant_(self.feat_bn.bias, 0)
def forward(self,x,feature_withbn = False):
x = self.base(x)
x = self.gap(x)
x = x.view(x.size(0), -1)
if self.cut_at_pooling:
return x
if self.has_embedding:
bn_x = self.feat_bn(self.feat(x))
else:
bn_x = self.feat_bn(x)
if self.training is False:
bn_x = F.normalize(bn_x)
return bn_x
if self.norm:
bn_x = F.normalize(bn_x)
elif self.has_embedding:
bn_x = F.relu(bn_x)
if self.dropout > 0:
bn_x = self.drop(bn_x)
if self.num_classes > 0:
prob = self.classifier( bn_x )
else:
return x, bn_x
if feature_withbn:
return bn_x, prob
return x, prob
model.py好像也没啥好说的,以resnet50为backbone,另外加入了一个线性分类器。
reid.py的代码如下:
import torch
import torch.nn as nn
from torchvision import datasets, transforms
from model import resnet_model
from torch.optim import lr_scheduler
import loss
transform_list = [
transforms.Resize((256,128), interpolation=3),
transforms.RandomHorizontalFlip(p=0.5),
transforms.Pad(10),
transforms.RandomCrop((256,128)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406],[0.229, 0.224, 0.225])
]
transform_compose = transforms.Compose(transform_list)
try_dataset1 = datasets.ImageFolder('./try_data1',transform_compose)
try_dataloader1 = torch.utils.data.DataLoader( try_dataset1, batch_size=32,shuffle=True )
try_data1_len = len(try_dataset1)
try_data1_class_name = try_dataset1.classes
net = resnet_model(num_classes=try_data1_len)
net.cuda()
params = []
for key, value in net.named_parameters():
if not value.requires_grad:
continue
params += [ { 'params': [value], 'lr': 0.00035, 'weight_decay':5e-4} ]
optimizer = torch.optim.Adam( params )
exp_lr_scheduler = lr_scheduler.StepLR( optimizer, step_size=10, gamma=0.1 )
criterion_ce = nn.CrossEntropyLoss()
criterion_triple = loss.TripletLoss()
triplet_loss_list = []
pre_loss_list = []
loss_list = []
acc_list = []
for epoch in range(30):
print("epoch: {} / 30" .format(epoch + 1))
for data in try_dataloader1:
input, labels = data
input = input.cuda()
labels = labels.cuda()
features, pres = net(input)
tri_loss, _ = criterion_triple( features, labels )
ce_loss = criterion_ce(features, labels)
loss = tri_loss + ce_loss
triplet_loss_list.append(tri_loss.item())
pre_loss_list.append(ce_loss.item())
loss_list.append(loss.item())
_, pid = torch.max(pres.data, dim = 1)
acc = torch.sum( pid==labels.data )/pid.size(0)
acc_list.append(acc.item())
optimizer.zero_grad()
loss.backward()
optimizer.step()
exp_lr_scheduler.step()
all_acc = sum(acc_list)/(len(acc_list))
all_triple_loss = sum(triplet_loss_list)/(len(triplet_loss_list))
all_pre_loss = sum(pre_loss_list)/(len(pre_loss_list))
all_loss = sum(loss_list)/(len(loss_list))
print('accuracy: {:.4f}'.format(all_acc))
print('triplet loss: {:.4f}:'.format(all_triple_loss))
print('predict loss: {:.4f}'.format(all_pre_loss))
print('loss : {:.4f}'.format(all_loss))
权当记录一下吧,有问题尽管留言。争取做到全网最简单易懂,但是最完整的代码展示。