Iterative Refinement in the Continuous Space for Non-Autoregressive Neural Machine Translation

We propose an efficient inference procedure for non-autoregressive machine translation that iteratively refines translation purely in the continuous space. Given a continuous latent variable model for machine translation (Shu et al., 2020), we train an inference network to approximate the gradient of the marginal log probability of the target sentence, using only the latent variable as input. This allows us to use gradient-based optimization to find the target sentence at inference time that approximately maximizes its marginal probability. As each refinement step only involves computation in the latent space of low dimensionality (we use 8 in our experiments), we avoid computational overhead incurred by existing non-autoregressive inference procedures that often refine in token space. We compare our approach to a recently proposed EM-like inference procedure (Shu et al., 2020) that optimizes in a hybrid space, consisting of both discrete and continuous variables. We evaluate our approach on WMT'14 En-De, WMT'16 Ro-En and IWSLT'16 De-En, and observe two advantages over the EM-like inference: (1) it is computationally efficient, i.e. each refinement step is twice as fast, and (2) it is more effective, resulting in higher marginal probabilities and BLEU scores with the same number of refinement steps. On WMT'14 En-De, for instance, our approach is able to decode 6.2 times faster than the autoregressive model with minimal degradation to translation quality (0.9 BLEU).


Introduction
Most neural machine translation systems are autoregressive, hence decoding latency grows linearly with respect to the length of the target sentence. For faster generation, several work proposed nonautoregressive models with sub-linear decoding latency given sufficient parallel computation (Gu et al., 2018a;Lee et al., 2018;Kaiser et al., 2018).
As it is challenging to precisely model the dependencies among the tokens without autoregression, many existing non-autoregressive models first generate an initial translation which is then iteratively refined to yield better output (Lee et al., 2018;Gu et al., 2019;Ghazvininejad et al., 2019). While various training objectives are used to admit refinement (e.g. denoising, evidence lowerbound maximization and mask language modeling), the generation process of these models is similar in that the refinement process happens in the discrete space of sentences.
Meanwhile, another line of work proposed to use continuous latent variables for non-autoregressive translation, such that the distribution of the target sentences can be factorized over time given the latent variables (Ma et al., 2019;Shu et al., 2020). Unlike the models discussed above, finding the most likely target sentence under these models requires searching over continuous latent variables. To this end, Shu et al. (2020) proposed an EM-like inference procedure that optimizes over a hybrid space consisting of both continuous and discrete variables. By introducing a deterministic delta posterior, it maximizes a proxy lowerbound by alternating between matching the delta posterior to the original approximate posterior (continuous optimization), and finding a target sentence that maximizes the proxy lowerbound (discrete search).
In this work, we propose an iterative inference procedure for latent variable non-autoregressive models that purely operates in the continuous space. 1 Given a latent variable model, we train an inference network to estimate the gradient of the marginal log probability of the target sentence, using only the latent variable as input. At inference time, we find the target sentence that approximately maximizes the log probability by (1) initializing the latent variable e.g. as the mean of the prior, and (2) following the gradients estimated by the inference network.
We compare the proposed approach with the EMlike inference (Shu et al., 2020) on three machine translation datasets: WMT'14 En→De, WMT'16 Ro→En and IWSLT'16 De→En. The advantages of our approach are twofold: (1) each refinement step is twice as fast, as it avoids discrete search over a large vocabulary, and (2) it is more effective, giving higher marginal probabilities and BLEU scores with the same number of refinement steps. Our procedure results in significantly faster inference, for instance giving 6.2× speedup over the autoregressive baseline on WMT'14 En→De at the expense of 0.9 BLEU score.

