当前位置: 代码迷 >> 综合 >> EMNLP’19-Mask-Predict: Parallel Decoding of Conditional Masked Language Models
  详细解决方案

EMNLP’19-Mask-Predict: Parallel Decoding of Conditional Masked Language Models

热度:7   发布时间:2024-02-27 20:28:42.0

Mask-Predict: Parallel Decoding ofConditional Masked Language Models

  • Intorduction
  • Conditional Masked Language Models
    • Architecture
    • Training Objective
    • Predicting Target Sequence Length
  • Decoding with Mask-Predict
    • Formal Description

Intorduction

大多数机器翻译系统使用顺序译玛的策略,其中单词是一个一个预测的。本文展示了一个并行译码的模型,该模型在恒定译码迭代次数下得到翻译结果。本文提出的条件掩码语言模型(CMLMS
解码器的输入是一个完全被masked的句子,并行的预测所有的单词,并在恒定数量的屏蔽-预测循环之后结束。这种整体策略使模型可以在丰富的双向上下文中反复重新考虑单词的选择,并且正如我们将要展示的那样,它仅在几个周期内就可以产生高质量的翻译。Mask?predictMask-predictMask?predict反复掩盖并重新预测模型对当前转换最不满意的单词子集。

Conditional Masked Language Models

  • YYY:目标语句
  • XXX:源语句
  • Yobs,YmaskY_{obs},Y_{mask}Yobs?Ymask?:将目标语句划分为两类。

CMLMCMLMCMLM根据X与YobsX与Y_{obs}XYobs?预测YmaskY_{mask}Ymask?

Architecture

EncoderEncoderEncoder:基于自注意机制对原文本进行编码。
DecoderDecoderDecoder:目标语言的译码器,具有面向编码器输出的自注意机制,以及另外一组面向目标文本的自注意力机制。作者通过删除自注意mask机制来改进标准解码器。

Training Objective

先对目标语句随机的选择YmaskY_{mask}Ymask?,被mask的token的数量遵循正态分布。之后选中的token被一个特殊的MASKMASK\;MASKtoken来代替。作者利用交叉熵来优化模型。并且,尽管译码器的输出是整个目标语句,但只对YmaskY_{mask}Ymask?执行交叉熵损失函数。

Predicting Target Sequence Length

在非自回归机器翻译中,通常将整个编码器的输出作为一个目标语句长度预测模型的输入来得到目标语句的长度。本文中作者,直接将LENGTHLENGTHLENGTH作为一个输入token输入编码器,利用编码器来预测目标语句的长度,即编码器的一个输出为目标序列的长度NNN,并利用交叉熵损失来训练。

Decoding with Mask-Predict

Mask?PredictMask-PredictMask?Predict:首先,先选择若干token进行mask,然后用decoder去预测它们,将输出中,预测概率值小的token再次mask,并输入decoder中再次预测目标序列。

Formal Description

  • 根据Encoder预测出来的目标序列长度NNN,作者定义了两个变量(y1,...,yN)(y_{1},...,y_{N})(y1?,...,yN?)以及(p1,...,pN)(p_{1},...,p_{N})(p1?,...,pN?)。这个过程将进行T个循环(T可以是一个常数或者序列长度N的函数),并且在每次迭代过程中,都会执行Mask操作,然后是预测目标序列。

  • Mask\mathbf{Mask}Mask:在第一次迭代时,作者将N个token全部mask,之后的迭代过程中,作者只mask掉预测概率值最低的n(n是迭代次数的函数即n=N?T?tTn=N\cdot{\frac {T-t}{T}}n=N?TT?t?)个token。即:
    在这里插入图片描述

  • Predict\mathbf{Predict}Predict:在得到Ymask(t)Y_{mask}^{\left(t\right)}Ymask(t)?后,CMLM将根据原文本XXXYobs(t)Y_{obs}^{\left(t\right)}Yobs(t)?来预测被mask掉的token。

举例介绍:
在这里插入图片描述

  • 首先是序列长度预测:作者选择lll个序列长度,并行计算。
  • 给定如上的句子,首先将全部mask的序列输入译码器中,如下:
    在这里插入图片描述
    接下来作者选择这12个token中预测概率值最小的八个,将他们mask掉,在第二次迭代中重新预测。第二次迭代的输出如上,再将预测概率值最低的四个token执行mask,并再次预测。最终经过继续的迭代的到最终的输出,作者对比lll个输出,将概率值对大的序列作为最终输出:
    在这里插入图片描述
  相关解决方案