欢迎您访问程序员文章站本站旨在为大家提供分享程序员计算机编程知识!
您现在的位置是: 首页

多元组损失整理(二元组损失、三元组损失、四元组损失)

程序员文章站 2024-01-16 11:48:40
...

参考链接

Contrastive Loss

来源

主要用于处理成对的数据

公式

L ( W , Y , X ⃗ 1 , X ⃗ 2 ) = ( 1 − Y ) 1 2 ( D W ) 2 + ( Y ) 1 2 { max ⁡ ( 0 , m − D W ) } 2 L\left(W, Y, \vec{X}_{1}, \vec{X}_{2}\right)=(1-Y) \frac{1}{2}\left(D_{W}\right)^{2}+(Y) \frac{1}{2}\left\{\max \left(0, m-D_{W}\right)\right\}^{2} L(W,Y,X 1,X 2)=(1Y)21(DW)2+(Y)21{max(0,mDW)}2

说明

D w D_w Dw 是数据对特征的欧式距离; Y Y Y 用于表示是否同类,同类表示为0, 不同类表示为1 m m m 我理解为是一个阈值

含义理解

当数据对同类的时候, Y Y Y 为0,公式变为 L ( W , Y , X ⃗ 1 , X ⃗ 2 ) = 1 2 ( D W ) 2 L\left(W, Y, \vec{X}_{1}, \vec{X}_{2}\right)=\frac{1}{2}\left(D_{W}\right)^{2} L(W,Y,X 1,X 2)=21(DW)2, 这是想让同类之间的距离尽量小

当数据对不同类的时候, Y Y Y 为1,公式变为 L ( W , Y , X ⃗ 1 , X ⃗ 2 ) = 1 2 { max ⁡ ( 0 , m − D W ) } 2 L\left(W, Y, \vec{X}_{1}, \vec{X}_{2}\right)=\frac{1}{2}\left\{\max \left(0, m-D_{W}\right)\right\}^{2} L(W,Y,X 1,X 2)=21{max(0,mDW)}2, 这是想让不同类之间的距离尽量大,但是当 D w ( n e g a t i v e ) > m D_w(negative) > m Dw(negative)>m 的时候,距离大的很明显了, 这种easy negative pair就不处理了

Triplet Loss

来源

参考链接

和contrastive loss的不同在于,构建了一个 < a , p , n > <a, p, n> <a,p,n> 三元组来计算损失。其中 a a a 代表anchor, n n n 代表negative 负匹配, p p p 代表positive 正匹配

公式

∥ f ( x i a ) − f ( x i p ) ∥ 2 2 + α < ∥ f ( x i a ) − f ( x i n ) ∥ 2 2 \left\|f\left(x_{i}^{a}\right)-f\left(x_{i}^{p}\right)\right\|_{2}^{2}+\alpha<\left\|f\left(x_{i}^{a}\right)-f\left(x_{i}^{n}\right)\right\|_{2}^{2} f(xia)f(xip)22+α<f(xia)f(xin)22

L = ∑ i N [ ∥ f ( x i a ) − f ( x i p ) ∥ 2 2 − ∥ f ( x i a ) − f ( x i n ) ∥ 2 2 + α ] + L=\sum_{i}^{N}\left[\left\|f\left(x_{i}^{a}\right)-f\left(x_{i}^{p}\right)\right\|_{2}^{2}-\left\|f\left(x_{i}^{a}\right)-f\left(x_{i}^{n}\right)\right\|_{2}^{2}+\alpha\right]_{+} L=iN[f(xia)f(xip)22f(xia)f(xin)22+α]+

说明

f ( x ) f(x) f(x) 表示embedding,将image x x x 编码到一个 d d d 维的欧式空间中。 x i a x^a_i xia 表示anchor, x i p x^p_i xip 表示positive样本, x i p x^p_i xip 表示neigative样本, α \alpha α 表示为 margin,是一个常量

含义理解

多元组损失整理(二元组损失、三元组损失、四元组损失)

网络没经过学习之前, a a a p p p的欧式距离可能很大, a a a n n n的欧式距离可能很小;通过学习,使得类间的距离要大于类内的距离

PyTorch 实现

