Mask-Predict: Parallel Decoding of Conditional Masked Language Models

Most machine translation systems generate text autoregressively from left to right. We, instead, use a masked language modeling objective to train a model to predict any subset of the target words, conditioned on both the input text and a partially masked target translation. This approach allows for efficient iterative decoding, where we first predict all of the target words non-autoregressively, and then repeatedly mask out and regenerate the subset of words that the model is least confident about. By applying this strategy for a constant number of iterations, our model improves state-of-the-art performance levels for non-autoregressive and parallel decoding translation models by over 4 BLEU on average. It is also able to reach within about 1 BLEU point of a typical left-to-right transformer model, while decoding significantly faster.


Introduction
Most machine translation systems use sequential decoding strategies where words are predicted one-by-one. In this paper, we present a model and a parallel decoding algorithm which, for a relatively small sacrifice in performance, can be used to generate translations in a constant number of decoding iterations.
We introduce conditional masked language models (CMLMs), which are encoder-decoder architectures trained with a masked language model objective (Devlin et al., 2018;Lample and Conneau, 2019). This change allows the model to learn to predict, in parallel, any arbitrary subset of masked words in the target translation. We use transformer CMLMs, where the decoder's self attention (Vaswani et al., 2017) can attend to the entire sequence (left and right context) to predict each masked word. We train with a simple masking scheme where the number of masked target tokens is distributed uniformly, presenting the model with both easy (single mask) and difficult (completely masked) examples. Unlike recently proposed insertion models (Gu et al., 2019;Stern et al., 2019), which treat each token as a separate training instance, CMLMs can train from the entire sequence in parallel, resulting in much faster training.
We also introduce a new decoding algorithm, mask-predict, which uses the order-agnostic nature of CMLMs to support highly parallel decoding. Mask-predict repeatedly masks out and repredicts the subset of words in the current translation that the model is least confident about, in contrast to recent parallel decoding translation approaches that repeatedly predict the entire sequence . Decoding starts with a completely masked target text, to predict all of the words in parallel, and ends after a constant number of mask-predict cycles. This overall strategy allows the model to repeatedly reconsider word choices within a rich bi-directional context and, as we will show, produce high-quality translations in just a few cycles.
Experiments on benchmark machine translation datasets show the strengths of mask-predict decoding for transformer CMLMs. With just 4 iterations, BLEU scores already surpass the performance of the best non-autoregressive and parallel decoding models. 2 With 10 iterations, the approach outperforms the current state-of-the-art parallel decod-ing model  by gaps of 4-5 BLEU points on the WMT'14 English-German translation benchmark, and up to 3 BLEU points on WMT'16 English-Romanian, but with the same model complexity and decoding speed. When compared to standard autoregressive transformer models, CMLMs with mask-predict offer a tradeoff between speed and performance, trading up to 2 BLEU points in translation quality for a 3x speed-up during decoding.

Conditional Masked Language Models
A conditional masked language model (CMLM) predicts a set of target tokens Y mask given a source text X and part of the target text Y obs . It makes the strong assumption that the tokens Y mask are conditionally independent of each other (given X and Y obs ), and predicts the individual probabilities P (y|X, Y obs ) for each y ∈ Y mask . Since the number of tokens in Y mask is given in advance, the model is also implicitly conditioning on the length of the target sequence N = |Y mask | + |Y obs |.

Architecture
We adopt the standard encoder-decoder transformer for machine translation (Vaswani et al., 2017): a source-language encoder that does selfattention, and a target-language decoder that has one set of attention heads over the encoder's output and another set for the target language (selfattention). In terms of parameters, our architecture is identical to the standard one. We deviate from the standard decoder by removing the selfattention mask that prevents left-to-right decoders from attending on future tokens. In other words, our decoder is bi-directional, in the sense that it can use both left and right contexts to predict each token.

Training Objective
During training, we randomly select Y mask among the target tokens. We first sample the number of masked tokens from a uniform distribution between one and the sequence's length, and then randomly choose that number of tokens. Following Devlin et al. (2018), we replace the inputs of the tokens Y mask with a special MASK token.
We optimize the CMLM for cross-entropy loss over every token in Y mask . This can be done in parallel, since the model assumes that the tokens in Y mask are conditionally independent of each other.
While the architecture can technically make predictions over all target-language tokens (including Y obs ), we only compute the loss for the tokens in Y mask .