Background: Iterative Refinement for Non-Autoregressive Translation
We motivate our approach by reviewing existing refinement-based non-autoregressive models for machine translation in terms of their inference procedure. Let us use V, D, T and L to denote vocabulary size, latent dimensionality, target sentence length and the number of refinement steps, respectively.
Most machine translation models are trained to maximize the conditional log probability log p(y|x) of the target sentence y given the source sentence x, averaged over the training data consisting of sentence pairs {(x n , y n )} N n=1 . To find the most likely target sentence at test time, one performs maximum-a-posteriori inference by solving a search problemŷ = argmax y log p(y|x).

Refinement in a Discrete Space
As the lack of autoregression makes it challenging to model the dependencies among the target tokens, most of the existing non-autoregressive translation models use iterative refinement to impose dependencies in the generation process. Various training objectives are used to incorporate refinement, e.g. denoising (Lee et al., 2018), mask language modeling (Ghazvininejad et al., 2019) and evidence lowerbound maximization (Chan et al., 2019;Gu et al., 2019). However, inference procedures employed by these models are similar in that an initial hypothesis is generated and then successively refined. We refer the readers to (Mansimov et al., 2019) for a formal definition of a sequence gen-eration framework that unifies these models, and briefly discuss the inference procedure below.
By viewing each refinement step as introducing a discrete random variable z i (a T ×V -dimensional matrix, where each row is one-hot), inference with L refinement steps requires finding y that maximizes the log probability log p(y|x).
log p θ (y|x) = log (1) As the marginalization over z 1:L is intractable, inference for these models instead maximize the log joint probability with respect toẑ 1:L and y: Approximate search methods are used to findẑ 1:L asẑ i = argmax z i log p θ (z i |ẑ <i , x).

Refinement in a Hybrid Space
Learning On the other hand, Ma et al. (2019); Shu et al. (2020) proposed to use continuous latent variables for non-autoregressive translation. By letting the latent variables z (of dimensionality T × D) capture the dependencies between the target tokens, the decoder p θ (y|z, x) can be factorized over time. As exact posterior inference and learning is intractable for most deep parameterized prior and decoder distributions, these models are trained to maximize the evidence lowerbound (ELBO) (Kingma and Welling, 2014;Wainwright and Jordan, 2008).
Inference Exact maximization of ELBO with respect to y is challenging due to the expectation over z ∼ q φ . To approximately maximize the ELBO, Shu et al. (2020) proposed to optimize a deterministic proxy lowerbound using a Dirac delta posterior: Then, the ELBO reduces to the following proxy lowerbound: = log p θ (y|µ, x) + log p θ (µ|x). Shu et al. (2020) proposed to approximately maximize the ELBO with an EM-like inference procedure, to which we refer as delta inference.
It alternates between continuous and discrete optimization: (1) E-step matches the delta posterior with the approximate posterior by minimizing their KL divergence: (2) M-step maximizes the proxy lowerbound with respect to y:ŷ i = argmax y log p θ (y|µ i , x). Overall, delta inference finds y and µ that maximizes log p θ (y|µ, x) + log q φ (µ|y, x). This iterative inference procedure in hybrid space was empirically shown to result in improved BLEU scores and ELBO on each refinement step (Shu et al., 2020).

Iterative Refinement in a Continuous Space
While the delta inference procedure is an effective inference algorithm for machine translation models with continuous latent variables, it is unsatisfactory as the M-step requires searching over V tokens T times for each refinement step. As V is large for most machine translation models, this is an expensive operation, even when the T searches can be parallelized. We thus propose to replace the delta inference with continuous optimization in the latent space only, given the underlying latent variable model.

