Understanding the Mechanics of SPIGOT: Surrogate Gradients for Latent Structure Learning

Latent structure models are a powerful tool for modeling language data: they can mitigate the error propagation and annotation bottleneck in pipeline systems, while simultaneously uncovering linguistic insights about the data. One challenge with end-to-end training of these models is the argmax operation, which has null gradient. In this paper, we focus on surrogate gradients, a popular strategy to deal with this problem. We explore latent structure learning through the angle of pulling back the downstream learning objective. In this paradigm, we discover a principled motivation for both the straight-through estimator (STE) as well as the recently-proposed SPIGOT - a variant of STE for structured models. Our perspective leads to new algorithms in the same family. We empirically compare the known and the novel pulled-back estimators against the popular alternatives, yielding new insight for practitioners and revealing intriguing failure cases.


Introduction
Natural language data is rich in structure, but most of the structure is not visible at the surface. Machine learning models tackling high-level language tasks would benefit from uncovering underlying structures such as trees, sequence tags, or segmentations. Traditionally, practitioners turn to pipeline approaches where an external, pretrained model is used to predict, e.g., syntactic structure. The benefit of this approach is that the predicted tree is readily available for inspection, but the downside is that the errors can easily propagate throughout the pipeline and require further attention (Finkel et al., 2006;Sutton and McCallum, 2005;Toutanova, 2005). In contrast, deep neural architectures tend to eschew such preprocessing, and instead learn soft hidden representations, not easily amenable to visualization and analysis.
The best of both worlds would be to model structure as a latent variable, combining the transparency of the pipeline approach with the endto-end unsupervised representation learning that makes deep models appealing. Moreover, largecapacity model tend to rediscover structure from scratch (Tenney et al., 2019), so structured latent variables may reduce the required capacity.
Learning with discrete, combinatorial latent variables is, however, challenging, due to the intersection of large cardinality and null gradient issues. For example, when learning a latent dependency tree, the latent parser must choose among an exponentially large set of possible trees; what's more, the parser may only learn from gradient information from the downstream task. If the highestscoring tree is selected using an argmax operation, the gradients will be zero, preventing learning.
One strategy for dealing with the null gradient issue is to use a surrogate gradient, explicitly overriding the zero gradient from the chain rule, as if a different computation had been performed. The most commonly known example is the straight-through estimator (STE; Bengio et al., 2013), which pretends that the argmax node was instead an identity operator. Such methods lead to a fundamental mismatch between the objective and the learning algorithm. The effect of this mismatch is still insufficiently understood, and the design of successful new variants is therefore challenging. For example, the recently-proposed SPIGOT method (Peng et al., 2018) found it beneficial to use a projection as part of the surrogate gradient.
In this paper, we study surrogate gradient methods for deterministic learning with discrete structured latent variables. Our contributions are: • We propose a novel motivation for surrogate gra-x sẑŷ L(ŷ, y) x sẑŷ L(ŷ, y) Figure 1: A model with a discrete latent variable z. Given an input x, we assign a score s z = [f (x)] z to each choice, and pick the highest scoring one,ẑ, to predictŷ = g θ (ẑ). For simplicity, here g θ does not access x directly. (a). Since argmax has null gradients, the encoder parameters φ do not receive updates. (b). If ground truth supervision were available for the latent z, φ could be trained jointly with an auxiliary loss. (c). As such supervision is not available, we induce a best-guess label µ by pulling back the downstream loss. This strategy recovers the STE and SPIGOT estimators. dient methods, based on optimizing a pulledback loss, thereby inducing pseudo-supervision on the latent variable. This leads to new insight into both STE and SPIGOT.
• We show how our framework may be used to derive new surrogate gradient methods, by varying the loss function or the inner optimization algorithm used for inducing the pseudo-supervision.
• We experimentally validate our discoveries on a controllable experiment as well as on Englishlanguage sentiment analysis and natural language inference, comparing against stochastic and relaxed alternatives, yielding new insights, and identifying noteworthy failure cases.
While the discrete methods do not outperform the relaxed alternatives using the same building blocks, we hope that our interpretation and insights would trigger future latent structure research.
The code for the paper is available on https: //github.com/deep-spin/understanding-spigot.