torch.nn.TripletMarginLoss(margin=1.0, p=2.0, eps=1e-06, swap=False,size_average=None, reduce=None, reduction='mean')
>>> triplet_loss = nn.TripletMarginLoss(margin=1.0, p=2)
>>> anchor = torch.randn(100, 128, requires_grad=True)
>>> positive = torch.randn(100, 128, requires_grad=True)
>>> negative = torch.randn(100, 128, requires_grad=True)
>>> output = triplet_loss(anchor, positive, negative)
>>> output.backward()

参数说明:

margin: α \alpha α 的值, 默认为1

p:为范数选择,默认为2

Triplet Loss变体

Improved triplet loss (改进三元组损失)

公式

L t = d a , p + ( d a , p − d a , n + α ) L_{t}=d_{a, p}+\left(d_{a, p}-d_{a, n}+\alpha\right) Lt=da,p+(da,pda,n+α)

说明

传统的三元组损失只考虑正负样本对之间的相对距离,并没有考虑正样本对之间的绝对距离。改进后的triplet loss保证网络不仅能够在特征空间把正负样本推开,也能保证正样本对之间的距离很近。

Triplet loss with batch hard mining, TriHard loss(难样本采样三元组损失)

说明

传统的三元组随机从训练数据中抽样三张图片,这样的做法虽然比较简单,但是抽样出来的大部分都是简单易区分的样本对。如果大量训练的样本对都是简单的样本对,那么这是不利于网络学习到更好的表征。大量论文发现用更难的样本去训练网络能够提高网络的泛化能力,而采样难样本对的方法很多。

TriHard之前是在一篇ReID论文中提出

TriHard损失的核心思想是:对于每一个训练batch,随机挑选 P 个ID的行人,每个行人随机挑选 K 张不同的图片,即一个batch含有 P×K 张图片。之后对于batch中的每一张图片 a ,我们可以挑选一个最难的正样本和一个最难的负样本和 a 组成一个三元组。

公式

L t h = 1 P × K ∑ a ∈ batch ( max ⁡ P ∈ A d a , p − min ⁡ n ∈ B d a , n + α ) + L_{\mathrm{th}}=\frac{1}{P \times K} \sum_{a \in \text {batch}}\left(\max _{P \in A} d_{a, p}-\min _{n \in B} d_{a, n}+\alpha\right)_{+} Lth=P×K1abatch(maxPAda,pminnBda,n+α)+

d d d 表示距离

E-negative triplet loss

来源

公式

L ( q m a , q ^ m p , x ^ m n ) = ∑ m = 1 M N 1 E ∑ e = 1 E [ ∥ f ( q m a ) − f ( q ^ m p ) ) ∥ 2 2 − ∥ f ( q m a ) − f ( q ^ m n e ) ∥ 2 2 + α ] + \begin{aligned} \mathcal{L}\left(q_{m}^{a}, \hat{q}_{m}^{p}, \hat{x}_{m}^{n}\right)=\sum_{m=1}^{M N} & \frac{1}{E} \sum_{e=1}^{E}\left[\| f\left(q_{m}^{a}\right)-f\left(\hat{q}_{m}^{p}\right)\right) \|_{2}^{2}\left.-\left\|f\left(q_{m}^{a}\right)-f\left(\hat{q}_{m}^{n_{e}}\right)\right\|_{2}^{2}+\alpha\right]_{+} \end{aligned} L(qma,q^mp,x^mn)=m=1MNE1e=1E[f(qma)f(q^mp))22f(qma)f(q^mne)22+α]+

说明

改进在于选择了E个负样本, 论文中对于这个改进的描述:

This extension makes the feature comparison more effective and faithful to the few-shot learning procedure, since in each update, the network can compare a sample with multiple negative classes.

Quadruplet loss(四元组损失)

公式

L t = ( d a , p − d a , n 1 + α ) + + ( d a , p − d n 1 , n 2 + β ) + L_{t}=\left(d_{a, p}-d_{a, n_{1}}+\alpha\right)_{+}+\left(d_{a, p}-d_{n_{1}, n_{2}}+\beta\right)_{+} Lt=(da,pda,n1+α)++(da,pdn1,n2+β)+

α \alpha α β \beta β是手动设置的正常数, 通常设置 β \beta β小于 α \alpha α

说明

输入四张图片,即 < a , p , n 1 , n 2 > <a, p, n_1, n_2> <a,p,n1,n2>, 和传统三元组比多了一个负样本。公式中,前一项称为强推动,后一项称为弱推动。

PyTorch 实现

目前好像没有专门的四元组损失接口,但是我理解可以用两个torch.nn.TripletMarginLoss实现?