Learning
Let us define τ θ (z; x) as the marginal log probability of the most likely target sentence under the latent variable model given z.
whereŷ = argmax y log p θ (y|z, x). Our goal is to find a function −E ψ (z; x) that approximates τ θ (z; x) up to an additive constant and a positive multiplicative factor, such that In this work, instead of directly approximating τ θ , we train −E ψ to learn the difference of τ θ between a pair of configurations of latent variables. Omitting the source sentence x and the model parameters θ for notational simplicity, we solve the following problem for z = z: , as Eq. 3 maximizes their dot product while minimizing its squared norm. As τ θ (z; x) is not differentiable with respect to z due to the argmax operation in Eq. 2, ∇ z τ θ (z; x) is not defined. We thus use a proxy gradient from delta inference. Furthermore, we weigh the latent configuration z according to the prior. Our final training objective for E ψ is then as follows: where z is the output of applying k steps of delta inference on z. If delta inference improves the log probability at each iteration, we hypothesize that ( z − z) is a reasonable approximation to the true gradient ∇ z τ θ (z; x). We empirically show that this is indeed the case in Sec. 5.2.

Parameterization
We have two options for parameterizing ∇ z E ψ (z; x) when minimizing Eq. 4. First, we can parameterize it as the gradient of a scalar-valued function E, to which earlier work have referred as an energy function (Teh et al., 2003;LeCun et al., 2006). Second, we can parameterize it as a function S ψ (z; x) that directly outputs the gradient of the log probability with respect to z (which is often referred to as a score function (Hyvärinen, 2005)), without estimating the energy directly.
While previous work found direct score estimation that bypasses energy estimation unstable (Alain and Bengio, 2014;Saremi et al., 2018), it leads to faster inference by avoiding backpropagation in each refinement step. We compare the two approaches in our experiments.

Algorithm 1: Inference for Latent Variable Models using Learned Gradients
Input :

Inference
At inference time, we initialize the latent variable (e.g. using either a sample from the prior or its mean) and iteratively update the latent variable using the estimated gradients (see Alg. 1). As our inference procedure only involves optimization in the continuous space each step, we avoid having to search over a large vocabulary. We can either perform iterative refinement for a fixed number of steps, or until some convergence condition is satisfied.
We use sentencepiece tokenization (Kudo and Richardson, 2018) with 32K sentencepieces on all datasets. For WMT'16 Ro→En, we follow Sennrich et al. (2016) and normalize Romanian and remove diacritics before applying tokenization. For training, we discard sentence pairs if either the source or the target length exceeds 64 tokens.
Following Lee et al. (2018), we remove repetitions from the translations with a simple postprocessing step before computing BLEU scores. We use detokenized BLEU with Sacrebleu (Post, 2018).
Distillation Following previous work on non-autoregressive translation, we train non-autoregressive models on the target sentences generated by an autoregressive model (Kim and Rush, 2016;Gu et al., 2018a) trained using the FairSeq framework (Ott et al., 2019).
Non-autoregressive latent variable models We closely follow the implementation details from (Shu et al., 2020). The prior and the approximate posterior distributions are spherical Gaussian distributions with learned mean and variance, and the decoder is factorized over time. The only difference is at inference time, the target sentence length is predicted once and fixed throughout the refinement procedure. Therefore, the latent variable dimensionality R T ×D does not change.
The decoder, prior and approximate posterior distributions are all parameterized using n layers Transformer decoder layers (the last two also have a final linear layer that outputs mean and variance). For IWSLT'16 De→En, we use (d model , d filter , n layers , n heads ) = (256, 1024, 3, 4). For WMT'14 En→De and WMT'16 Ro→En, we use (512,2048,6,8). The latent dimensionality d latent is set to 8 across all datasets. The source sentence encoder is implemented with a standard Transformer encoder. Given the hidden states of the source sentence, the length predictor (a 2-layer MLP) predicts the length difference between the source and target sentences as a categorical distribution in [−50, 50].
Energy function E ψ (z; x) is parameterized with n layers Transformer decoder layers and a final linear layer with the output dimensionality of 1. We average the last Transformer hidden states across time and feed it to a linear layer to yield a scalar energy value.
Score function When directly estimating the gradient of the log probability with respect to z, S ψ (z; x) is parameterized with n layers Transformer decoder layers and a final linear layer with the output dimensionality of d latent .   (Shu et al., 2020), the proposed inference procedure with estimated energy (Energy) or score (Score). Speed: inference speedup compared to the autoregressive model with beam width 4. Time: Average wall clock time per example in milliseconds on a Tesla V100 GPU (with standard deviations). b: beam width, L: the number of refinement steps. Search: parallel decoding with 5 length candidates and 5 samples from the prior, with 1 refinement step. Results above Search are obtained by initializing the latent variable as the mean of the prior. We boldface the highest BLEU among the latent variable models.

