ALBEF论文笔记
简介
现有方法的缺陷
- 图像特征和单词标记嵌入位于它们自己的空间中,这使得多模态编码器学习建模其交互具有挑战性;
- 目标检测器既注释昂贵又计算昂贵,因为它在训练前需要边界框注释,以及高分辨率(例如。600×1000)图像;
- 广泛使用的图像-文本数据集是从web收集的,具有固有的噪声,现有的训练前目标如MLM可能会过拟合噪声文本,降低模型的泛化性能。
ALBEF的优势
- 图文对齐后再融合。对于图片的 Embedding 和文本的 Embedding 引入一个对比学习的损失函数 image-text contrastive loss,在融合之前提前把图片和文本的表征对齐。使得后续的多模态 Transformer 更容易执行跨模态学习。
- 不使用目标检测器,不需要边界框注释或高分辨率的图像。
- 提出了动量蒸馏方法,这是一种从动量模型产生的伪目标中学习的自训练方法,改进了从有噪声的web数据中学习。
方法
- 用一个无检测器的图像编码器和一个文本编码器独立地编码图像和文本。
- 然后利用多模态编码器,通过跨模态注意的方法,将图像特征与文本特征进行融合。
- 在来自单模态编码器的表示中引入了一个中间的图像-文本对比(ITC)损失,目的:
它对齐图像特征和文本特征,使多模态编码器更容易执行跨模态学习;
改进单模态编码器,更好地理解图像和文本的语义意义;
学习一个通用的低维空间嵌入图像和文本,使图像-文本匹配目标(ITM)通过对比硬负挖掘(Hard Negative Mining)找到信息更丰富的样本。
Hard Negative Mining:基本思想是在训练过程中,重点关注那些模型难以正确分类的负样本。在训练过程中,模型会对每个样本进行预测,并生成一个预测概率。对于负样本,如果模型给出的预测概率很高(即模型错误地认为它是一个正样本),那么这个负样本就是一个hard negative。通过将这些hard negatives加入到训练集中,我们可以帮助模型更好地学习如何区分正样本和负样本
- 为了改进在噪声监督下的学习,我们提出了动量蒸馏(MoD),使模型能够利用一个更大的未经管理的web数据集。在训练过程中,我们通过取模型参数的移动平均数来保持模型的动量版本,并使用动量模型生成伪目标作为额外的监督。使用MoD时,该模型不会因为产生其他不同于web注释的合理输出而受到惩罚。我们证明了MoD不仅改进了训练前的任务,而且还改进了具有干净注释的下游任务。
三个优化目标:单模态编码器上的图像文本对比学习(ITC)、多模态编码器上的掩码语言建模(MLM)和图像文本匹配(ITM)。其中,通过在线对比Hard Negative Mining改进了ITM
图像文本对比学习(ITC)
即Align before Fuse的核心
s是相似度函数,$v_{cls}$是[CLS] token的embedding(图像编码器编码得到的),$w_{cls}$是文本编码其得到的, $g_v$和$g_w$是一种线性转换,它们将[CLS]embedding映射到归一化的低维(256-d)表示
维护两个队列来存储来自动量单模态编码器的最新的M个图像-文本表示(当作当前图像、文本的负样本)。动量编码器的归一化特征分别记为$g^{‘}_v(v^{‘}_{cls})$和$g^{‘}_w(w^{‘}_{cls})$。
动量在数学上就是加权移动平均。例如$ y_t=m \times y_{t-1}+(1-m) \times x_t $,$ y_{t-1}$为上一时刻的输出,$x_t$为当前输入,$m$为动量参数;当 $m$很大时,$y_t$就取决于上一时刻输出,其更新就很缓慢;当$m$很小时,$y_t$就取决于当前时刻输入
$p^{i2t}_m、p^{t2i}_m$是文本到图像、图像到文本的softmax归一化相似度。$y^{i2t}(I)$和$y^{t2i}(T)$表示真实独热编码相似度,其中负对的概率为0,正对的概率为1。 $\tau$是个可学习的参数,H是交叉熵损失。
多模态编码器上的掩码语言建模(MLM)
利用图像和上下文文本来预测掩蔽词。我们以15%的概率随机掩码出输入标记,并用特殊的标记$[MASK]$替换它们。设$\hat{T}$表示一个被mask过的文本,而$p^{msk}(I,\hat{T})$表示模型对一个该文本被mask处的token的预测概率, $y^{msk}$是一个独热编码的词汇分布,其中ground-truth token的概率为1。MLM使交叉熵损失最小化:
图像文本匹配(ITM)
可以预测一对图像和文本是正的(匹配)还是负的(不匹配)。使用多模态编码器的输出嵌入$[CLS]$ token作为图像-文本对的联合表示,并附加一个全连接(FC)层,然后使用softmax来计算一个二分类概率$p^{itm}$。ITM损失为:
$y^{itm}$是一个二维独热编码向量,表示真实标签。
Hard Negative Mining:我们提出了一种在零计算开销的ITM任务中进行硬负样本采样的策略。如果一个负的图像-文本对共享相似的语义,但在细粒度的细节上有所不同,那么它们是硬的(hard)。我们使用方程1中的对比相似度来寻找一个batch内的硬负样本:对于一个batch中的每一幅图像,按照对比相似度分布采样与图像最相似的文本作为负文本(即模型认为它非常可能是正样本,但其实它是负样本)。同样地,为每个文本采样一个最hard的负图像。
设 B 为 Batch Size,从代码中可以看出最终的预测维度是 [3B, 2],标签维度是 3B,前 B 个样本都是正样本,其余 2B 都是负样本。
动量蒸馏
解决的问题:从网上收集的数据存在噪声。对于ITC学习,图像的负文本也可能与图像的内容相匹配。对于MLM,可能存在其他不同于描述图像但描述得同样好(或更好)的注释的词。然而,ITC和MLM的one-hot标签惩罚所有负面预测,不管它们的正确性。
解决办法:从动量模型产生的伪目标中学习。动量模型是一个连续发展的教师模型,它由单模态和多模态编码器的指数-移动平均版本组成。在训练过程中,我们训练base模型,使其预测与动量模型的预测相匹配。具体来说,对于ITC,我们首先使用动量单模态编码器的特征来计算图像-文本相似度,即$s^{‘}(I, T) = g_v^{‘}(v_{cls})^T g_w^{‘}(w^{‘}_{cls})$和$s^{‘}(T, I) = g_w^{‘}(w_{cls})^T g_v^{‘}(v^{‘}_{cls})$。然后把方程1中的$s$替换成$s^{‘}$来计算软伪目标$q^{i2t}$和$q^{t2i}$。$ITC_{MoD}$损失的定义为:
类似的,$MLM_{MoD}$损失的定义为:
KL:KL散度计算,又称相对熵或信息散度。我们设定两个概率分布分别为P和Q,在设定为连续随机变量的前提下,他们对应的概率密度函数分别为p(x)和q(x)。如果我们用p(x)去近似q(x),则KL散度可以表示为: