Learning Context-free Languages with Nondeterministic Stack RNNs

We present a differentiable stack data structure that simultaneously and tractably encodes an exponential number of stack configurations, based on Lang’s algorithm for simulating nondeterministic pushdown automata. We call the combination of this data structure with a recurrent neural network (RNN) controller a Nondeterministic Stack RNN. We compare our model against existing stack RNNs on various formal languages, demonstrating that our model converges more reliably to algorithmic behavior on deterministic tasks, and achieves lower cross-entropy on inherently nondeterministic tasks.


Introduction
Although recent neural models of language have made advances in learning syntactic behavior, research continues to suggest that inductive bias plays a key role in data efficiency and human-like syntactic generalization (van Schijndel et al., 2019;Hu et al., 2020). Based on the long-held observation that language exhibits hierarchical structure, previous work has proposed coupling recurrent neural networks (RNNs) with differentiable stack data structures (Joulin and Mikolov, 2015;Grefenstette et al., 2015) to give them some of the computational power of pushdown automata (PDAs), the class of automata that recognize context-free languages (CFLs). However, previously proposed differentiable stack data structures only model deterministic stacks, which store only one version of the stack contents at a time, theoretically limiting the power of these stack RNNs to the deterministic CFLs.
A sentence's syntactic structure often cannot be fully resolved until its conclusion (if ever), requiring a human listener to track multiple possibilities while hearing the sentence. Past work in psycholinguistics has suggested that models that keep multiple candidate parses in memory at once can explain human reading times better than models which assume harsher computational constraints. This ability also plays an important role in calculating expectations that facilitate more efficient language processing (Levy, 2008). Current neural language models do not track multiple parses, if they learn syntax generalizations at all McCoy et al., 2020).
We propose a new differentiable stack data structure that explicitly models a nondeterministic PDA, adapting an algorithm by Lang (1974) and reformulating it in terms of tensor operations. The algorithm is able to represent an exponential number of stack configurations at once using cubic time and quadratic space complexity. As with existing stack RNN architectures, we combine this data structure with an RNN controller, and we call the resulting model a Nondeterministic Stack RNN (NS-RNN).
We predict that nondeterminism can help language processing in two ways. First, it will improve trainability, since all possible sequences of stack operations contribute to the objective function, not just the sequence used by the current model. Second, it will improve expressivity, as it is able to model concurrent parses in ways that a deterministic stack cannot. We demonstrate these claims by comparing the NS-RNN to deterministic stack RNNs on formal language modeling tasks of varying complexity. To show that nondeterminism aids training, we show that the NS-RNN achieves lower cross-entropy, in fewer parameter updates, on some deterministic CFLs. To show that nondeterminism improves expressivity, we show that the NS-RNN achieves lower crossentropy on nondeterministic CFLs, including the "hardest context-free language" (Greibach, 1973), a language which is at least as difficult to parse as any other CFL and inherently requires nondeterminism. Our code is available at https://github. com/bdusell/nondeterministic-stack-rnn.