Training and Optimization
We use the Adam optimizer (Kingma and Ba, 2015) with batch size of 8192 tokens and the learning rate schedule used by Vaswani et al. (2017) with warmup of 8K steps. When training our inference networks, we fix the underlying latent variable model. Our inference networks are trained for 1M steps to minimize Eq. 4, where z is obtained by applying k(= 4) iterations of delta inference on z sampled from the prior. We also find that stochastically applying one gradient update (using the estimated gradients) to z before computing z leads to better performance.

Inference
Step size For the proposed inference procedure, we use the step size α = 1.0 as it performed well on the development set.
Length prediction Given a distribution of target sentence length, we can either (1) take the argmax, or (2) select the top l candidates and decode them in parallel (Ghazvininejad et al., 2019). In the second case, we select the output candidate with the highest log probability under an autoregressive model, normalized by its length.
Latent search In Alg. 1, we can either initialize the latent variable with a sample from the prior, or its mean. We use n w samples from the prior and perform iterative refinement (e.g. delta inference or the proposed inference procedures) in parallel. Similarly to length prediction, we select the output with the highest log probability. To avoid stochasticity, we fix the random seed during sampling. Table 1 presents translation performance and inference speed of several inference procedures for the non-autoregressive latent variable models, along with the autoregressive baselines. We emphasize that the same underlying latent variable model is used across three different inference procedures (Delta, Energy, Score), to compare their efficiency and effectiveness.

Translation Quality and Speed
Translation quality We observe that both of the proposed inference procedures result in improvements in translation quality with more refinement steps. For instance, 4 refinement steps using the learned score function improves BLEU by 2.1 on IWSLT'16 De→En. Among the proposed infer- ence procedures, we find it more effective to use a learned score function, as it gives comparable or better performance to delta inference on all datasets. A learned energy function results in comparable performance to delta inference. Parallel decoding over multiple target length candidates and sampled latent variables leads to significant improvements in BLEU, resulting in 1 BLEU increase or more on all datasets. Similarly to delta inference, we find that the proposed iterative inference procedures converge quite quickly, and often 1 refinement step gives comparable translation quality to running 4 refinement steps.
Inference speed We observe that using a learned score function is significantly faster than delta inference: twice as fast on IWSLT'16 De→En and WMT'16 Ro→En and almost four times as fast on WMT'14 En→De. On WMT'14 En→De, the decoding latency for 4 steps using the score is close to (within one standard deviation of) running 1 refinement step of delta inference. On the other hand, we find that using the learned energy function is slower, presumably due to the overhead from backpropagation. We find its wall clock time to be similar to delta inference. As the entire inference process can be parallelized, we find that parallel decoding with multiple length candidates and latent variable samples only incurs minimal overhead. Finally, we confirm that decoding latency for non-autoregressive models is indeed constant with respect to the sequence length (given parallel computation), as the standard deviation is small (< 10 ms) across test examples.
Overall result Overall, we find the proposed inference procedure using the learned score function highly effective and efficient. On WMT'14 En→De, using 1 refinement step and parallel search leads to 6.2× speedup over the autoregressive baseline with minimal degradation to translation quality (0.9 BLEU score).

