Torch-Struct: Deep Structured Prediction Library

The literature on structured prediction for NLP describes a rich collection of distributions and algorithms over sequences, segmentations, alignments, and trees; however, these algorithms are difficult to utilize in deep learning frameworks. We introduce Torch-Struct, a library for structured prediction designed to take advantage of and integrate with vectorized, auto-differentiation based frameworks. Torch-Struct includes a broad collection of probabilistic structures accessed through a simple and flexible distribution-based API that connects to any deep learning model. The library utilizes batched, vectorized operations and exploits auto-differentiation to produce readable, fast, and testable code. Internally, we also include a number of general-purpose optimizations to provide cross-algorithm efficiency. Experiments show significant performance gains over fast baselines and case-studies demonstrate the benefits of the library. Torch-Struct is available at https://github.com/harvardnlp/pytorch-struct.


Introduction
Structured prediction is an area of machine learning focusing on representations of spaces with combinatorial structure, as well as algorithms for inference and parameter estimation over these structures. Core methods include both tractable exact approaches like dynamic programming and spanning tree algorithms as well as heuristic techniques such linear programming relaxations and greedy search.
Structured prediction has played a key role in the history of natural language processing. Example methods include techniques for sequence labeling and segmentation (Lafferty et al., 2001;Sarawagi and Cohen, 2005), discriminative dependency and constituency parsing (Finkel et al., 2008;McDonald et al., 2005), unsupervised learning for Figure 1: Distribution of binary trees over an 1000token sequence. Coloring shows the marginal probabilities of every span. Torch-Struct is an optimized collection of common CRF distributions used in NLP that is designed to integrate with deep learning frameworks. labeling and alignment (Vogel et al., 1996;Goldwater and Griffiths, 2007), approximate translation decoding with beam search (Tillmann and Ney, 2003), among many others.
In recent years, research into deep structured prediction has studied how these approaches can be integrated with neural networks and pretrained models. One line of work has utilized structured prediction as the final layer for deep models (Collobert et al., 2011;Durrett and Klein, 2015). Another has incorporated structured prediction within deep learning models, exploring novel models for latentstructure learning, unsupervised learning, or model control (Johnson et al., 2016;Yogatama et al., 2016;Wiseman et al., 2018). We aspire to make both of these use-cases as easy to use as standard neural networks.
The practical challenge of employing structured  (Kasami, 1966) Simple CKY Labeled Tree Splits (CN 2 ) 0-th order CKY 30 118k (Kasami, 1966) Dependency Proj. Tree Arcs (N 2 ) Eisner Alg 40 28k (Eisner, 2000) Dep ( Table 1: Models and algorithms implemented in Torch-Struct. Notation is developed in Section 5. Parts are described in terms of sequence lengths N, M , label size C, segment length K, and layers / grammar size L, G. Lines of code (LoC) is from the log-partition (A( )) implementation. T/S is the tokens per second of a batched computation, computed with batch 32, N = 25, C = 20, K = 5, L = 3 (K80 GPU run on Google Colab).
prediction is that many required algorithms are difficult to implement efficiently and correctly. Most projects reimplement custom versions of standard algorithms or focus particularly on a single welldefined model class. This research style makes it difficult to combine and try out new approaches, a problem that has compounded with the complexity of research in deep structured prediction. With this challenge in mind, we introduce Torch-Struct with three specific contributions: • Modularity: models are represented as distributions with a standard flexible API integrated into a deep learning framework.
• Completeness: a broad array of classical algorithms are implemented and new models can easily be added.
• Efficiency: implementations target computational/memory efficiency for GPUs and the backend includes extensions for optimization.
In this system description, we first motivate the approach taken by the library, then present a technical description of the methods used, and finally present several example use cases.

Related Work
Several software libraries target structured prediction.

Motivating Case Study
While structured prediction is traditionally presented at the output layer, recent applications have deployed structured models broadly within neural networks (Johnson et al., 2016;Kim et al., 2017;Yogatama et al., 2016, inter alia). Torch-Struct aims to encourage this general use case.
To illustrate, we consider a latent tree model. ListOps (Nangia and Bowman, 2018) is a dataset of mathematical functions. Each input/output pair consists of a prefix expression x and its result y, e.g.
Models such as a flat RNN will fail to capture the hierarchical structure of this task. However, if a model can induce an explicit latent z, the parse tree of the expression, then the task is easy to learn by a tree-RNN model p(y|x, z) (Yogatama et al., 2016;Havrylov et al., 2019).
Let us briefly summarize a latent-tree RL model for this task. The objective is to maximize the probability of the correct prediction under the expectation of a prior tree model, p(z|x; φ), Computing the expectation is intractable so policy gradient is used. First a tree is sampledz ∼ p(z|x; φ), then the gradient with respect to φ is approximated as, where b is a variance reduction baseline. A common choice is the self-critical baseline (Rennie et al., 2017), Finally an entropy regularization term is added to the objective encourage exploration of different trees, Obj + λH(p(z | x; φ)). Even in this brief overview, we can see how complex a latent structured learning problem can be. To compute these terms, we need 5 different properties of the structured prior model p(z |x; φ): For structured models, each of these terms is nontrivial to compute. A goal of Torch-Struct is to make it seamless to deploy structured models for these complex settings.

Library Design
The library design of Torch-Struct follows the distributions API used by both TensorFlow and Py-Torch (Dillon et al., 2017). For each structured model in the library, we define a conditional random field (CRF) distribution object. From a user's standpoint, this object provides all necessary distributional properties. Given log-potentials output from a deep network, the user can request samples z ∼ CRF( ), probabilities CRF(z; ), modes arg max z CRF( ), or other distributional properties such as H(CRF( )). The library is agnostic to how these are utilized, and when possible, they allow for backpropagation to update the input network. The same distributional object can be used for standard output prediction as for more complex operations like attention or reinforcement learning. Figure 2 demonstrates this API for a binary tree CRF over an ordered sequence, such as p(z | x; φ) from the previous section. The distribution takes in log-potentials which score each possible span in the input. The distribution converts these to probabilities of a specific tree. This distribution can be queried for predicting over the set of trees, sampling a tree for model structure, or even computing entropy over all trees. Table 1 shows all of the structures and distributions implemented in Torch-Struct. While each is internally implemented using different specialized algorithms and optimizations, from the user's perspective they all utilize the same external distributional API, and pass a generic set of distributional tests. 1 This approach hides the internal complexity 1 The test suite for each distribution enumerates over all of the inference procedure, while giving the user full access to the model.

Conditional Random Fields
We now describe the technical approach underlying the library. To establish notation, first consider the implementation of a softmax categorical distribution, CAT( ), with one-hot categories z with z i = 1 from a set Z and probabilities given by the softmax over logits , Define the log-partition as A( ) = LSE( ), i.e. log of the denominator, where LSE is the log-sumexp operator. Computing probabilities or sampling from this distribution, requires enumerating Z to compute the log-partition A. A useful identity is that derivatives of A yield category probabilities, Other distributional properties can be similarly extracted from variants of the log-partition. For instance, define A * ( ) = log max K j=1 exp j then 2 : I(z * i = 1) = ∂ ∂ i A * ( ). Conditional random fields, CRF( ), extend the softmax to combinatorial spaces where Z is exponentially sized. Each z, is now represented as a binary vector over polynomial-sized set of parts, P, i.e. Z ⊂ {0, 1} |P| . Similarly log-potentials are now defined over parts ∈ R |P| . For instance, in Figure 2 each span is a part and the vector is shown in the top-left figure. Define the probability of a structure z as, Computing probabilities or sampling from this distribution, requires computing the log-partition term A. In general, computing this term is now intractable, however for many core algorithms in NLP there are exist efficient combinatorial algorithms for this term (a list of examples is given in Table 1).
structures to ensure that properties hold. While this is intractable for large spaces, it can be done for small sets and was extremely useful for development. 2 This is a subgradient identity, but that deep learning libraries like PyTorch generally default to this value.

Name
Ops ( , ⊗) Backprop Gradients See (Li and Eisner, 2009) Exp. See (Li and Eisner, 2009) Sparsemax See (Mensch and Blondel, 2018)  Derivatives of the log-partition again provide useful distributional properties. For instance, the marginal probabilities of parts are given by, Similarly derivatives of A * correspond to whether a part appears in the argmax structure, I(z * p = 1) = ∂ ∂ p A * ( ). While these gradient identities are well-known (Eisner, 2016), they are not commonly deployed in practice. Computing CRF properties is typically done through two-step specialized algorithms, such as forward-backward, inside-outside, or similar variants such as viterbi-backpointers (Jurafsky and Martin, 2014). Common wisdom is that these approaches are more efficient implementations.
However, we observe that recent engineering of faster gradient computation for deep learning has made gradient-based calculations competitive with hand-written calculations. In our experiments, we found that using these identities with autodiffer-entiation was often faster, and much simpler, than custom two-pass approaches. Torch-Struct is thus designed around using gradients for distributional computations.

Dynamic Programming and Semirings
Torch-Struct is a collection of generic algorithms for CRF inference. Each CRF distribution object, CRF( ), is constructed by providing ∈ R |P| where the parts P are specific to the type of distribution. Internally, each distribution is implemented through a single function for computing the logpartition function A( ). From this function, the library uses autodifferentiation and the identities from the previous section, to define a complete distribution object. The core models implemented by the library are shown in Table 1.
To make the approach concrete, we consider the example of the simplest structured model, a linearchain CRF p(z 1 , z 2 , z 3 | x).
The model has C labels per node with a length N utilizing a first-order linear-chain (Markov) model. This model has N − 1 × C × C parts corresponding to edges in the chain, and thus ∈ R N −1×C×C logpotentials. The log-partition function A( ) factors into two reduce computations, Computing this function left-to-right using dynamic programming yields the standard forward algorithm for computing the log-partition of sequence models. As we have seen, the gradient with respect to produces marginals for each part, i.e. the probability of a specific labeled edge.
We can further extend the same function to support generic semiring dynamic programming (Goodman, 1999). A semiring is defined by a pair (⊕, ⊗) with commutative ⊕, distribution, and appropriate identities.
The log-partition utilizes ⊕, ⊗ = (LSE, +), but we can substitute alternatives. For instance, utilizing the log-max semiring (max, +) in the forward algorithm yields the max score. As we have seen, its gradient with respect to is the argmax sequence, negating the need for a separate argmax (Viterbi) algorithm. Some distributional properties cannot be computed directly through gradient identities but still use a forward-backward style compute structure. For instance, sampling requires first computing the log-partition term and then sampling each part, (forward filtering / backward sampling). We can compute this value by overriding each backpropagation operation for the to instead compute a sample. Table 2 shows the set of semirings and backpropagation steps for computing different terms of interest. We note that many of the terms necessary in the case-study can be computed with variant semirings, negating the need for specialized algorithms.

Optimizations
Torch-Struct aims for computational and memory efficiency. Implemented naively, dynamic programming algorithms in Python are prohibitively slow. As such Torch-Struct provides key primitives to help batch and vectorize these algorithms to take advantage of GPU computation and to minimize the overhead of backpropagating through chart-based dynamic programmming. We discuss three optimizations: a) Parallel Scan, b) Vectorization, and c) Semiring Matrix Multiplications. Figure 3 shows the impact of these optimizations on the core algorithms.
Parallel Scan Inference The commutative properties of semiring algorithms allow flexibility in the order in which we compute A( ). Typical implementations of dynamic programming algorithms are serial in the length of the sequence. On parallel hardware, an appealing approach is a parallel scan ordering (Särkkä and García-Fernández, 2019), typically used for computing prefix sums. To compute, A( ) in this manner we first pad the sequence length N out to the nearest power of two, and then compute a balanced parallel tree over the parts, shown in Figure 4. Concretely each node layer would compute a semiring matrix multiplication, e.g. c n,·,c ⊗ n+1,c,· . Under this approach, assuming enough parallel cores, we only need O(log N ) steps in Python and can use parallel operations for the rest. Similar parallel approach can also be used for computing sequence alignment and semi-Markov models. Vectorization Computational complexity is even more of an issue for algorithms that cannot easily be parallelized. For example, parsing algorithms the generalize CKY are common in NLP. The CKY algorithm has a bottleneck that it must compute each width from 1 through N in serial; however internally each one of these steps can be vectorized. Assuming we have computed all inside spans of width less than d, computing the inside span of width d requires computing for all i, In order to vectorize this loop over i, j, we need to reindex the chart. Instead of using a single chart C, we split it into two parts: one rightfacing C r [i, d] = C[i, i + d] and one left facing, i+d]. After this reindexing, the update can be written.
Unlike the original, this formula can easily be computed as a vectorized semiring dot product. This allows use to compute C r [·, d] in one operation. Variants of this same approach can be used for many more complex dynamic programs.

