Towards Dynamic Computation Graphs via Sparse Latent Structure

Deep NLP models benefit from underlying structures in the data—e.g., parse trees—typically extracted using off-the-shelf parsers. Recent attempts to jointly learn the latent structure encounter a tradeoff: either make factorization assumptions that limit expressiveness, or sacrifice end-to-end differentiability. Using the recently proposed SparseMAP inference, which retrieves a sparse distribution over latent structures, we propose a novel approach for end-to-end learning of latent structure predictors jointly with a downstream predictor. To the best of our knowledge, our method is the first to enable unrestricted dynamic computation graph construction from the global latent structure, while maintaining differentiability.


Introduction
Latent structure models are a powerful tool for modeling compositional data and building NLP pipelines (Smith, 2011). An interesting emerging direction is to dynamically adapt a network's computation graph, based on structure inferred from the input; notable applications include learning to write programs (Bosnjak et al., 2017), answering visual questions by composing specialized modules (Hu et al., 2017;Johnson et al., 2017), and composing sentence representations using latent syntactic parse trees .
But how to learn a model that is able to condition on such combinatorial variables? The question then becomes: how to marginalize over all possible latent structures? For tractability, existing approaches have to make a choice. Some of them eschew global latent structure, resorting to computation graphs built from smaller local decisions: e.g., structured attention networks use local posterior marginals as attention weights (Kim et al., 2017;Liu and Lapata, 2018), and Maillard et al. (2017) construct sentence representations from parser chart entries. Others allow more flexibility at the cost of losing end-to-end differentiability, ending up with reinforcement learning problems Hu et al., 2017;Johnson et al., 2017;Williams et al., 2018). More traditional approaches employ an off-line structure predictor (e.g., a parser) to define the computation graph (Tai et al., 2015;Chen et al., 2017), sometimes with some parameter sharing (Bowman et al., 2016). However, these off-line methods are unable to jointly train the latent model and the downstream classifier via error gradient information.
We propose here a new strategy for building dynamic computation graphs with latent structure, through sparse structure prediction. Sparsity allows selecting and conditioning on a tractable number of global structures, eliminating the limitations stated above. Namely, our approach is the first that: A) is fully differentiable; B) supports latent structured variables; C) can marginalize over full global structures.
This contrasts with off-line and with reinforcement learning-based approaches, which satisfy B and C but not A; and with local marginal-based methods such as structured attention networks, which satisfy A and B, but not C. Key to our approach is the recently proposed SparseMAP inference (Niculae et al., 2018), which induces, for each data example, a very sparse posterior distribution over the possible structures, allowing us to compute the expected network output efficiently and explicitly in terms of a small, interpretable set of latent structures. Our model can be trained end-to-end with gradient-based methods, without the need for policy exploration or sampling.
We demonstrate our strategy on inducing latent dependency TreeLSTMs, achieving competitive results on sentence classification, natural language inference, and reverse dictionary lookup.

Sparse Latent Structure Prediction
We describe our proposed approach for learning with combinatorial structures (in particular, nonprojective dependency trees) as latent variables.
Figure 1: Our method computes a sparse probability distribution over all possible latent structures: here, only two have nonzero probability. For each selected tree h, we evaluate p ξ (y | h, x) by dynamically building the corresponding computation graph (e.g., a TreeLSTM). The final, posterior prediction is a sparse weighted average.

Latent Structure Models
Let x and y denote classifier inputs and outputs, and h ∈ H(x) a latent variable; for example, H(x) can be the set of possible dependency trees for x. We would like to train a neural network to model where p θ (h | x) is a structured-output parsing model that defines a distribution over trees, and p ξ (y | h, x) is a classifier whose computation graph may depend freely and globally on the structure h (e.g., a TreeLSTM). The rest of this section focuses on the challenge of defining p θ (h | x) such that Eqn. 1 remains tractable and differentiable.