Log Probability Comparison
In Fig 1, we report the marginal log probability log p θ (ŷ|x) ofŷ found after L steps of each iterative inference procedure on IWSLT'16 De→En. We estimate the marginal log probability by importance sampling with 500 samples from the approximate posterior. We observe that the log probability improves with more refinement steps for all inference procedures (delta inference and the proposed procedures). We draw two conclusions from this. First, delta inference indeed increases log probability at each iteration. Second, the proposed optimization scheme increases the target objective function it was trained on (log probability).

Token Statistics
We compare delta inference and the proposed inference with a learned score function in terms of token statistics in the output translations on IWSLT'16 De→En. In Figure 2 (left), we compute the average edit distance (in sentencepieces) per test example from the initial output (mean of the prior). It is clear that each refinement step using a learned score function results in more changes in terms of edit distance than delta inference. In Figure 2 (right), we compute the number of token repetitions in the output translations (before removing them in a post-processing step), relative to the initial output. We observe that refining with a learned score function results in less repetitive output compared to delta inference.
6 Qualitative Results

Visualization of learned gradients
We visualize the learned gradients and the optimization trajectory in Figure 3, from a score inference network trained on a two-dimensional latent variable model on IWSLT'16 De→En. The example used to generate the visualization is shown below.  Figure 3: Visualization of estimated gradients and optimization trajectory. Above each plot are tokens predicted from the following latent variables: (1) approximate posterior mean, (2) prior mean, (3) delta inference and (4) inference with the learned score. Black star: latent variable before refinement (prior mean). Blue cross: latent variables after L = {1, 2, 3, 4} steps of delta inference (collapsed into a single point). Green circle: latent variables after L steps of inference with a learned score function. Marker size decreases with successive refinement steps. We observe that for tokens 1, 2 and 6, delta inference converges quickly to the approximate posterior mean. We also find that the local optima estimated by the score function do not necessarily coincide with the approximate posterior mean. For Token 4, while the local optima estimated by the score function (green circle) is far from the posterior mean (red square), they both map to the reference translation ("my"), indicating that there exist multiple latent variables that map to the reference output.

Sample translations
We demonstrate that refining in the continuous space results in non-local, non-trivial revisions to the original sentence. For each example in Table 2, we show the English source sentence, German reference sentence, original translation decoded from a sample from the prior, and the revised translation with one gradient update using the estimated score function.
In Example 1, the positions of the main clause ("Es gibt nicht vieleÄrzte") and the prepositional phrase ("im westafrikanischen Land") are reversed in the continuous refinement process. Inside the main clause, "es gibt" is revised to "gibt es", a correct grammatical form in German when the prepositional phrase comes before the main clause.
In Example 2, the two numbers are exchanged (" 1,2 Milliarden Dollar" and " 6,9 Milliarden Dollar") in the revised translation. Also, the phrase "aus den" (out of the) is correctly inserted between the two.
In Example 3, the noun phrase "Weisheit in Bedouin" is combined into a single German compound noun "Bedouin-Weisheit". Also, the phrases "Der erste ..." and "mit dieser ..." are swapped in the refinement process, to better resemble the reference sentence.

Related Work
Learning Our training objective is closely related to the score matching objective (Hyvärinen, 2005), with the following differences. First, we approximate the gradient of the data log density using a proxy gradient, whereas this term is replaced by the Hessian of the energy in the original score matching objective. Second, we only consider samples from the prior. Saremi et al. (2018) proposed a denoising interpretation of the Parzen score objective (Vincent, 2011) that avoids estimating the Hessian. Although score function estimation that bypasses energy estimation was found to be unstable (Alain and Bengio, 2014;Saremi et al., 2018), it has been successfully applied to generative modeling of images (Song and Ermon, 2019).
Inference While we categorize inference methods for machine translation as (1) discrete search, (2) hybrid optimization (Shu et al., 2020) and (3) continuous optimization (this work) in Section 2, another line of work relaxes discrete search into continuous optimization (Hoang et al., 2017;Gu et al., 2018b;Tu et al., 2020). By using Gumbelsoftmax relaxation (Maddison et al., 2017;Jang et al., 2017), they train an inference network to generate target tokens that maximize the log probability under a pretrained model.