Background and Motivation
In all differentiable stack-augmented networks that we are aware of (including ours), a network called the controller, which is some kind of RNN (typically an LSTM), is augmented with a differentiable stack, which has no parameters of its own. At each time step, the controller emits weights for various stack operations, which at minimum include push and pop. To maintain differentiability, the weights need to be continuous; different designs for the stack interpret fractionally-weighted operations differently. The stack then executes the fractional operations and produces a stack reading, which is a vector that represents the top of the updated stack. The stack reading is used as an extra input to the next hidden state update.
Designs for differentiable stacks have proceeded generally along two lines. One approach, which we call superposition (Joulin and Mikolov, 2015), treats fractional weights as probabilities. The other, which we call stratification (Sun et al., 1995;Grefenstette et al., 2015), treats fractional weights as "thicknesses." Superposition In the model of Joulin and Mikolov (2015), the controller emits at each time step a probability distribution over three stack operations: push a new vector, pop the top vector, and no-op. The stack simulates all three operations at once, setting each stack element to the weighted interpolation of the elements above, at, and below it in the previous time step, weighted by push, noop, and pop probabilities respectively. Thus, each stack element is a superposition of possible values for that element. Because stack elements depend only on a fixed number of elements from the previous time step, the stack update can largely be parallelized. Yogatama et al. (2018) developed an extension to this model that allows a variable number of pops per time step, up to a fixed limit K. Suzgun et al. (2019) also proposed a modification of the controller parameterization.
Stratification The model proposed by Sun et al. (1995) and later studied by Grefenstette et al. (2015) takes a different approach, assigning a strength between 0 and 1 to each stack element. If the stack elements were the layers of a cake, then the strengths would represent the thickness of each layer. At each time step, the controller emits a push weight between 0 and 1 which determines the strength of a new vector pushed onto the stack, and a pop weight between 0 and 1 which determines how much to slice off the top of the stack. The stack reading is computed by examining the top layer of unit thickness and interpolating the vectors proportional to their strengths. This relies on min and max operations, which can have zero gradients. In practice, the model can get trapped in local optima and requires random restarts (Hao et al., 2018). This model also affords less opportunity for parallelization because of the interdependence of stack elements within the same time step. Hao et al. (2018) proposed an extension that uses memory buffers to allow variable-length transductions.
Nondeterminism In all the above models, the stack is essentially deterministic in design. In order to recognize a nondeterministic CFL like {ww R } from left to right, it must be possible, at each time step, for the stack to track all prefixes of the input string read so far. None of the foregoing models, to our knowledge, can represent a set of possiblities like this. Even for deterministic CFLs, this has consequences for trainability; at each time step, training can only update the model from the vantage point of a single stack configuration, making the model prone to getting stuck in local minima.
To overcome this weakness, we propose incorporating a nondeterministic stack, which affords the model a global view of the space of possible ways to use the stack. Our controller emits a probability distribution over stack operations, as in the superposition approach. However, whereas superposition only maintains the per-element marginal distributions over the stack elements, we propose to maintain the full distribution over the whole stack contents. We marginalize the distribution as late as possible, when the controller queries the stack for the current top stack symbol.
In the following sections, we explain our model and compare it against those of Joulin and Mikolov (2015) and Grefenstette et al. (2015). Despite taking longer in wall-clock time to train, our model learns to solve the tasks optimally with a higher rate of success.

Pushdown Automata
In this section, we give a definition of nondeterministic PDAs ( §3.2), describe how to process strings with nondeterministic PDAs in cubic time ( §3.3), and reformulate this algorithm in terms of tensor operations ( §3.4).

Notation
Let be the empty string. Let 1[φ] be 1 when proposition φ is true, 0 otherwise. If A is a matrix, let A i: and A : j be the ith row and jth column, respectively, and define analogous notation for tensors.

Definition
which we write as q, x a − → r, y, to weights. • q 0 ∈ Q is the start state.
In this paper, we do not allow non-scanning transitions (that is, those where a = ). Although this does not reduce the weak generative capacity of PDAs (Autebert et al., 1997), it could affect their ability to learn; we leave exploration of nonscanning transitions for future work.
For simplicity, we will assume that all transitions have one of the three forms: This also does not reduce the weak generative capacity of PDAs.
Given an input string w ∈ Σ * of length n, a configuration is a triple (i, q, β), where i ∈ [0, n] is an input position indicating that all symbols up to and including w i have been scanned, q ∈ Q is a state, and β ∈ Γ * is the content of the stack (written bottom to top). For all i, q, r, β, x, y, we say that run is a sequence of configurations starting with (0, q 0 , ⊥) where each configuration (except the last) yields the next configuration.
Because our model does not use the PDA to accept or reject strings, we omit the usual definitions for the language accepted by a PDA. This is also why our definition lacks accept states.
As an example, consider the following PDA, for the language {ww R | w ∈ {0, 1} * }: This PDA has a possible configuration with an empty stack (⊥) iff the input string read so far is of the form ww R .
To make a weighted PDA probabilistic, we require that all transition weights be nonnegative and, for all a, q, x: Whereas many definitions make the model generate symbols (Abney et al., 1999), our definition makes the PDA operations conditional on the input symbol a. The difference is not very important, because the RNN controller will eventually assume responsibility for reading and writing symbols, but our definition makes the shift to an RNN controller below slightly simpler. Lang (1974) gives an algorithm for simulating all runs of a nondeterministic PDA, related to Earley's algorithm (Earley, 1970). At any point in time, there can be exponentially many possibilities for the contents of the stack. In spite of this, Lang's algorithm is able to represent the set of all possibilities using only quadratic space. As this set is regular, its representation can be thought of as a weighted finite automaton, which we call the stack WFA, similar to the graph-structured stack used in GLR parsing (Tomita, 1987). Figure 1 depicts Lang's algorithm as a set of inference rules, similar to a deductive parser (Shieber et al., 1995;Goodman, 1999), although the visual presentation is rather different. Each inference rule is drawn as a fragment of the stack WFA. If the transitions drawn with solid lines are present in the stack WFA, and the side conditions in the right column are met, then the transition drawn with a dashed line can be added to the stack WFA. The algorithm repeatedly applies inference rules to add states and transitions to the stack WFA; no states or transitions are ever deleted.