Global Inference
Denote by f θ (h; x) a scoring function, assigning each tree a non-normalized score. For instance, we may have an arc-factored score f θ (h; x) := a∈h s θ (a; x), where we interpret a tree h as a set of directed arcs a, each receiving an atomic score s θ (a; x). Deriving p θ given f θ is known as structured inference. This can be written as a Ω-regularized optimization problem of the form where △ |H(x)| is the set of all possible probability distributions over H(x). Examples follow.
Marginal inference. With negative entropy regularization, i.e., Ω(q) := h∈H(x) q(h) log q(h), we recover marginal inference, and the probability of a tree becomes (Wainwright and Jordan, 2008) This closed-form derivation, detailed in Appendix A, provides a differentiable expression for p θ . However, crucially, since exp(·) > 0, every tree is assigned strictly nonzero probability. Therefore-unless the downstream p ξ is constrained to also factor over arcs, as in Kim et al. (2017); Liu and Lapata (2018)-the sum in Eqn. 1 requires enumerating the exponentially large H(x). This is generally intractable, and even hard to approximate via sampling, even when p θ is tractable.
MAP inference. At the polar opposite, setting Ω(q) := 0 yields maximum a posteriori (MAP) inference (see Appendix A). MAP assigns a probability of 1 to the highest-scoring tree, and 0 to all others, yielding a very sparse p θ . However, since the top-scoring tree (or top-k, for fixed k) does not vary with small changes in θ, error gradients cannot propagate through MAP. This prevents end-to-end gradient-based training for MAPbased latent variables, which makes them more difficult to use. Related reinforcement learning approaches also yield only one structure, but sidestep non-differentiability by instead introducing more challenging search problems.

Sparse Inference
In this work, we propose using SparseMAP inference (Niculae et al., 2018) to sparsify the set H while preserving differentiability. SparseMAP uses a quadratic penalty on the posterior marginals Situated between marginal inference and MAP inference, SparseMAP assigns nonzero probability to only a small set of plausible treesH ⊂ H, of size at most equal to the number of arcs (Martins et al., 2015, Proposition 11). This guarantees that the summation in Eqn. 1 can be computed efficiently by iterating overH: this is depicted in Figure 1 and described in the next paragraphs.
Forward pass. To compute p(y | x) (Eqn. 1), we observe that the SparseMAP posterior p θ is nonzero only on a small set of treesH, and thus we only need to compute p ξ (y | h, x) for h ∈H.
The support and values of p θ are obtained by solving the SparseMAP inference problem, as we describe in Niculae et al. (2018). The strategy, based on the active set algorithm (Nocedal and Wright, 1999, chapter 16), involves a sequence of MAP calls (here: maximum spanning tree problems.) Backward pass. We next show how to compute end-to-end gradients efficiently. Recall from Eqn. 1 p(y | x) = h∈H p θ (h | x) p ξ (y | h, x), where h is a discrete index of a tree. To train the classifier, we have ∂p(y|x) /∂ξ = h∈H p θ (h | x) ∂p ξ (y|h,x) /∂ξ, therefore only the terms with nonzero probability (i.e., h ∈H) contribute to the gradient. ∂p ξ (y|h,x) /∂ξ is readily available by implementing p ξ in an automatic differentiation library. 1 To train the latent parser, the total gradient ∂p(y|x) /θ is the sum h∈H p ξ (y | h, x) ∂p θ (h|x) /∂θ. We derive the expression of ∂p θ (h|x) /∂θ in Appendix B. Crucially, the gradient sum is also sparse, like p θ , and efficient to compute, amounting to multiplying by a |H(x)|-by-|H(x)| matrix. The proof, given in Appendix B, is a novel extension of the SparseMAP backward pass (Niculae et al., 2018).
Generality. Our description focuses on probabilistic classifiers, but our method can be readily applied to networks that output any representation, not necessarily a probability. For this, we define a function r ξ (h, x), consisting of any autodifferentiable computation w.r.t. x, conditioned on 1 Here we assume θ and ξ to be disjoint, but weight sharing is easily handled by automatic differentiation via the product rule. Differentiation w.r.t. the summation index h is not necessary: p ξ may use the discrete structure h freely and globally.  the discrete latent structure h in arbitrary, nondifferentiable ways. We then computē This strategy is demonstrated in our reversedictionary experiments in §3.4. In addition, our approach is not limited to trees: any structured model with tractable MAP inference may be used.

Experiments
We evaluate our approach on three natural language processing tasks: sentence classification, natural language inference, and reverse dictionary lookup.

Common aspects
Word vectors. Unless otherwise mentioned, we initialize with 300-dimensional GloVe word embeddings (Pennington et al., 2014) We transform every sentence via a bidirectional LSTM encoder, to produce a context-aware vector v i encoding word i.
Dependency TreeLSTM. We combine the word vectors v i in a sentence into a single vector using a tree-structured Child-Sum LSTM, which allows an arbitrary number of children at any node (Tai et al., 2015). Our baselines consist in extreme cases of dependency trees: where the parent of word i is word i+1 (resulting in a left-to-right sequential LSTM), and where all words are direct children of the root node (resulting in a flat additive model). We also consider off-line dependency trees precomputed by Stanford CoreNLP .
Experimental setup. All networks are trained via stochastic gradient with 16 samples per batch.
We tune the learning rate on a log-grid, using a decay factor of 0.9 after every epoch at which the validation performance is not the best seen, and stop after five epochs without improvement. At test time, we scale the arc scores s θ by a temperature t seen unseen concepts rank acc 10 acc 100 rank acc 10 acc 100 rank acc 10 acc 100  Table 2: Results on the reverse dictionary lookup task (Hill et al., 2016). Following the authors, for an input definition, we rank a shortlist of approximately 50k candidate words according to the cosine similarity to the output vector, and report median rank of the expected word, accuracy at 10, and at 100.
chosen on the validation set, controlling the sparsity of the SparseMAP distribution. All hidden layers are 300-dimensional. 2

