How Can Self-Attention Networks Recognize Dyck-n Languages?

We focus on the recognition of Dyck-n (Dn) languages with self-attention (SA) networks, which has been deemed to be a difficult task for these networks. We compare the performance of two variants of SA, one with a starting symbol (SA+) and one without (SA-). Our results show that SA+ is able to generalize to longer sequences and deeper dependencies. For D2, we find that SA- completely breaks down on long sequences whereas the accuracy of SA+ is 58.82%. We find attention maps learned by SA+ to be amenable to interpretation and compatible with a stack-based language recognizer. Surprisingly, the performance of SA networks is at par with LSTMs, which provides evidence on the ability of SA to learn hierarchies without recursion.


Introduction
There is a growing interest in using formal languages to study fundamental properties of neural architectures, which has led to the extraction of interpretable models (Weiss et al., 2018;Merrill et al., 2020). Recent work (Hao et al., 2018;Suzgun et al., 2019;Skachkova et al., 2018) has explored the generalized Dyck-n (D n ) languages, a subset of context-free languages. D n consists of "wellbalanced" strings of parentheses with n different types of bracket pairs, and it is the canonical formal language to study nested structures (Chomsky and Schützenberger, 1959). Weiss et al. (2018) show that LSTMs (Hochreiter and Schmidhuber, 1997) are a variant of the k-counter machine and can recognize D 1 languages. The dynamic counting mechanisms, however, are not sufficient for D n>1 as it requires emulating a pushdown automata. Hahn (2020) shows that for a sufficiently large length, Transformers (Vaswani et al., 2017) will fail to transduce the D 2 language.
We empirically show that with the addition of The rows and columns denote queries and keys, respectively. The layer produces virtually hard attentions, in which each symbol attends only to one preceding symbol or itself. The attended symbol is either the starting symbol (T) or the last unmatched opening bracket.
a starting symbol to the vocabulary, a two-layer multi-headed SA network (i.e., the encoder of a Transformer) is able to learn D n languages, and generalize to longer sequences, although not perfectly. As shown in Figure 1, the network is able to identify the corresponding closing bracket for an opening bracket, in what resembles a stack-based automaton. For example, the symbol "]" in the string "([])", will first pop "[" from the stack, then it attends to "(", the last unmatched symbol, which will determine the next valid closing bracket. The starting symbol (T) enables the model to learn the occurrence of the end of a clause or the end of the sequence, which can be regarded as a mechanism to represent an empty stack. Our work is the first to perform an empirical exploration of SA on formal languages. We present detailed comparison between an SA which incorporates a starting symbol (SA + ), and one that does not (SA − ), and demonstrate significant differences in their generalization across the length of sequences and the depth of dependencies.
Recent work has suggested that the ability of self-attention mechanisms to model hierarchical structures is limited. Shen et al. (2019) show that the performance of Transformers on tasks such as logical inference (Bowman et al., 2015) and ListOps (Nangia and Bowman, 2018) is either poor or worse than LSTMs. Tran et al. (2018) have also reported similar results on SA, concluding that recurrence is necessary to model hierarchical structures. In comparison, our results show that SA + outperforms LSTM on D n languages except for D 2 on longer sequences. Papadimitriou and Jurafsky (2020) posit that the ability of neural models to learn hierarchical structures can be attributed to a "looking back" capability, rather than directly encoding hierarchies. Our analysis sheds light on the ability of SA to learn hierarchical structures by elegantly attending to the correct preceding symbol.

Related Work
Formal languages such as a n b n , a n b n c m d m (context-free) and a n b n c n , a n+m b n c m (contextsensitive) have been extensively studied and recognized using RNNs (Elman, 1990;Das et al., 1992;Steijvers and Grünwald, 1996). But the performance of same recurrent architectures on D n languages is poor and suffers from the lack of generalization. Sennhauser and Berwick (2018) and Bernardy (2018) study the capability of RNNs to predict the next possible closing parenthesis at each position in the D n string and found that the generalization at higher recursion depths is poor. Hao et al. (2018) reported that stack-augmented LSTMs achieve better generalization on D n languages but the network computation does not emulate a stack. More recently, Suzgun et al. (2019) proposed memory-augmented recurrent neural networks and defined a sequence classification task for the recognition of D n languages. Yu et al. (2019) explored the use of attention-based seq2seq framework for D 2 languages and found that the generalization to sequences with higher depths is still lacking. Besides empirical investigations, formal languages have been studied theoretically for understanding the complexity of neural networks (Siegelmann and Sontag, 1992;Pérez et al., 2019), mostly under assumptions that cannot be met in an experimentinfinite precision or unbounded computation time.