Related Work
Discrete latent variable learning is often tackled in stochastic computation graphs, by estimating the gradient of an expected loss. An established method is the score function estimator (SFE) (Glynn, 1990;Williams, 1992;Kleijnen and Rubinstein, 1996). SFE is widely used in NLP, for tasks including minimum risk training in NMT (Shen et al., 2016;Wu et al., 2018) and latent linguistic structure learning Havrylov et al., 2019). In this paper, we focus on the alternative strategy of surrogate gradients, which allows learning in deterministic graphs with discrete, argmax-like nodes, rather than in stochastic graphs. Examples are the straight-through estimator (STE) (Hinton, 2012;Bengio et al., 2013) and the structured projection of intermediate gradients optimization technique (SPIGOT; Peng et al. 2018). Recent work focuses on studying and explaining STE. Yin et al. (2019) obtained a convergence result in shallow networks for the unstructured case. Cheng et al. (2018) show that STE can be interpreted as the simulation of the projected Wasserstein gradient flow. STE has also been studied in binary neural networks (Hubara et al., 2016) and in other applications (Tjandra et al., 2019).
Other methods based on the surrogate gradients have been recently explored (Vlastelica et al., 2020;Meng et al., 2020).
A popular alternative is to relax an argmax into a continuous transform such as softmax or sparsemax (Martins and Astudillo, 2016), as seen for instance in soft attention mechanisms (Vaswani et al., 2017), or structured attention networks (Kim et al., 2017;Maillard et al., 2017;Liu and Lapata, 2018;Mensch and Blondel, 2018;Niculae et al., 2018a). In-between surrogate gradients and relaxation is Gumbel softmax, which uses the Gumbel-max reparametrization to sample from a categorical distribution, applying softmax either to relax the mapping or to induce surrogate gradients (Jang et al., 2017;Maddison et al., 2017). Gumbel-softmax has been successfully applied to latent linguistic structure as well (Choi et al., 2018;Maillard and Clark, 2018). For sampling from a structured variable is required, the Perturb-and-MAP technique (Papandreou and Yuille, 2011) has been successfully applied to sampling latent structures in NLP applications (Corro and Titov, 2019a,b).
We assume a general latent structure model involving input variables x ∈ X , output variables y ∈ Y, and latent discrete variables z ∈ Z. We assume that Z ⊆ {0, 1} K , where K ≤ |Z| (typically, K |Z|): i.e., the latent discrete variable z can be represented as a K-th dimensional binary vector. This often results from a decomposition of a structure into parts: for example, z could be a dependency tree for a sentence of L words, represented as a vector of size K = O(L 2 ), indexed by pairs of word indices (i, j), with z ij = 1 if arc i → j belongs to the tree, and 0 otherwise. This allows us to define the score of a structure as the sum of the scores of its parts. Given a vector s ∈ R K , containing scores for all possible parts, we define score(z) := s z. (1) Notation. We denote by e k the one-hot vector with all zeros except in the k th coordinate. We denote the simplex by |Z| := {p ∈ R |Z| | p ≥ 0, z∈Z p(z) = 1}. Given a distribution p ∈ |Z| , the expectation of a function h : Background. In the context of structured prediction, the set M := conv(Z) is known as the marginal polytope, since any point inside it can be interpreted as some marginal distribution over parts of the structure (arcs) under some distribution over structures. There are three relevant problems that may be formulated in a structured setting: whereH is the maximum entropy among all distributions over structures that achieve marginals µ (Wainwright and Jordan, 2008): • SparseMAP: finds the (unique) sparse marginals induced by the scores s, given by a Euclidean projection onto M: (Niculae et al., 2018a) Unstructured setting. As a check, we consider the encoding of a categorical variable with K distinct choices, encoding each choice as a one-hot vector e k and setting Z = {e 1 , . . . , e K }. In this case, conv(Z) = K . The optimization problems above then recover some well known transformations, as described in Table 1. unstructured structured vertices e k z k interior points p µ maximization argmax MAP expectation softmax Marg Euclidean projection sparsemax SparseMAP