Example 1
Source There aren 't many doctors in the west African country ; just one for every 5,000 people Reference In dem westafrikanischen Land gibt es nicht viele rzte, nur einen fr 5.000 Menschen Original Es gibt nicht viele rzte im westafrikanischen Land, nur eine fr 5.000 Menschen. Refined Im westafrikanischen Land gibt es nicht viele rzte, nur eine fr 5.000 Menschen.
Example 2 Source Costumes are expected to account for $ 1.2 billion dollars out of the $ 6.9 billion spent , according to the NRF . Reference Die Kostme werden etwa 1,2 Milliarden der 6,9 Milliarden ausgegebenen US-Dollar ausmachen, so der NRF.
Example 3 Source It was with this piece of Bedouin wisdom that the first ever chairman Wolfgang Henne described the history and fascination behind the "Helping Hands" society . Reference Mit dieser Beduinenweisheit beschrieb der erste Vorsitzende Wolfgang Henne die Geschichte und Faszination des Vereins "Helfende Hnde". Original Der erste Vorsitzende Wolfgang Henne beschrieb mit dieser erste Weisheit in Bedouin" die Geschichte und Faszination hinter der "Helenden Hands" Gesellschaft Refined Mit diesem Stck Bedouin-Weisheit beschrieb der erste Vorsitzende Wolfgang Henne jemals die Geschichte und Faszination hinter der "Heling Hands" Gesellschaft Table 2: Sample translations on WMT'14 En→De. We show the translation from a latent variable sampled from the prior (Original) and the translation after one refinement step in the continuous space with the learned score function (Refined). We emphasize phrases whose positions are swapped in the refinement process in red and blue.
Gradient-based Inference Performing gradient descent over structured outputs was mentioned in LeCun et al. (2006), and has been successfully applied to many structured prediction tasks (Belanger and McCallum, 2016;Wang et al., 2016;Belanger et al., 2017). Other work performed gradient descent over the latent variables to optimize objectives for a wide variety of tasks, including chemical design (Gómez-Bombarelli et al., 2018) and text generation (Mueller et al., 2017) Generation by Refinement Refinement has a long history in text generation. The retrieve-andrefine framework retrieves an (input, output) pair from the training set that is similar to the test example, and performs edit operations on the corresponding output (Sumita and Iida, 1991;Song et al., 2016;Hashimoto et al., 2018;Weston et al., 2018;Gu et al., 2018c). The idea of refinement has also been applied in automatic post-editing (Novak et al., 2016;Grangier and Auli, 2017).

Conclusion
We propose an efficient inference procedure for non-autoregressive machine translation that refines translations purely in the continuous space. Given a latent variable model for machine translation, we train an inference network to approximate the gradient of the marginal log probability with respect to the target sentence, using only the latent variable. This allows us to use gradient based optimization to find a target sentence at inference time that approximately maximizes the marginal log probability. As we avoid discrete search over a large vocabulary, our inference procedure is more efficient than previous inference procedures that refine in the token space. We compare our approach with a recently proposed delta inference procedure that optimizes jointly in discrete and continuous space on three machine translation datasets: WMT'14 En→De, WMT'16 Ro→En and IWSLT'16 De→En. With the same underlying latent variable model, the proposed inference procedure using a learned score function has following advantages: (1) it is twice as fast as delta inference, and (2) it is able to find target sentences resulting in higher marginal probabilities and BLEU scores.
While we showed that iterative inference with a learned score function is effective for spherical Gaussian priors, more work is required to investigate if such an approach will also be successful for more sophisticated priors, such as Gaussian mixtures or normalizing flows. This will be particularly interesting, as recent study showed latent variable models with a flexible prior give high test loglikelihoods, but suffer from poor generation quality as inference is challenging .