Semiring Matrix Operations
The two previous optimizations reduce most of the cost to semiring matrix multiplication. In the specific case of the ( , ×) semiring these can be computed very efficiently using matrix multiplication, which is highlytuned on GPU hardware. However, this semiring is not particularly useful and prone to underflow. For A( ) ⊗ ⊗ I 7,·,· ⊗ 6,·,· 5,·,· ⊗ ⊗ 4,·,· 3,·,· ⊗ 2,·,· 1,·,· Figure 4: Parallel scan implementation of the linearchain CRF inference algorithm (parallel forward). Here ⊗ represents a semiring matrix operation and I is padding to produce a balanced tree.
other semirings, such as log and max, these operations are either slow or very memory inefficient. For instance, for matrices T and U of sized N × M and M × O, we can broadcast with ⊗ to a tensor of size N × M × O and then reduce dim M by at a huge memory cost.
To avoid this issue, we implement custom CUDA kernels targeting fast and memory efficient tensor operations. For log, this corresponds to computing, V m,o = log n exp(T m,n + U n,o − q) + q where q = max n T m,n + U n,o . To optimize this operation on GPU we utilize the TVM language (Chen et al., 2018) to layout the CUDA loops and tune it to hardware. This produces much faster operations, although still less efficient that matrix multiplication which is heavily customized to hardware.

Conclusion and Future Work
We present Torch-Struct, a library for deep structured prediction. The library achieves modularity through its adoption of a generic distributional API, completeness by utilizing CRFs and semirings to make it easy to add new algorithms, and efficiency through core optimizations to vectorize important dynamic programming steps. In addition to the problems discussed so far, Torch-Struct also includes several other example implementations including supervised dependency parsing with BERT, unsupervised tagging, structured attention, and connectionist temporal classification (CTC) for speech. Code demonstrates that the model is able to replicate standard deep learning results, although we focus here on the fidelity and implementation approach of the core library. The full library is available at https: //github.com/harvardnlp/pytorch-struct.
In the future, we hope to support research and production applications employing structured models. We also believe the library provides a strong foundation for building generic tools for interpretablity, control, and visualization through its probabilistic API. Finally, we hope to explore further optimizations to make core algorithms competitive with highly-optimized neural network components. These approaches provide a benchmark for improving autodifferentiation systems and extending their functionality to higher-order properties.