Recognition
Each state of the stack WFA is of the form (i, q, x), where i is a position in the input string, q is a PDA state, and x is the top stack symbol. We briefly explain each of the inference rules: Axiom creates an initial state and pushes ⊥ onto the stack.
Push pushes a y on top of an x. Unlike Lang's original algorithm, this inference rule applies whether or not state ( j−1, q, x) is reachable.
Replace pops a z and pushes a y, by backing up the z transition (without deleting it) and adding a new y transition.
Pop pops a z, by backing up the z transition as well as the preceding y transition (without deleting them) and adding a new y transition.
The set of accept states of the stack WFA changes from time step to time step; at step j, the accept states are {( j, q, x) | q ∈ Q, x ∈ Γ}. The language recognized by the stack WFA at time j is the set of possible stack contents at time j.
An example run of the algorithm is shown in Figure 2, using our example PDA and the string 0110. At time step j = 3, the PDA reads 1 and either pushes a 1 (path ending in state (3, q 1 , 1)) or pops a 1 (path ending in state (3, q 2 , 0)). Similarly at time step j = 4, and the existence of a state with top stack symbol ⊥ indicates that the string is of the form ww R .
The total running time of the algorithm is proportional to the number of ways that the inference rules can be instantiated. Since the Pop rule contains three string positions (i, j, and k), the time complexity is O(n 3 ). The total space requirement is characterized by the number of possible WFA transitions. Since transitions connect two states, each with a string position (i and j), the space complexity is O(n 2 ).

Inner and Forward Weights
To implement this algorithm in a typical neuralnetwork framework, we reformulate it in terms of tensor operations. We use the assumption that all transitions are scanning, although it would be possible to extend the model to handle non-scanning transitions using matrix inversions (Stolcke, 1995).
Define Act(Γ) = •Γ∪Γ∪{ } to be a set of possible stack actions: if y ∈ Γ, then •y means "push y," y means "replace with y," and means "pop." Given an input string w, we pack the transition weights of the PDA into a tensor ∆ with dimensions n × |Q| × |Γ| × |Q| × |Act(Γ)|: (1) We compute the transition weights of the stack WFA (except for the initial transition) as a tensor of inner weights γ, with dimensions n × n × |Q| × |Γ| × |Q| × |Γ|. Each element, which we write as , is the weight of the stack WFA transition i, q, x j, r, y y The equations defining γ are shown in Figure 3. Because these equations are a recurrence relation, we cannot compute γ all at once, but (for example) in order of increasing j.
Additionally, we compute a tensor α of forward weights of the stack WFA. This tensor has dimensions n × |Q| × |Γ|, and its elements are defined by the recurrence (2 ≤ j ≤ n).
The weight α[ j][r, y] is the total weight of reaching a configuration (r, j, βy) for any β from the initial configuration, and we can use α to compute the probability distribution over top stack symbols at time step j: .

Neural Pushdown Automata
Now we couple the tensor formulation of Lang's algorithm for nondeterministic PDAs with an RNN controller.