Experiments
We follow prior works (Gers and Schmidhuber, 2001;Suzgun et al., 2019), and formulate the recognition of D n languages as a transduction task: Given a valid string, we ask the model to predict the next possible symbols auto-regressively. To illustrate, consider an input string "[ ( ) ] ( [" in the D 2 language, we seek to predict the set of next valid brackets in the string-(, [, or ]. We consider an input to be accurately recognized only if the model correctly predicts the set of all possible brackets at each position in the input sequence. Throughout the paper, we refer to a clause as a substring, in which the number of closing and opening brackets of each type of bracket are equal. We train two multi-headed self-attention networks (i.e., only the encoder part of a Transformer), one of which incorporates an additional starting symbol in the vocabulary (SA + ), and the other does not (SA − ). For each model, the number of layers is 2, the number of attention heads h = 4 and model dimension d = 256. We use learnable embeddings to convert each input symbol to a 256-dimensional vector. We also add residual connections around each layer followed by layer normalization, similar to the standard Transformer (Vaswani et al., 2017). We train two unidirectional LSTMs, one with the starting symbol (LSTM + ) and the other without it (LSTM − ). The LSTMs use 320-dimensional hidden states and a 320-dimensional vector for learned input embeddings. Our SA and LSTM variants all have around 1.6M parameters 1 . We use Adam (Kingma and Ba, 2015) for optimization. For SA + and SA − , we vary the learning rate η as η = const · min(itr −0.5 , itr · warmup −1.5 ), (1) where itr refers to the iteration number and warmup is set to 10k. We tuned the hyper-parameter const, using the values [0.01, 0.1, 1.0, 10], and used 0.1. For LSTMs, we use an initial learning rate of 0.001 but with no learning rate scheduling.
We re-generate the synthetic dataset for our experiments through the probabilistic context-free grammar (PCFG) already described in the existing literature (Suzgun et al., 2019). For instance, the PCFG for Dyck-2 language can be defined as: (4) S − → ε, each with probability p = 0.25. For each D n language, we train on 32k sequences of length 2-50, validate on 3.2k sequences of length 52-74, and evaluate on 10k sequences divided equally over the length intervals 76-100 and 102-126.   .. is 0.5. We perform experiments on D 1 , D 2 , D 3 , and D 4 languages. Note that the number of pairs of parentheses cannot be increased arbitrarily without requiring modifications to the experimental setup: We varied the length of sequences during training from 2 to 50, which could contain at most 25 different pairs. In our sequence prediction task, the input vocabulary (V i n ) for a D n language consists of 2n+1 symbols: n pairs of brackets (or parentheses), and an additional starting symbol T whereas the output vocabulary (V o n ) does not include the starting symbol T. Since there might exist multiple possibilities for the next bracket in a sequence, we adopt a multi-label classification approach wherein the outputs are encoded as a k-hot vector and the network is optimized using the binary cross-entropy loss function given by where |V o n | is the output vocabulary size (2 for D 1 , 4 for D 2 , 6 for D 3 , 8 for D 4 ),ŷ i ∈ {0, 1} and y i are the target and prediction for label i, respectively. Table 1 compares the accuracy of SA + and SA − on D 1 , D 2 , D 3 , and D 4 languages. For both models, the performance on D 1 is almost perfect (> 98%) and does not show any degradation with increase in sequence length. The accuracy of SA − on D 2 is 14.52% for sequences with length 76-100 and completely fails beyond it. In comparison, the performance of SA + on D 2 is significantly better, 93.34% and 58.82% for sequences of length 76-100 and 102-126, respectively. The performance of SA − improves on D 3 and D 4 , compared to D 2 , with an accuracy of 32.62% and 42.94%, respectively for sequences of length 76-100. The performance of SA + is nearly constant (∼93%) on D n≥2 for sequences of length 76-10 but there is significant improvement from D 2 (58.82%) to D 3 (66.88%) and D 4 (72.38%) for sequences of length 102-126. Unlike SA, the performance of LSTM degrades after the addition of the starting symbol, with the biggest drop (4.3%) on D 4 for sequence length of 102-106. The starting symbol has enabled SA to attend to the correct preceding token, but it has been ineffective for LSTM. For D 2 sequences of length 102-126, LSTM − achieves an accuracy of 73.20%, an improvement of ∼14% over SA + . On all other comparisons, SA + outperforms LSTM − .

Evaluation
We observe another interesting distinction between the two architectures. The accuracy of LSTM deteriorates as the number of pairs of brackets increases, while the accuracy of SA + and SA − improves. To understand this phenomenon, we looked at the training, validation, and test sets of each language, and found that while validation and test sets of each D n language almost always (> 99%) includes sequences of n different brackets, the training set could include sequences of 1 ≤ m < n types of brackets. This implies that SA benefits from data augmentation with sequences from other languages, and LSTM does not. Put dif-  Figure 3: In a, we plot the distribution of the errors made by SA + and SA − , based on the position of the mispredicted symbol, and its distance to its head. In b and c, we plot the performance of the models as depth increases.
ferently, these results suggest LSTM has a strong inductive bias, perhaps in counting (Kharitonov and Chaabouni, 2020), which might result in degradation of its performance in higher Dyck languages.

Error Analysis
We define failure position (f p ) as the position of the first symbol in the sequence where the model failed to correctly predict the next set of possible parentheses, For each symbol in a D n sequence: (i) depth (d p ) is the number of unmatched parenthesis up to and including that symbol, and (ii) distance to head (d h ) is the number of symbols between the mis-classified closing bracket and its opening counterpart. Figure 3a plots the error distribution of SA + and SA − in terms of failure position (f p ) and distance to head (d h ). There is a clear separation between the two models in terms of what "types" of errors are made. SA − breaks quite early on in the sequence, with majority of the errors occurring at f p = 25-75 whereas whereas the errors of SA + are mostly concentrated at f p > 80. Figure 3bc shows how the performance of SA + and SA − change with depth (d p ) for D 2 and D 4 languages. SA − is very sensitive to depth as the accuracy decreases rapidly for D 2 from ∼38% at d p = 10 to a complete failure beyond d p = 20. In comparison, the drop in accuracy for SA + is less severe, ∼ 94% at d p = 10 to ∼ 72% at d p = 20.

Compatibility With a Stack-Based Recognizer
The ability of (memory-less) SA networks to recognize D n>1 languages is intriguing. In this section, we contrast second-layer attention maps produced by SA + and SA − , and provide insights into the underlying mechanism which leads to the success of SA + . We define compatibility as a quantitative measure for the alignment of the state of a stack-based language recognizer (M ) with the attention maps. M has access to the top of a hypothetical stack, and can push and pop depending on the opening and closing brackets, respectively. Based on this analogy, all opening brackets should attend to themselves, and all closing brackets should first do a pop, and then attend to the last unmatched bracket. For example, the symbol "]" in the string "([])", will first pop "[" from the stack, then it attends to "(", the last unmatched symbol, which will determine the next valid closing bracket. If for every closing symbol in the sequence, the highest attention score of at least one of the heads points to the correct bracket, then we consider the SA compatible. Furthermore, for a fair comparison between SA + and SA − , we do not push the starting symbol to the stack and only consider closing brackets which are not at the end of a clause.   versus sequence length. We find that SA − on D 2 has almost zero compatibility, even for sequence lengths seen during training (40-50), on which it achieves close-to-perfect accuracy. In comparison, SA + has perfect compatibility for sequence lengths seen during training, and maintains a high degree of compatibility for longer ones. Further, perhaps not surprisingly, the Pearson correlation between the distribution of accuracy and compatibility across lengths 50-100 is 90% for all SA + models. Figure 4 shows the attention maps of all four heads of SA + and SA − for the D 2 sequence "([([])])". We observe that the third head of SA + matches our expectation of a stack-based recognizer. An important feature of the third head is that the last symbol attends to the starting symbol T. The starting symbol has enabled the model to learn the occurrence of the end of a clause and the end of the whole sequence.

Conclusion and Future Work
We provide empirical evidence on the ability of self-attention (SA) networks to learn generalized D n languages. We compare the performance of two SA networks, SA + and SA − , which differ only in the inclusion of a starting symbol in their vocabulary. We demonstrate that a simple addition of the starting symbol helps SA + generalize to sequences that are longer and have higher depths. The competitive performance of SA (no-recurrence) against LSTMs might seem surprising, considering that the recognition of D n languages is an inherently hierarchical task. From our experiments, we conclude that recognizing Dyck languages is not tied to recursion, but rather learning the right representations to look up the head token. Further, we find that the representations learned by SA + are highly interpretable and the network performs computations similar to a stack automaton. Our results suggest formal languages could be an interesting avenue to explore the interplay between performance and interpretability for SA. Comparisons between SA and LSTM reveal interesting contrast between the two architectures which calls for further investigation. Recent work (Katharopoulos et al., 2020) shows how to express the Transformer as an RNN through linearization of the attention mechanism, which could lay grounds for more theoretical analysis of these neural architectures (e.g., inductive biases and complexity.)