Sentence classification
We evaluate our models for sentence-level subjectivity classification (Pang and Lee, 2004) and for binary sentiment classification on the Stanford Sentiment Treebank (Socher et al., 2013). In both cases, we use a softmax output layer on top of the Dependency TreeLSTM output representation.

Natural language inference (NLI)
We apply our strategy to the SNLI corpus (Bowman et al., 2015), which consists of classifying premise-hypothesis sentence pairs into entailment, contradiction or neutral relations. In this case, for each pair (x P , x H ), the running sum is over two latent distributions over parse trees, i.e., For each pair of trees, we independently encode the premise and hypothesis using a TreeLSTM. We then concatenate the two vectors, their difference, and their element-wise product (Mou et al., 2016). The result is passed through one tanh hidden layer, followed by the softmax output layer. 3

Reverse dictionary lookup
The reverse dictionary task aims to compose a dictionary definition into an embedding that is close to the defined word. We therefore used fixed input and output embeddings, set to unit-norm 500dimensional vectors provided, together with training and evaluation data, by Hill et al. (2016). The 28% ⋆ a vivid cinematic portrait . network output is a projection of the TreeLSTM encoding back to the dimension of the word embeddings, normalized to unit ℓ 2 norm. We maximize the cosine similarity of the predicted vector with the embedding of the defined word. ⋆ lovely and poignant .  ble 2), our model also performs well, especially on concept classification, where the input definitions are more different from the ones seen during training. For context, we repeat the scores of the CKY-based latent TreeLSTM model of Maillard et al. (2017), as well as of the LSTM from Hill et al. (2016); these different-sized models are not entirely comparable. We attribute our model's performance to the latent parser's flexibility, investigated below.
Selected latent structures. We analyze the latent structures selected by our model on SST, where the flat composition baseline is remarkably strong. We find that our model, to maximize accuracy, prefers flat or nearly-flat trees, but not exclusively: the average posterior probability of the flat tree is 28.9%. In Figure 2, the highest-ranked tree is flat, but deeper trees are also selected, including the projective CoreNLP parser output. Syntax is not necessarily an optimal composition order for a latent TreeLSTM, as illustrated by the poor performance of the off-line parser (Table 1). Consequently, our (fully unsupervised) latent structures tend to disagree with CoreNLP: the average probability of CoreNLP arcs is 5.8%; Williams et al. (2018) make related observations. Indeed, some syntactic conventions may be questionable for recursive composition. Figure 3 shows two examples where our model identifies a plausible symmetric composition order for coordinate structures: this analysis disagrees with CoreNLP, which uses the asymmetrical Stanford / UD convention of assigning the left-most conjunct as head (Nivre et al., 2016). Assigning the conjunction as head instead seems preferable in a Child-Sum TreeLSTM.
Training efficiency. Our model must evaluate at least one TreeLSTM for each sentence, making it necessarily slower than the baselines, which evaluate exactly one. Thanks to sparsity and autobatching, the actual slow-down is not problematic; moreover, as the model trains, the latent parser gets more confident, and for many unambiguous sentences there may be only one latent tree with nonzero probability. On SST, our average training epoch is only 4.7× slower than the off-line parser and 6× slower than the flat baseline.

Conclusions and future work
We presented a novel approach for training latent structure neural models, based on the key idea of sparsifying the set of possible structures, and demonstrated our method with competitive latent dependency TreeLSTM models. Our method's generality opens up several avenues for future work: since it supports any structure for which MAP inference is available (e.g., matchings, alignments), and we have no restrictions on the downstream p ξ (y | h, x), we may design latent versions of more complicated state-of-the-art models, such as ESIM for NLI (Chen et al., 2017). In concurrent work, Peng et al. (2018) proposed an approximate MAP backward pass, relying on a relaxation and a gradient projection. Unlike our method, theirs does not support multiple latent structures; we intend to further study the relationship between the methods.