Predicting Target Sequence Length
In traditional left-to-right machine translation, where the target sequence is predicted token by token, it is natural to determine the length of the sequence dynamically by simply predicting a special EOS (end of sentence) token. However, for CMLMs to predict the entire sequence in parallel, they must know its length in advance. This problem was recognized by prior work in nonautoregressive translation, where the length is predicted with a fertility model (Gu et al., 2018) or by pooling the encoder's outputs into a length classifier .
We follow Devlin et al. (2018) and add a special LENGTH token to the encoder, akin to the CLS token in BERT. The model is trained to predict the length of the target sequence N as the LENGTH token's output, similar to predicting another token from a different vocabulary, and its loss is added to the cross-entropy loss from the target sequence.

Decoding with Mask-Predict
We introduce the mask-predict algorithm, which decodes an entire sequence in parallel within a constant number of cycles. At each iteration, the algorithm selects a subset of tokens to mask, and then predicts them (in parallel) using an underlying CMLM. Masking the tokens where the model has doubts while conditioning on previous highconfidence predictions lets the model re-predict the more challenging cases, but with more information. At the same time, the ability to make large parallel changes at each step allows mask-predict to converge on a high quality output sequence in a sub-linear number of decoding iterations.

Formal Description
Given the target sequence's length N (see Section 3.3), we define two variables: the target sequence (y 1 , . . . , y N ) and the probability of each token (p 1 , . . . , p N ). The algorithm runs for a predetermined number of iterations T , which is either a constant or a simple function of N . At each iteration, we perform a mask operation, followed by predict.

src
Der Abzug der franzsischen Kampftruppen wurde am 20. November abgeschlossen .  Mask For the first iteration (t = 0), we mask all the tokens. For later iterations, we mask the n tokens with the lowest probability scores: The number of masked tokens n is a function of the iteration t; specifically, we use linear decay n = N · T −t T , where T is the total number of iterations. For example, if T = 10, we will mask 90% of the tokens at t = 1, 80% at t = 2, and so forth.
Predict After masking, the CMLM predicts the masked tokens Y (t) mask , conditioned on the source text X and the unmasked target tokens Y (t) obs . We select the prediction with the highest probability for each masked token y i ∈ Y (t) mask and update its probability score accordingly: The values and the probabilities of unmasked tokens Y (t) obs remain unchanged: We tried updating or decaying these probabilities in preliminary experiments, but found that this heuristic works well despite the fact that some probabilities are stale. Figure 1 illustrates how mask-predict can generate a good translation in just three iterations.

Example
In the first iteration (t = 0), the entire target sequence is masked (Y (0) mask = Y and Y (0) obs = ∅), and is thus generated by the CMLM in a purely non-autoregressive process: This produces an ungrammatical translation with repetitions ("completed completed"), which is typical of non-autoregressive models due to the multi-modality problem (Gu et al., 2018).
In the second iteration (t = 1), we select 8 of the 12 tokens generated in the previous step; these token were predicted with the lowest probabilities at t = 0. We mask them and repredict with the CMLM, while conditioning on the 4 unmasked tokens Y (1) obs = {"The", "20", "November", "."}. This results in a more grammatical and accurate translation. Our analysis shows that this second iteration removes most repetitions, perhaps because conditioning on even a little bit of the target sequence is enough to collapse the multi-modal target distribution into a single output (Section 5.1).
In the last iteration (t = 2), we select the 4 of the 12 tokens that had the lowest probabilities. Two of those tokens were predicted at the first step (t = 0), and not repredicted at the second step (t = 1). It is quite common for earlier predictions to be masked at later iterations because they were predicted with less information and thus tend to have lower probabilities. Now that the model is conditioning on 8 tokens, it is able to produce an more fluent translation; "withdrawal" is a better fit for describing troop movement, and "November 20th" is a more common date format in English.

