arcface笔记
程序员文章站
2022-03-20 08:22:05
...
ArcFace可以说是现在人脸识别损失函数中应用最为成功也最为有效的损失之一,在各大人脸数据集上可谓屠榜。
论文:
ArcFace: Additive Angular Margin Loss for Deep Face Recognition
常规交叉熵损失如下:
当我们将W和X都通过L2归一化,设置偏执b为0,则可将矩阵相乘的结果logit看成一个余弦距离与缩放因子S的乘积,这两者在结果上是等价的,那么logit就与w和x无关了,仅与θ和s有关,如果固定s,那仅与θ有关。
为了增大类间距离,减少类内距离,我们相想法给目标类logit加学习困难。随着网络的收敛,对应类别的logit会越来越大,θ即样本与类中心距离越来越小,分类性能越来越好。我们再角度空间手动加一个边际m,使得更难训练,但训练出来的网络特征更加紧凑,区分性更好。
作者给出的伪代码,理解起来也非常容易。
首先对x,w归一化,然后计算出logit
依据gt id取出目标logit
依据logit计算出角度θ
计算θ增加margin后的margin_logit
构造one-hot,取出目标logit,并更新目标最终的logit。其中仅对目标类别的logit做更新,但其余类别不用更新
缩放logit得到最终输出logit
将这个logit输入到交叉熵损失函数来计算最终损失。
class ArcMarginModel(nn.Module):
def __init__(self, m=0.5,s=64,easy_margin=False,emb_size=512):
super(ArcMarginModel, self).__init__()
self.weight = Parameter(torch.FloatTensor(num_classes, emb_size))
# num_classes 训练集中总的人脸分类数
# emb_size 特征向量长度
nn.init.xavier_uniform_(self.weight)
# 使用均匀分布来初始化weight
self.easy_margin = easy_margin
self.m = m
# 夹角差值 0.5 公式中的m
self.s = s
# 半径 64 公式中的s
# 二者大小都是论文中推荐值
self.cos_m = math.cos(self.m)
self.sin_m = math.sin(self.m)
# 差值的cos和sin
self.th = math.cos(math.pi - self.m)
# 阈值,避免theta + m >= pi
self.mm = math.sin(math.pi - self.m) * self.m
def forward(self, input, label):
x = F.normalize(input)
W = F.normalize(self.weight)
# 正则化
cosine = F.linear(x, W)
# cos值
sine = torch.sqrt(1.0 - torch.pow(cosine, 2))
# sin
phi = cosine * self.cos_m - sine * self.sin_m
# cos(theta + m) 余弦公式
if self.easy_margin:
phi = torch.where(cosine > 0, phi, cosine)
# 如果使用easy_margin,那么近对cos>0的部分增加学习难度,而cos<0的部分不做改变,因为cos<0,则表明样本预测本就偏离较大
else:
# 这块结合cosine和sine曲线来理解比较好,数形结合,我们目的是
phi = torch.where(cosine > self.th, phi, cosine - self.mm)
one_hot = torch.zeros(cosine.size(), device=device)
one_hot.scatter_(1, label.view(-1, 1).long(), 1)
# 将样本的标签映射为one hot形式 例如N个标签,映射为(N,num_classes)
output = (one_hot * phi) + ((1.0 - one_hot) * cosine)
# 对于正确类别(1*phi)即公式中的cos(theta + m),对于错误的类别(1*cosine)即公式中的cos(theta)
# 这样对于每一个样本,比如[0,0,0,1,0,0]属于第四类,则最终结果为[cosine, cosine, cosine, phi, cosine, cosine]
# 再乘以半径,经过交叉熵,正好是ArcFace的公式
output *= self.s
# 乘以半径
return output