Latent Structure Models
Throughout, we assume a classifier parametrized by φ and θ, which consists of three parts: • An encoder function f φ which, given an input x ∈ X , outputs a vector of "scores" s ∈ R K , as s = f φ (x); • An argmax node which, given these scores, outputs the highest-scoring structure: • A decoder function g θ which, given x ∈ X and z ∈ Z, makes a predictionŷ ∈ Y aŝ y = g θ (x, z). We will sometimes writeŷ(z) to emphasize the dependency on z. For reasons that will be clear in the sequel, we must assume that the decoder also accepts average structures, i.e., it can also output predictions g θ (x, µ) where µ ∈ conv(Z) is a convex combination (weighted average) of structures.
Thus, given input x ∈ X , this network predicts: To train this network, we minimize a loss function L(ŷ, y), where y denotes the target label; a common example is the negative log-likelihood loss. The gradient w.r.t. the decoder parameters, ∇ θ L(ŷ, y), is easy to compute using automatic differentiation on g θ . The main challenge is to propagate gradient information through the argmax node into the encoder parameters. Indeed, so no gradient will flow to the encoder. We list below the three main categories of approaches that tackle this issue.
Introducing stochasticity. Replace the argmax node by a stochastic node where z is modeled as a random variable Z parametrized by s (e.g., using a Gibbs distribution). Then, instead of optimizing a deterministic loss L(ŷ(ẑ), y), optimize the expectation of the loss under the predicted distribution: The expectation ensures that the gradients are no longer null. This is sometimes referred to as minimum risk training (Smith and Eisner, 2006;Stoyanov et al., 2011), and typically optimized using the score function estimator (SFE; Glynn, 1990;Williams, 1992;Kleijnen and Rubinstein, 1996).
Relaxing the argmax. Keep the network deterministic, but relax the argmax node into a continuous function, for example replacing it with softmax or sparsemax (Martins and Astudillo, 2016). In the structured case, this gives rise to structured attention networks (Kim et al., 2017) and their SparseMAP variant (Niculae et al., 2018a). This corresponds to moving the expectation inside the loss, Inventing a surrogate gradient. Keep the argmax node and perform the usual forward computation, but backpropagate a different, non-null gradient in the backward pass. This is the approach underlying straight-through estimators (Hinton, 2012; Bengio et al., 2013) and SPIGOT (Peng et al., 2018). This method introduces a mismatch between the measured objective and the optimization algorithm. In this work, we proposed a novel, principled justification for inducing surrogate gradients. In what follows, we assume that: • We can compute the gradient for any µ, usually by automatic differentiation; 1 • We want to replace the null gradient ∇ s L(ŷ(ẑ), y) by a surrogate∇ s L(ŷ(ẑ), y).

SPIGOT as the Approximate Optimization of a Pulled Back Loss
We next provide a novel interpretation of SPIGOT as the minimization of a "pulled back" loss. SPIGOT uses the surrogate gradient: highlighting that SparseMAP (Niculae et al., 2018a) computes an Euclidean projection (Eq. 4).