Model
The controller can be any type of RNN; in our experiments, we used a LSTM RNN. At each time step, it computes a hidden vector h ( j) with d dimensions from the previous hidden vector, an input Axiom 0, q 0 , ⊥ ⊥/1 Push j−1, q, x j, r, y y/p p = δ(q, x w j − − → r, •y) Replace i, q, x j−1, s, z j, r, y z/p 1 y/p 1 p p = δ(s, z   vector x ( j) , and the distribution over current top stack symbols, τ ( j) , defined above: where R can be any RNN unit. This state is used to compute an output vector y ( j) as usual: where C and D are tensors of parameters with dimensions |Q| × |Γ| × |Q| × |Act(Γ)| × d and |Q|×|Γ|×|Q|×|Act(Γ)|, respectively. (This is just an affine transformation followed by a softmax over r and y.) These equations replace equations (1).

Implementation
We implemented the NS-RNN using PyTorch (Paszke et al., 2019), and doing so efficiently required a few crucial tricks. The first was a workaround to update the γ and α tensors in-place in a way that was compatible with PyTorch's automatic differentiation; this was necessary to achieve the theoretical quadratic space complexity. The second was an efficient implementation of a differentiable einsum operation 1 that supports the log semiring (as well as other semirings), which allowed us to implement the equations of Figure 3 in 1 https://github.com/bdusell/semiring-einsum a reasonably fast, memory-efficient way that avoids underflow. Our einsum implementation splits the operation into fixed-size blocks where the multiplication and summation of terms can be fully parallelized. This enforces a reasonable upper bound on memory usage while suffering only a slight decrease in speed compared to fully parallelizing the entire einsum operation.

Experiments
In this section, we describe our experiments comparing our NS-RNN and three baseline language models on several formal languages.

Tasks
Marked reversal The language of palindromes with an explicit middle marker, with strings of the form w#w R , where w ∈ {0, 1} * . This task should be easily solvable by a model with a deterministic stack, as the model can push the string w to the stack, change states upon reading #, and predict w R by popping w from the stack in reverse.
Unmarked reversal The language of (evenlength) palindromes without a middle marker, with strings of the form ww R , where w ∈ {0, 1} * . When the length of w can vary, a language model reading the string from left to right must use nondeterminism to guess where the boundary between w and w R lies. At each position, it must either push the input symbol to the stack, or else guess that the middle point has been reached and start popping symbols from the stack. An optimal language model will interpolate among all possible split points to produce a final prediction.
Padded reversal Like the unmarked reversal language, but with a long stretch of repeated symbols in the middle, with strings of the form wa p w R , where w ∈ {0, 1} * , a ∈ {0, 1}, and p ≥ 0. The purpose of the padding is to confuse a language model attempting to guess where the middle of the palindrome is based on the content of the string. In the general case of unmarked reversal, a language model can disregard split points where a valid palindrome does not occur locally. Since all substrings of a p are palindromes, the language model must deal with a larger number of candidates simultaneously.
Dyck language The language D 2 of strings with two kinds of balanced brackets.
Hardest CFL Designed by Greibach (1973) to be at least as difficult to parse as any other CFL: L 0 = {x 1 ,y 1 ,z 1 ; · · · x n ,y n ,z n ; | n ≥ 0, y 1 · · · y n ∈ $D 2 , Intuitively, L 0 contains strings formed by dividing a member of $D 2 into pieces (y i ) and interleaving them with "decoy" pieces (substrings of x i and z i ). While processing the string, the machine has to nondeterministically guess whether each piece is genuine or a decoy. Greibach shows that for any CFL L, there is a string homomorphism h such that a parser for L 0 can be run on h(w) to find a parse for w. See Appendix A for more information.

Data
For each task, we construct a probabilistic contextfree grammar (PCFG) for the language (see Appendix B for the full grammars and their parameters). We then randomly sample a training set of 10,000 examples from the PCFG, filtering samples so that the length of a string is in the interval [40,80] (see Appendix C for our sampling method). The training set remains the same throughout the training process and is not re-sampled from epoch to epoch, since we want to test how well the model can infer the probability distribution from a finite sample.
We sample a validation set of 1,000 examples from the same distribution and a test set with string lengths varying from 40 to 100, with 100 examples per length. The validation set is randomized in each experiment, but for each task, the test set remains the same across all models and random restarts. For simplicity, we do not filter training samples from the validation or test sets, assuming that the chance of overlap is very small.

Evaluation
Since, in these languages, the next symbol cannot always be predicted deterministically from previous symbols, we do not use prediction accuracy as in previous work. Instead, we compute per-symbol cross-entropy on a set of strings S . Let p be any distribution over strings; then: We compute the cross-entropy for both the stack RNN and the distribution from which S is sampled and report the difference. This can be seen as an approximation of the KL divergence of the stack RNN from the true distribution. Technically, because the RNN models do not predict the end of the string, they estimate p(w | |w|), not p(w). However, they do not actually use any knowledge of the length, so it seems reasonable to compare the RNN's estimate of p(w | |w|) with the true p(w). (This is why, when we bin by length in Figure 5, some of the differences are negative.) A benefit of using cross-entropy instead of prediction accuracy is that we can easily incorporate new tasks as long as they are expressed as a PCFG. We do not, for example, need to define a languagedependent subsequence of symbols to evaluate on.

Baselines
We compare our NS-RNN against three baselines: an LSTM, the Stack LSTM of Joulin and Mikolov (2015) ("JM"), and the Stack LSTM of Grefenstette et al. (2015) ("Gref"). We deviate slightly from the original definitions of these models in order to standardize the controller-stack interface to the one defined in Section 4.1, and to isolate the effects of differences in the stack data structure, rather than the controller mechanism. For all three stack models, we use an LSTM controller whose initial hidden state is fixed to 0, and we use only one stack for the JM and Gref models. (In early experiments, we found that using multiple stacks did not make a meaningful difference in performance.) For JM, we include a bias term in the layers that compute the stack actions and network output. We do allow the no-op operation, and the stack reading consists of only the top stack cell. For Gref, we set the controller output o t equal to the hidden state h t , so we compute the stack actions, pushed vector, and network output directly from the hidden state. We encode all input symbols as one-hot vectors; there are no embedding layers.

Hyperparameters
For all models, we use a single-layer LSTM with 20 hidden units. We selected this number because we found that an LSTM of this size could not completely solve the marked reversal task, indicating that the hidden state is a memory bottleneck. For each task, we perform a hyperparameter grid search for each model. We search for the initial learning rate, which has a large impact on performance, from the set {0.01, 0.005, 0.001, 0.0005}. For JM and Gref, we search for stack embedding sizes in {2, 20, 40}. We manually choose a small number of PDA states and stack symbol types for the NS-RNN for each task. For marked reversal, unmarked reversal, and Dyck, we use 2 states and 2 stack symbol types. For padded reversal, we use 3 states and 2 stack symbol types. For the hardest CFL, we use 3 states and 3 stack symbol types.
As noted by Grefenstette et al. (2015), initialization can play a large role in whether a Stack LSTM converges on algorithmic behavior or becomes trapped in a local optimum. To mitigate this, for each hyperparameter setting in the grid search, we run five random restarts and select the hyperparameter setting with the lowest average difference in cross entropy on the validation set. This gives us a picture not only of the model's performance, but of its rate of success. We initialize all fullyconnected layers except for the recurrent LSTM layer with Xavier uniform initialization (Glorot and Bengio, 2010), and all other parameters uniformly from [−0.1, 0.1].
We train all models with Adam (Kingma and Ba, 2015) and clip gradients whose magnitude is above 5. We use mini-batches of size 10; to generate a batch, we first select a length and then sample 10 strings of that length. We train models until convergence, multiplying the learning rate by 0.9 after 5 epochs of no improvement in cross-entropy on the validation set, and stopping after 10 epochs of no improvement.

Results
We show plots of the difference in cross entropy on the validation set between each model and the source distribution in Figure 4. For all tasks, stackbased models outperform the LSTM baseline, indicating that the tasks are effective benchmarks for differentiable stacks. For the marked reversal, unmarked reversal, and hardest CFL tasks, our model consistently achieves cross-entropy closer to the source distribution than any other model. Even for the marked reversal task, which can be solved deterministically, the NS-RNN, besides achieving lower cross-entropy on average, learns to solve the task in fewer updates and with much higher reliability across random restarts. In the case of the mildly nondeterministic unmarked reversal and highly nondeterministic hardest CFL tasks, the NS-RNN converges on the lowest validation crossentropy. On the Dyck language, which is a deterministic task, all stack models converge quickly on the source distribution. We hypothesize that this is because the Dyck language represents a case where stack usage is locally advantageous everywhere, so it is particularly conducive for learning stack-like behavior. On the other hand, we note that our model struggles on padded reversal, in which stack-friendly signals are intentionally made very distant. Although the NS-RNN outperforms the LSTM baseline, the JM model solves the task most effectively, though still imperfectly.
In order to show how each model performs when evaluated on strings longer than those seen during training, in Figure 5, we show cross-entropy on separately sampled test data as a function of string length. All test sets are identical across models and random restarts, and there are 100 samples per length. The NS-RNN consistently does well on string lengths it was trained on, but it is sometimes surpassed by other stack models on strings that are outside the distribution of lengths it was trained on. This suggests that the NS-RNN conforms more tightly to the real distribution seen during training.

Conclusion
We presented the NS-RNN, a neural language model with a differentiable stack that explicitly models nondeterminism. We showed that it offers improved trainability and modeling power over previous stack-based neural language models; the NS-RNN learns to solve some deterministic tasks more effectively than other stack-LSTMs, and achieves the best results on a challenging nondeterministic context-free language. However, we note that the NS-RNN struggled on a task where signals in the data were distant, and did not generalize to longer lengths as well as other stack-LSTMs; we hope to address these shortcomings in future work. We believe that the NS-RNN will prove to be a powerful tool for learning and modeling ambiguous syntax in natural language.  Figure 5: Cross-entropy difference in nats on the test set, binned by string length. Some models achieve a negative difference, for reasons explained in §5.3. Each line is the average of the same five random restarts shown in Figure 4.