当前位置: 代码迷 >> 综合 >> Siamese Network Triplet NetWork
  详细解决方案

Siamese Network Triplet NetWork

热度:52   发布时间:2023-10-30 23:25:30.0

Siamese Network(孪生网络)

简单来说,孪生网络就是共享参数的两个神经网络

在孪生网络中,我们把一张图片X1X_1X1?作为输入,得到该图片的编码GW(X1)G_W(X_1)GW?(X1?)。然后,我们在不对网络参数进行任何更新的情况下,输入另一张图片X2X_2X2?,并得到改图片的编码GW(X2)G_W(X_2)GW?(X2?)。由于相似的图片应该具有相似的特征(编码),利用这一点,我们就可以比较并判断两张图片的相似性

孪生网络的损失函数

传统的Siamese Network使用Contrastive Loss(对比损失函数)
L=(1?Y)12(DW)2+(Y)12{max(0,m?DW)}2\mathcal{L} = (1-Y)\frac{1}{2}(D_W)^2+(Y)\frac{1}{2}\{max(0, m-D_W)\}^2 L=(1?Y)21?(DW?)2+(Y)21?{ max(0,m?DW?)}2
其中DWD_WDW?被定义为孪生网络两个输入之间的欧氏距离,即
DW={GW(X1)?GW(X2)}2D_W = \sqrt{\{G_W(X_1)-G_W(X_2)\}^2} DW?={ GW?(X1?)?GW?(X2?)}2 ?

  • YYY值为0或1,如果X1,X2X_1,X_2X1?,X2?这对样本属于同一类,则Y=0Y=0Y=0,反之Y=1Y=1Y=1
  • mmm是边际价值(margin value),即当Y=1Y=1Y=1,如果X1X_1X1?X2X_2X2?之间距离大于mmm,则不做优化(省时省力);如果X1X_1X1?X2X_2X2?之间的距离小于mmm,则调整参数使其距离增大到mmm
Contrastive Loss代码
import torch
import numpy as np
import torch.nn.functional as Fclass ContrastiveLoss(torch.nn.Module):"Contrastive loss function"def __init__(self, m=2.0):super(ContrastiveLoss, self).__init__()self.m = mdef forward(self, output1, output2, label):d_w = F.pairwise_distance(output1, output2)contrastive_loss = torch.mean((1-label) * 0.5 * torch.pow(d_w, 2) +(label) * 0.5 * torch.pow(torch.clamp(self.m - d_w, min=0.0), 2))return contrastive_loss

其中,F.pairwise_distance(x1, x2, p=2)函数公式如下
(∑i=1n(∣x1?x2∣p))1px1,x2∈Rb×n(\sum_{i=1}^n(|x_1-x_2|^p))^{\frac{1}{p}}\\ x_1,x_2 \in \mathbb{R}^{b\times n} (i=1n?(x1??x2?p))p1?x1?,x2?Rb×n

pairwise_distance(x1, x2, p) Computes the batchwise pairwise distance between vectors x1x_1x1?, x2x_2x2? using the p-norm

孪生网络的用途

简单来说,孪生网络的直接用途就是衡量两个输入的差异程度(或者说相似程度)。将两个输入分别送入两个神经网络,得到其在新空间的representation,然后通过Loss Function来计算它们的差异程度(或相似程度)

  • 词汇语义相似度分析,QA中question和answer的匹配
  • 手写体识别也可以用Siamese Network
  • Kaggle上Quora的Question Pair比赛,即判断两个提问是否为同一个问题
Pseudo-Siamese Network(伪孪生网络)

对于伪孪生网络来说,两边可以是不同的神经网络(如一个是lstm,一个是cnn),并且如果是相同的神经网络,是不共享参数

孪生网络和伪孪生网络分别适用的场景
  • 孪生网络适用于处理两个输入比较类似的情况
  • 伪孪生网络适用于处理两个输入有一定差别的情况

例如,计算两个句子或者词汇的语义相似度,使用Siamese Network比较合适;验证标题与正文的描述是否一致(标题和正文长度差别很大),或者文字是否描述了一幅图片(一个是图片,一个是文字)就应该使用Pseudo-Siamese Network

Triplet Network(三胞胎网络)

如果说Siamese Network是双胞胎,那Triplet Network就是三胞胎。它的输入是三个:一个正例+两个负例,或一个负例+两个正例。训练的目标仍然是让相同类别间的距离尽可能小,不同类别间的距离尽可能大。Triplet Network在CIFAR,MNIST数据集上效果均超过了Siamese Network

损失函数定义如下:
L=max(d(a,p)?d(a,n)+margin,0)\mathcal{L}=max(d(a,p)-d(a,n)+margin, 0) L=max(d(a,p)?d(a,n)+margin,0)

  • aaa表示anchor图像
  • ppp表示positive图像
  • nnn表示negative图像

我们希望aaappp的距离应该小于aaannn的距离。marginmarginmargin是个超参数,它表示d(a,p)d(a,p)d(a,p)d(a,n)d(a,n)d(a,n)之间应该相差多少,例如,假设margin=0.2margin=0.2margin=0.2,并且d(a,p)=0.5d(a,p)=0.5d(a,p)=0.5,那么d(a,n)d(a,n)d(a,n)应该大于等于0.70.70.7

Reference

  • 多种类型的神经网络(孪生网络)
  • Siamese network 孪生神经网络–一个简单神奇的结构
  • Siamese Network & Triplet Loss
  • A friendly introduction to Siamese Networks
  • Contrastive Loss
  相关解决方案