Intermediate Latent Loss
To begin, consider a much simpler scenario: if we had supervision for the latent variable z (e.g., if the true label z was revealed to us), we could define an intermediate loss (ẑ, z) which would induce nonzero updates to the encoder parameters. Of course, we do not have access to this z. Instead, we consider the following alternative: for what the unknown z ∈ Z should be, informed by the downstream loss. Figure 1 provides the intuition of the pulled-back label and loss. We take a moment to justify picking µ ∈ M rather than directly in Z. In fact, if K = |Z| is small, we can enumerate all possible values of z and define the guess as the latent value minimizing the downstream loss, µ = argmin z∈Z L(ŷ(z), y). This is sensible, but intractable in the structured case. Moreover, early on in the training process, while g θ is untrained, the maximizing vertex carries little information. Thus, for robustness and tractability, we allow for some uncertainty by picking a convex combination µ ∈ M so as to approximately minimize For most interesting predictive modelsŷ(µ) (e.g., deep networks), this optimization problem is nonconvex and lacks a closed form solution. One common strategy is the projected gradient algorithm (Goldstein, 1964;Levitin and Polyak, 1966), which, in addition to gradient descent, has one more step: projection of the updated point on the constraint set. It iteratively performs the following updates: where η t is a step size and γ is as in Eq. 8. With a suitable choice of step sizes, the projected gradient algorithm converges to a local optimum of Eq. 10 (Bertsekas, 1999, Proposition 2.3.2). In the sequel, for simplicity we use constant η. If we initialize µ (0) =ẑ = argmax z∈Z s z, a single iteration of projected gradient yields the guess: Treating the induced µ as if it were the "ground truth" label of z, we may train the encoder f φ (x) by supervised learning. With a perceptron loss, a single iteration yields the gradient: which is precisely the SPIGOT gradient surrogate in Eq. 9. This leads to the following insight into how SPIGOT updates the encoder parameters: SPIGOT minimizes the perceptron loss between z and a pulled back target computed by one projected gradient step on min µ∈M L(ŷ(µ), y) starting atẑ = MAP(s).
This construction suggests possible alternatives, the first of which uncovers a well-known algorithm.
Relaxing the M constraint. The constraints in Eq. 10 make the optimization problem more complicated. We relax them and define µ ≈ argmin µ∈R K L(ŷ(µ), y). This problem still requires iteration, but the projection step can now be avoided. One iteration of gradient descent yields µ (1) =ẑ − ηγ. The perceptron update then recovers a novel derivation of straight-through with identity (STE-I), where the backward pass acts as if ∂ẑ(s) ∂s ! = Id (Bengio et al., 2013), This leads to the following insight into straightthrough and its relationship to SPIGOT: Straight-through (STE-I) minimizes the perceptron loss between z and a pulled back target computed by one gradient step on min µ∈R K L(ŷ(µ), y) starting atẑ = MAP(s).
From this intuition, we readily obtain new surrogate gradient methods, which we explore below.

New Surrogate Gradient Methods
Multiple gradient updates. Instead of a single projected gradient step, we could run multiple steps of Eq. 11. We would expect this to yield a better approximation of µ. This comes at a computational cost: each update involves running a forward and backward pass in the decoder g θ with the current guess µ (t) , to obtain γ(µ (t) ) := ∇ µ L ŷ(µ (t) ), y .
Different initialization. The projected gradient update in Eq. 12 uses µ (0) =ẑ = argmax z∈Z s z as the initial point. This is a sensible choice, if we believe the encoder predictionẑ is close enough to the optimal µ, and it is computationally convenient, because the forward pass usesẑ, so γ(ẑ) is readily available in the backward pass, thus the first inner iteration comes for free. However, other initializations are possible, for example µ (0) = Marg(s) or µ (0) = 0, at the cost of an extra computation of γ(µ (0) ). In this work, we do not consider alternate initializations for their own sake; they are needed for the following two directions.

Different intermediate loss: SPIGOT-CE.
For simplicity, consider the unstructured case where M = , and use the initial guess µ (0) = softmax(s). Replacing Perc by the cross-entropy loss CE (µ (0) , µ (1) ) = − K k=1 µ k log µ In the structured case, the corresponding loss is the CRF loss (Lafferty et al., 2001), which corresponds to the KL divergence between two distributions over structures. In this case, we initialize µ (0) = Marg(s) and update Exponentiated gradient updates: SPIGOT-EG.
In the unstructured case, optimization over M = can also be tackled via the exponentiated gradient (EG) algorithm (Kivinen and Warmuth, 1997), which minimizes Eq. 10 with the following multiplicative update: (18) where is elementwise multiplication and thus each iterate µ (t) is strictly positive, and normalized to be inside . EG cannot be initialized on the boundary of , so again we must take µ (0) = softmax(s). A single iteration of EG yields: It is natural to use the cross-entropy loss, giving i.e., the surrogate gradient is the difference between the softmax prediction and a "perturbed" softmax. To generalize to the structured case, we observe that both EG and projected gradient are instances of mirror descent under KL divergences (Beck and Teboulle, 2003). Unlike the unstructured case, we must iteratively keep track of both perturbed scores and marginals, since Marg −1 is non-trivial. This leads to the following mirror descent algorithm: With a single iteration and the CRF loss, we get Function GradLoss(µ, x, y): return γ ← ∇µL(ŷ(µ), y) // Eq. (8) Function BackwardSPIGOT(s, x, y): Function BackwardSTE-I(s, x, y): Algorithm 1 sketches the implementation of the proposed surrogate gradients for the structured case. The forward pass is the same for all variants: given the scores s for the parts of the structure, it calculates the MAP structure z. The surrogate gradients are implemented as custom backward passes. The function GradLoss uses automatic differentiation to compute γ(µ) at the current guess µ; each call involves thus a forward and backward pass through g θ . Due to convenient initialization, the first iteration of STE-I and SPIGOT come for free, since both µ (0) and γ(µ (0) ) are available as a byproduct when computing the forward and, respectively, backward pass through g θ in order to update θ. For SPIGOT-CE and SPIGOT-EG, even with k = 1 we need a second call to the decoder, since µ (0) =ẑ, so an additional decoder call is necessary for obtaining the gradient of the loss with respect to µ (0) . The unstructured case is essentially identical, with Marg replaced by softmax.

