多元组损失整理(二元组损失、三元组损失、四元组损失)
Contrastive Loss
主要用于处理成对的数据
公式
L ( W , Y , X ⃗ 1 , X ⃗ 2 ) = ( 1 − Y ) 1 2 ( D W ) 2 + ( Y ) 1 2 { max ( 0 , m − D W ) } 2 L\left(W, Y, \vec{X}_{1}, \vec{X}_{2}\right)=(1-Y) \frac{1}{2}\left(D_{W}\right)^{2}+(Y) \frac{1}{2}\left\{\max \left(0, m-D_{W}\right)\right\}^{2} L(W,Y,X 1,X 2)=(1−Y)21(DW)2+(Y)21{max(0,m−DW)}2
说明
D w D_w Dw 是数据对特征的欧式距离; Y Y Y 用于表示是否同类,同类表示为0, 不同类表示为1; m m m 我理解为是一个阈值
含义理解
当数据对同类的时候, Y Y Y 为0,公式变为 L ( W , Y , X ⃗ 1 , X ⃗ 2 ) = 1 2 ( D W ) 2 L\left(W, Y, \vec{X}_{1}, \vec{X}_{2}\right)=\frac{1}{2}\left(D_{W}\right)^{2} L(W,Y,X 1,X 2)=21(DW)2, 这是想让同类之间的距离尽量小
当数据对不同类的时候, Y Y Y 为1,公式变为 L ( W , Y , X ⃗ 1 , X ⃗ 2 ) = 1 2 { max ( 0 , m − D W ) } 2 L\left(W, Y, \vec{X}_{1}, \vec{X}_{2}\right)=\frac{1}{2}\left\{\max \left(0, m-D_{W}\right)\right\}^{2} L(W,Y,X 1,X 2)=21{max(0,m−DW)}2, 这是想让不同类之间的距离尽量大,但是当 D w ( n e g a t i v e ) > m D_w(negative) > m Dw(negative)>m 的时候,距离大的很明显了, 这种easy negative pair就不处理了
Triplet Loss
和contrastive loss的不同在于,构建了一个 < a , p , n > <a, p, n> <a,p,n> 三元组来计算损失。其中 a a a 代表anchor, n n n 代表negative 负匹配, p p p 代表positive 正匹配
公式
∥ f ( x i a ) − f ( x i p ) ∥ 2 2 + α < ∥ f ( x i a ) − f ( x i n ) ∥ 2 2 \left\|f\left(x_{i}^{a}\right)-f\left(x_{i}^{p}\right)\right\|_{2}^{2}+\alpha<\left\|f\left(x_{i}^{a}\right)-f\left(x_{i}^{n}\right)\right\|_{2}^{2} ∥f(xia)−f(xip)∥22+α<∥f(xia)−f(xin)∥22
L = ∑ i N [ ∥ f ( x i a ) − f ( x i p ) ∥ 2 2 − ∥ f ( x i a ) − f ( x i n ) ∥ 2 2 + α ] + L=\sum_{i}^{N}\left[\left\|f\left(x_{i}^{a}\right)-f\left(x_{i}^{p}\right)\right\|_{2}^{2}-\left\|f\left(x_{i}^{a}\right)-f\left(x_{i}^{n}\right)\right\|_{2}^{2}+\alpha\right]_{+} L=∑iN[∥f(xia)−f(xip)∥22−∥f(xia)−f(xin)∥22+α]+
说明
f ( x ) f(x) f(x) 表示embedding,将image x x x 编码到一个 d d d 维的欧式空间中。 x i a x^a_i xia 表示anchor, x i p x^p_i xip 表示positive样本, x i p x^p_i xip 表示neigative样本, α \alpha α 表示为 margin,是一个常量
含义理解
网络没经过学习之前, a a a和 p p p的欧式距离可能很大, a a a和 n n n的欧式距离可能很小;通过学习,使得类间的距离要大于类内的距离。
PyTorch 实现
torch.nn.TripletMarginLoss(margin=1.0, p=2.0, eps=1e-06, swap=False,size_average=None, reduce=None, reduction='mean')
>>> triplet_loss = nn.TripletMarginLoss(margin=1.0, p=2)
>>> anchor = torch.randn(100, 128, requires_grad=True)
>>> positive = torch.randn(100, 128, requires_grad=True)
>>> negative = torch.randn(100, 128, requires_grad=True)
>>> output = triplet_loss(anchor, positive, negative)
>>> output.backward()
参数说明:
margin: α \alpha α 的值, 默认为1
p:为范数选择,默认为2
Triplet Loss变体
Improved triplet loss (改进三元组损失)
公式
L t = d a , p + ( d a , p − d a , n + α ) L_{t}=d_{a, p}+\left(d_{a, p}-d_{a, n}+\alpha\right) Lt=da,p+(da,p−da,n+α)
说明
传统的三元组损失只考虑正负样本对之间的相对距离,并没有考虑正样本对之间的绝对距离。改进后的triplet loss保证网络不仅能够在特征空间把正负样本推开,也能保证正样本对之间的距离很近。
Triplet loss with batch hard mining, TriHard loss(难样本采样三元组损失)
说明
传统的三元组随机从训练数据中抽样三张图片,这样的做法虽然比较简单,但是抽样出来的大部分都是简单易区分的样本对。如果大量训练的样本对都是简单的样本对,那么这是不利于网络学习到更好的表征。大量论文发现用更难的样本去训练网络能够提高网络的泛化能力,而采样难样本对的方法很多。
TriHard之前是在一篇ReID论文中提出
TriHard损失的核心思想是:对于每一个训练batch,随机挑选 P 个ID的行人,每个行人随机挑选 K 张不同的图片,即一个batch含有 P×K 张图片。之后对于batch中的每一张图片 a ,我们可以挑选一个最难的正样本和一个最难的负样本和 a 组成一个三元组。
公式
L t h = 1 P × K ∑ a ∈ batch ( max P ∈ A d a , p − min n ∈ B d a , n + α ) + L_{\mathrm{th}}=\frac{1}{P \times K} \sum_{a \in \text {batch}}\left(\max _{P \in A} d_{a, p}-\min _{n \in B} d_{a, n}+\alpha\right)_{+} Lth=P×K1∑a∈batch(maxP∈Ada,p−minn∈Bda,n+α)+
d d d 表示距离
E-negative triplet loss
公式
L ( q m a , q ^ m p , x ^ m n ) = ∑ m = 1 M N 1 E ∑ e = 1 E [ ∥ f ( q m a ) − f ( q ^ m p ) ) ∥ 2 2 − ∥ f ( q m a ) − f ( q ^ m n e ) ∥ 2 2 + α ] + \begin{aligned} \mathcal{L}\left(q_{m}^{a}, \hat{q}_{m}^{p}, \hat{x}_{m}^{n}\right)=\sum_{m=1}^{M N} & \frac{1}{E} \sum_{e=1}^{E}\left[\| f\left(q_{m}^{a}\right)-f\left(\hat{q}_{m}^{p}\right)\right) \|_{2}^{2}\left.-\left\|f\left(q_{m}^{a}\right)-f\left(\hat{q}_{m}^{n_{e}}\right)\right\|_{2}^{2}+\alpha\right]_{+} \end{aligned} L(qma,q^mp,x^mn)=m=1∑MNE1e=1∑E[∥f(qma)−f(q^mp))∥22−∥f(qma)−f(q^mne)∥22+α]+
说明
改进在于选择了E个负样本, 论文中对于这个改进的描述:
This extension makes the feature comparison more effective and faithful to the few-shot learning procedure, since in each update, the network can compare a sample with multiple negative classes.
Quadruplet loss(四元组损失)
公式
L t = ( d a , p − d a , n 1 + α ) + + ( d a , p − d n 1 , n 2 + β ) + L_{t}=\left(d_{a, p}-d_{a, n_{1}}+\alpha\right)_{+}+\left(d_{a, p}-d_{n_{1}, n_{2}}+\beta\right)_{+} Lt=(da,p−da,n1+α)++(da,p−dn1,n2+β)+
α \alpha α β \beta β是手动设置的正常数, 通常设置 β \beta β小于 α \alpha α
说明
输入四张图片,即 < a , p , n 1 , n 2 > <a, p, n_1, n_2> <a,p,n1,n2>, 和传统三元组比多了一个负样本。公式中,前一项称为强推动,后一项称为弱推动。
PyTorch 实现
目前好像没有专门的四元组损失接口,但是我理解可以用两个torch.nn.TripletMarginLoss实现?
上一篇: 不用CSS,如何控制如下显示
下一篇: 在YII中实现三个公共方法