Deciding Target Sequence Length
When generating, we first compute the CMLM's encoder, and then use the LENGTH token's encoding to predict a distribution over the target sequence's length (see Section 2.3). Since much of the CMLM's computation can be batched, we select the top length candidates with the highest probabilities, and decode the same example with different lengths in parallel. We then select the sequence with the highest average log-probability as our result: Our analysis reveals that translating multiple candidate sequences of different lengths can improve performance (see Section 5.3).

Experiments
We evaluate CMLMs with mask-predict decoding on standard machine translation benchmarks. We find that our approach significantly outperforms prior parallel decoding machine translation methods and even approaches the performance of standard autoregressive models (Section 4.2), while decoding significantly faster (Section 4.3).  (Papineni et al., 2002) for all language pairs, except from EN to ZH, where we use SacreBLEU (Post, 2018). 3

Experimental Setup
Hyperparameters We follow most of the standard hyperparameters for transformers in the base configuration (Vaswani et al., 2017): 6 layers per stack, 8 attention heads per layer, 512 model dimensions, 2048 hidden dimensions. We also experiment with 512 hidden dimensions, for comparison with previous parallel decoding models (Gu et al., 2018;. We follow the weight initialization scheme from BERT (Devlin et al., 2018), which samples weights from N (0, 0.02), initializes biases to zero, and sets layer normalization parameters to β = 0, γ = 1. For regularization, we use 0.3 dropout, 0.01 L 2 weight decay, and smoothed cross validation loss with ε = 0.1. We train batches of 128k tokens using Adam (Kingma and Ba, 2015) with β = (0.9, 0.999) and ε = 10 −6 . The learning rate warms up to a peak of 5 · 10 −4 within 10,000 steps, and then decays with the inverse squareroot schedule. We trained all models for 300k steps, measured the validation loss at the end of each epoch, and averaged the 5 best checkpoints to create the final model. During decoding, we use a beam size of b = 5 for autoregressive decoding, and similarly use = 5 length candidates for mask-predict decoding. We trained with mixed precision floating point arithmetic on two DGX-1 machines, each with eight 16GB Nvidia V100 GPUs interconnected by Infiniband (Micikevicius et al., 2018).

Translation Quality
We compare our approach to three other parallel decoding translation methods: the fertility-based sequence-to-sequence model of Gu et al. (2018), the CTC-loss transformer of Libovický and Helcl (2018), and the iterative refinement approach of . The first two methods are purely non-autoregressive, while the iterative refinement approach is only non-autoregressive in the first decoding iteration, similar to our approach. In terms of speed, each mask-predict iteration is virtually equivalent to a refinement iteration. Table 1 shows that among the parallel decoding methods, our approach yields the highest BLEU scores by a considerable margin. When controlling for the number of parameters (i.e. considering only the smaller CMLM configuration), CMLMs score roughly 4 BLEU points higher than the previous state of the art on WMT'14 EN-DE, in both directions. Another striking result is that a CMLM with only 4 mask-predict iterations yields higher scores than 10 iterations of the iterative refinement model; in fact, only 3 mask-predict iterations are necessary for achieving a new state of the art on both directions of WMT'14 EN-DE (not shown).
The translations produced by CMLMs with mask-predict also score competitively when compared to strong transformer-based autoregressive models. In all 4 benchmarks, our base CMLM reaches within 0.5-1.2 BLEU points from a welltuned base transformer, a relative decrease of less   than 4% in translation quality. In many scenarios, this is an acceptable price to pay for a significant speedup from parallel decoding. Table 2 shows that these trends also hold for English-Chinese translation, in both directions, despite major linguistic differences between the two languages.

Decoding Speed
Because CMLMs can predict the entire sequence in parallel, mask-predict can translate an entire sequence in a constant number of decoding iterations. Does this appealing theoretical property translate into a wall-time speed-up in practice? By comparing the actual decoding times, we show that, for some sacrifice in performance, our parallel method can translate much faster than standard sequential transformers.
Setup As the baseline system, we use the base transformer with beam search (b = 5) to translate WMT'14 EN-DE; we also use greedy search (b = 1) as a faster but less accurate baseline. For CMLMs, we vary the number of mask-predict iterations (T = 4, . . . , 10) and length candidates ( = 1, 2, 3). For both models, we decode batches of 10 sentences. 4 For each decoding run, we measure the performance (BLEU) and wall time (seconds) from when the model and data have been loaded until the last example has been translated, and calculate the relative decoding speed-up (CMLM time / baseline time) to assess the speedperformance trade-off.
The implementation of both the baseline transformer and our CMLM is based on fairseq (Gehring et al., 2017), which efficiently decodes left-to-right transformers by caching the state. Caching reduces the baseline's decoding speed from 210 seconds to 128.5; CMLMs do not use cached decoding. All experiments used exactly the same machine and the same single GPU. Results Figure 2 shows the speed-performance trade-off. We see that mask-predict is versatile; on one hand, we can translate over 3 times faster than the baseline at a cost of 2 BLEU points (T = 4, = 2), or alternatively retain a high quality of 27.03 BLEU while gaining a 30% speed-up (T = 4, = 2). Surprisingly, this latter configuration outperforms an autoregressive transformer with greedy decoding (b = 1) in both quality and speed. We also observe that more balanced configurations (e.g. T = 8, = 2) yield similar performance to the single-beam autoregressive transformer, but decode much faster.

Analysis
To complement the quantitative results in Section 4, we present qualitative analysis that provides some intuition as to why our approach works and where future work could potentially improve it.

Why Are Multiple Iterations Necessary?
Various non-autoregressive translation models, including our own CMLM, make the strong assumption that the individual token predictions are conditionally independent of each other. Such a model might consider two or more possible translations, A and B, but because there is no coordination mechanism between the token predictions, it could predict one token from A and another token from B. This problem, known as the multimodality problem (Gu et al., 2018), often manifest as token repetitions in the output when the model has multiple hypotheses that predict the same word w with high confidence, but at different positions.
We hypothesize that multiple mask-predict iterations alleviate the multi-modality problem by allowing the model to condition on parts of the input, thus collapsing the multi-modal distribution into a sharper uni-modal distribution. To test our  hypothesis, we measure the percentage of repetitive tokens produced by each iteration of maskpredict as a proxy metric for the multi-modality problem. Table 3 shows that, indeed, the proportion of repetitive tokens drops drastically during the first 2-3 iterations. This finding suggests that the first few iterations are critical for converging into a uni-modal distribution. The decrease in repetitions also correlates with the steep rise in translation quality (BLEU), supporting the conjecture of Gu et al. (2018) that multi-modality is a major roadblock for purely non-autoregressive machine translation.

Do Longer Sequences Need More
Iterations?
A potential concern with using a constant amount of decoding iterations is that it may be effective for short sequences (where the number of iterations T is closer to the output's length N ), but insufficient for longer sequences. To determine whether this is the case, we use compare-mt (Neubig et al., 2019) to bucket the evaluation data by target sentence length and compute the performance with different values of T . Table 4 shows that increasing the number of decoding iterations (T ) appears to mainly improve the performance on longer sequences. Having said that, the performance differences across length buckets are not very large, and it seems that even 4 mask-predict iterations are enough to produce decent translations for long sequences (40 ≤ N ).

Do More Length Candidates Help?
Traditional autoregressive models can dynamically decide the length of the target sequence by generating a special END token when they are done, but that is not true for models that decode multiple tokens in parallel, such as CMLMs.
To address this problem, our model predicts the   length of the target sequence (Section 2.3) and decodes multiple length candidates in parallel (Section 3.3). We compare our model's performance with a varying number of length candidates to its performance when conditioned on the reference (gold) target length in order to determine how accurate it is at predicting the correct length and assess the relative contribution of decoding with multiple length candidates. Table 5 shows that having multiple candidates can increase performance almost as much as conditioning on the gold length. Surprisingly, adding too many candidates can even degrade performance. We suspect that because CMLMs are implicitly conditioned on the target length, producing a translation that is too short (i.e. high precision, low recall) will have a high average log probability. In preliminary experiments, we tried to address this issue by weighting the different candidates according to the model's length prediction, but this approach gave too much weight to the top candidate and resulted in lower performance.

Is Model Distillation Necessary?
Previous work on non-autoregressive and insertion-based machine translation reported that it was necessary to train their models on text generated by an autoregressive teacher model, a process known as distillation. To determine CMLM's dependence on this process, we train a models on both raw and distilled data, and compare their performance. Table 6 shows that in every case, training with model distillation substantially outperforms training on raw data. The gaps are especially large when decoding with a single iteration (purely nonautoregressive). Overall, it appears as though CMLMs are heavily dependent on model distillation.
On the English-Romanian benchmark, the differences are much smaller, and after 10 iterations the raw-data model can perform comparably with the distilled model. A possible explanation is that our teacher model was weaker for this dataset due to insufficient hyperparameter tuning. Alternatively, it could also be the case that the English-German dataset is much noisier than the English-Romanian one, and that the teacher model essentially cleans the training data. Unfortunately, we do not have enough evidence to support or refute either hypothesis at this time.

Related Work
Training Masked Language Models with Translation Data Recent work by Lample and Conneau (2019) shows that training a masked language model on sentence-pair translation data, as a pre-training step, can improve performance on cross-lingual tasks, including autoregressive machine translation. Our training scheme builds on their work, with the following differences: we use separate model parameters for source and target texts (encoder and decoder), and we also use a different masking scheme. Specifically, we mask a varying percentage of tokens, only from the target, and do not replace input tokens with noise. Most importantly, the goal of our work is different; we do not use CMLMs for pre-training, but to directly generate text with mask-predict decoding.
Concurrently with our work, Song et al. (2019) extend the approach of Lample and Conneau (2019) by using separate encoder and decoder parameters (as in our model) and pre-training them jointly in an autoregressive version of masked language modeling, although with monolingual data. While this work demonstrates that pretraining CMLMs can improve autoregressive machine translation, it does not try to leverage the parallel and bi-directional nature of CMLMs to generate text in a non-left-to-right manner.

Generating from Masked Language Models
One such approach for generating text from a masked language model casts BERT (Devlin et al., 2018), a non-conditional masked language model, as a Markov random field (Wang and Cho, 2019). By masking a sequence of length N and then iteratively sampling a single token at each time from the model (either sequentially or in arbitrary order), one can produce grammatical examples. While this sampling process has a theoretical justification, it also requires N forward passes of the model; mask-predict decoding, on the other hand, can produce text in a constant number of iterations.

Parallel Decoding for Machine Translation
There have been several advances in parallel decoding machine translation by training nonautoregressive models. Gu et al. (2018) introduce a transformer-based approach with explicit word fertility, and identify the multi-modality problem. Libovický and Helcl (2018) approach the multimodality problem by collapsing repetitions with the Connectionist Temporal Classification training objective (Graves et al., 2006). Perhaps most similar to our work is the iterative refinement approach of , in which the model corrects the original non-autoregressive prediction by passing it multiple times through a denoising autoencoder. A major difference is that  train their noisy autoencoder to deal with corrupt inputs by applying stochastic corruption heuristics on the training data, while we simply mask a random number of input tokens. We also show that our approach outperforms all of these models by wide margins.
Arbitrary Order Language Generation Finally, recent work has developed insertion-based transformers for arbitrary, but fixed, word order generation (Gu et al., 2019;Stern et al., 2019). While they do not decode in a constant number of iterations, Stern et al. (2019) show strong results in logarithmic time. Both models treat each token insertion as a separate training example, which cannot be computed in parallel with every other insertion in the same sequence. This makes training significantly more expensive that standard transformers (which use causal attention masking) and our CMLMs (which can predict all of the masked tokens in parallel).

Conclusion
This work introduces conditional masked language models and a novel mask-predict decoding algorithm that leverages their parallelism to generate text in a constant number of decoding iterations. We show that, in the context of machine translation, our approach substantially outperforms previous parallel decoding methods, and can approach the performance of sequential autoregressive models while decoding much faster. While there are still open problems, such as the need to condition on the target's length and the dependence on knowledge distillation, our results provide a significant step forward in nonautoregressive and parallel decoding approaches to machine translation. In a broader sense, this paper shows that masked language models are useful not only for representing text, but also for generating text efficiently.