Experiments
Armed with a selection of surrogate gradient methods, we now proceed to an experimental comparison. For maximum control, we first study a synthetic unstructured experiment with known data generating process. This allows us to closely compare the various methods, and to identify basic failure cases. We then study the structured case of latent dependency trees for sentiment analysis and natural language inference in English. Full training details are described in Appendix A.

Categorical Latent Variables
For the unstructured case, we design a synthetic dataset from a mixture model z ∼ Categorical( 1 /K), x ∼ Normal(m z , σI), y = sign(w z x + b z ), where m z are randomly placed cluster centers, and w z , b z are parameters of a different ground truth linear model for each cluster. Given cluster labels, one could learn the optimal linear classifier separating the data in that cluster. Without knowing the cluster, a global linear model cannot fit the data well. This setup provides a test bed for discrete variable learning, since accurate clustering leads to a good fit. The architecture, following §4, is: • Encoder: A linear mapping from the input to a K-dimensional score vector: • Latent mapping:ẑ = ρ(s), where ρ is argmax or a continuous relaxation such as softmax or sparsemax.
• Decoder: A bilinear transformation, combining the input x and the latent variable z: where θ = (W g , b g ) ∈ R K×dim(X ) × R are model parameters. Ifẑ = e k , this selects the k th linear model from the rows of W g .
We evaluate two baselines: a linear model, and an oracle where g θ (x, z) has access to the true z. In addition to the methods discussed in the previous section, we evaluate softmax and sparsemax end-toend differentiable relaxations, and the STE-S variant which uses the softmax backward pass while doing argmax in the forward pass. We also compare stochastic methods, including score function estimators (with an optional moving average control variate), and the two Gumbel estimator variants (Jang et al., 2017;Maddison et al., 2017): Gumbel-Softmax with relaxed softmax in the forward pass, and the other using argmax in the style of STE (hence dubbed ST-Gumbel).
Results. We compare the discussed methods in Table 2. Knowledge of the data-generating process allows us to measure not only downstream accuracy, but also clustering quality, by comparing the model predictions with the known true z. We measure the latter via the V-measure (Rosenberg and Hirschberg, 2007), a clustering score independent of the cluster labels, i.e., invariant to permuting the labels (between 0 and 100, with 100 representing perfect cluster recovery). The linear and gold cluster oracle baselines confirm that cluster separation is needed for good performance. Stochastic models perform well across both criteria. Crucially, SFE requires variance reduction to performs well, but even a simple control variate will do.
Deterministic models may be preferable when likelihood assessment or sampling is not tractable. Among these, STE-I and SPIGOT-{CE,EG} are indistinguishable from the best models. Surprisingly, the vanilla SPIGOT fails, especially in cluster recovery. Finally, the relaxed deterministic models perform very well on accuracy and learn very fast (Figure 2), but appear to rely on mixing clusters, therefore they remarkably fail to recover cluster assignments. 2 This is in line with the structured results of Corro and Titov (2019b). Therefore, if latent structure recovery is less important than downstream accuracy, relaxations seem preferable.
Impact of multiple updates. One possible explanation for the failure of SPIGOT is that SPIGOT-CE and SPIGOT-EG perform more work per iteration, since they use a softmax initial guess and thus require a second pass through the decoder. We rule out this possibility in Figure 3: even when tuning the number of updates, SPIGOT does not substantially improve. We observe, however, that SPIGOT-CE improves slightly with more updates, outperforming STE-I. However, since each update step performs an additional decoder call, this also increases the training time.

Structured Latent Variables
For learning structured latent variables, we study sentiment classification on the English language Stanford Sentiment Treebank (SST) (Socher et al., 2013), and Natural Language Inference on the SNLI dataset (Bowman et al., 2015).

Sentiment Classification
The model predicts a latent projective arc-factored dependency tree for the sentence, then uses the tree in predicting the downstream binary sentiment label. The model has the following components: • Encoder: Computes a score for every possible dependency arc i → j between words i and j. Each word is represented by its embedding h i , 3 then processed by an LSTM, yielding contextual vectors ← → h i . Then, arc scores are computed as  Table 3: SST and SNLI average accuracy and standard deviation over three runs, with latent dependency trees. Baselines are described in Section 7.2. We mark stochastic methods marked with *.
• Latent parser: We use the arc scores vector s to get a parseẑ = ρ(s) for the sentence, where ρ(s) is the argmax, or combination of trees, such as Marg or SparseMAP.
• Decoder: Following Peng et al. (2018), we concatenate each ← → h i with its predicted head ← → h head(i) . For relaxed methods, we average all possible heads, weighted by the corresponding marginal: The concatenation is passed through an affine layer, a ReLU activation, an attention mechanism, and the result is fed into a linear output layer.
For marginal inference, we use pytorch-struct (Rush, 2020). For the SparseMAP projection, we use the active set algorithm (Niculae et al., 2018a). The baseline we compare our models against is a BiLSTM, followed by feeding the sum of all hidden states to a two-layer ReLU-MLP.
Results. The results from the experiments with the different methods are shown in Table 3. As in the unstructured case, the relaxed models lead to strong downstream classifiers. Unlike the unstructured case, SPIGOT is a top performer here. The effect of tuning the number of gradient update steps is not as big as in the unstructured case and did not lead to significant improvement. This can be explained by a "moving target" intuition: since the decoder g θ is far from optimal, more accurate µ do not overall help learning.

Natural Language Inference
We build on top of the decomposable attention model (DA;Parikh et al., 2016). Following the setup of Corro and Titov (2019b), we induce structure on the premise and the hypothesis. For com-puting the score of the arc from word i to j, we concatenate the representations of the two words, as in Eq. 23. In the decoder, after the latent parse tree is calculated, we concatenate each word with the average of its heads. We do this separately for the premise and the hypothesis. As baseline, we use the DA model with no intra-attention.
Results. The SNLI results are shown in Table 3. Here, the straight-through (argmax) methods are outperformed by the more stable relaxation-based methods. This can be attributed to the word-level alignment in the DA model, where soft dependency relations appear better suited than hard ones.

Conclusions
In this work, we provide a novel motivation for straight-through estimator (STE) and SPIGOT, based on pulling back the downstream loss. We derive promising new algorithms, and novel insight into existing ones. Unstructured controlled experiments suggest that our new algorithms, which use the cross-entropy loss instead of the perceptron loss, can be more stable than SPIGOT while accurately disentangling the latent variable. Differentiable relaxation models (using softmax and sparsemax) are the easiest to optimize to high downstream accuracy, but they fail to correctly identify the latent clusters. On structured NLP experiments, relaxations (SparseMAP and Marginals) tend to overall perform better and be more stable than straightthrough variants in terms of classification accuracy. However, the lack of gold-truth latent structures makes it impossible to assess recovery performance. We hope that our insights, including some of our negative results, may encourage future research on learning with latent structures. STE  (SPIGOT-CE) An intelligent , moving and invigorating film . The majority of the models produce mostly flat trees. In contrast, SPIGOT-CE identifies the adjectives describing the keyword "drama" and attaches them correctly.