Evaluating Attribution Methods using White-Box LSTMs

Interpretability methods for neural networks are difficult to evaluate because we do not understand the black-box models typically used to test them. This paper proposes a framework in which interpretability methods are evaluated using manually constructed networks, which we call white-box networks, whose behavior is understood a priori. We evaluate five methods for producing attribution heatmaps by applying them to white-box LSTM classifiers for tasks based on formal languages. Although our white-box classifiers solve their tasks perfectly and transparently, we find that all five attribution methods fail to produce the expected model explanations.


Introduction
Attribution methods are a family of interpretability techniques for individual neural network predictions that attempt to measure the importance of input features for determining the model's output. Given an input, an attribution method produces a vector of attribution or relevance scores, which is typically visualized as a heatmap that highlights portions of the input that contribute to model behavior. In the context of NLP, attribution scores are usually computed at the token level, so that each score represents the importance of a token within an input sequence. These heatmaps can be used to identify keywords upon which networks base their decisions (Li et al., 2016;Sundararajan et al., 2017;Arras et al., 2017a,b;Murdoch et al., 2018, inter alia).
One of the main challenges facing the evaluation of attribution methods is that it is difficult to assess the quality of a heatmap when the network in question is not understood in the first place. If a word is deemed relevant by an attribution method, we do not know whether the model actually considers that word relevant, or whether the attribu-tion method has erroneously estimated its importance. Indeed, previous studies have argued that attribution methods are sensitive to features unrelated to model behavior in some cases (e.g., Kindermans et al., 2019), and altogether insensitive to model behavior in others (Adebayo et al., 2018).
To tease the evaluation of attribution methods apart from the interpretation of models, this paper proposes an evaluation framework for attribution methods in NLP that uses only models that are fully understood a priori. Instead of testing attribution methods on black-box models obtained through training, we construct white-box models for testing by directly setting network parameters by hand. Our focus is on white-box LSTMs that implement intuitive strategies for solving simple classification tasks based on formal languages with deterministic solutions. We apply our framework to five attribution methods: occlusion (Zeiler and Fergus, 2014), saliency (Simonyan et al., 2014;Li et al., 2016), gradient × input, (G × I, Shrikumar et al., 2017), integrated gradients (IG, Sundararajan et al., 2017), and layer-wise relevance propagation (LRP, Bach et al., 2015). In doing so, we make the following contributions.
• We construct four white-box LSTMs that can be used to test attribution methods. We provide a complete description of our model weights in Appendix A. 1 Beyond the five methods considered here, our white-box networks can be used to test any attribution method compatible with LSTMs.
• Empirically, we show that all five attribution methods produce erroneous heatmaps for our white-box networks, despite the models' transparent behavior. As a preview of our re-Task: Determine whether the input contains one of the following subsequences: ab, bc, cd, or dc. Output: True, since the input aacb contains two (noncontiguous) instances of ab.
Occlusion Saliency G × I IG LRP a a c b a a c b a a c b a a c b a a c b a a c b a a c b a a c b a a c b a a c b Table 1: Sample heatmaps for two white-box networks: a "counter-based" network (top) and an "FSA-based" network (bottom). The features relevant to the output are the two as and the b.
sults, Table 1 shows sample heatmaps computed for two models designed to identify the non-contiguous subsequence ab in the input aacb. Even though both models' outputs are determined by the presence of the two as and the b, all four methods either incorrectly highlight the c or fail to highlight at least one of the as in at least one case.
• We identify two general ways in which four of the five methods do not behave as intended. Firstly, while saliency, G × I and IG are theoretically invariant to differences in model implementation (Sundararajan et al., 2017), in practice we find that these methods can still produce qualitatively different heatmaps for nearly identical models. Secondly, we find that LRP is susceptible to numerical issues, which cause heatmaps to be zeroed out when values are rounded to zero.

Related Work
Several approaches have been taken in the literature for understanding how to evaluate attribution methods. On a theoretical level, axiomatic approaches propose formal desiderata that attribution methods should satisfy, such as implementation invariance (Sundararajan et al., 2017), input translation invariance (Kindermans et al., 2019), continuity with respect to inputs (Montavon et al., 2018;Ghorbani et al., 2019), or the existence of relationships between attribution scores and logit or softmax scores (Sundararajan et al., 2017;Ancona et al., 2018;Montavon, 2019). The degree to which attribution methods fulfill these criteria can be determined either mathematically or empirically.
Other approaches, which are more experimental in nature, attempt to directly assess the relationship between attribution scores and model behav-ior. A common test, due to Bach et al. (2015) and Samek et al. (2017) and applied to sequence modeling by Arras et al. (2017a), involves ablating or perturbing parts of the input, from those with the highest attribution scores to those with the lowest, and counting the number of features that need to be ablated in order to change the model's prediction. Another test, proposed by Adebayo et al. (2018), tracks how heatmaps change as layers of a network are incrementally randomized.
A third kind of approach evaluates the extent to which heatmaps identify salient input features. For example, Zhang et al. (2018) propose the pointing game task, in which the highest-relevance pixel for an image classifier input must belong to the object described by the target output class. Within this framework, ), Poerner et al. (2018, Arras et al. (2019), and Yang and Kim (2019) construct datasets in which input features exhibit experimentally controlled notions of importance, yielding "ground truth" attributions against which heatmaps can be evaluated.
Our paper incorporates elements of the groundtruth approaches, since it is straightforward to determine which input features are important for our formal language tasks. We enhance these approaches by using white-box models that are guaranteed to be sensitive to those features.

Formal Language Tasks
Formal languages are often used to evaluate the expressive power of RNNs. Here, we focus on formal languages that have been recently used to probe LSTMs' ability to capture three kinds of dependencies: counting, long-distance, and hierarchical dependencies. We define a classification task based on each of these formal languages.

Counting Dependencies
Counter languages (Fischer, 1966;Fischer et al., 1968) are languages recognized by automata equipped with counters. Weiss et al. (2018) demonstrate using an acceptance task for the languages a n b n and a n b n c n that LSTMs naturally learn to use cell state units as counters. Merrill's (2019) asymptotic analysis shows that LSTM acceptors accept only counter languages when their weights are fully saturated. Thus, counter languages may be viewed as a characterization of the expressive power of LSTMs. We define the counting task based on a simple example of a counting language.
Task 1 (Counting Task). Given a string in x ∈ {a, b} * , determine whether or not x has strictly more as than bs.
Example 2. The counting task classifies aaab as True, ab as False, and bbbba as False.
A counter automaton can solve the counting task by incrementing its counter whenever an a is encountered and decrementing it whenever a b is encountered. It outputs True if and only if its counter is at least 1. We expect attribution scores for all input symbols to have roughly the same magnitude, but that scores assigned to a will have the opposite sign to those assigned to b.

Long-Distance Dependencies
Strictly piecewise (SP, Heinz, 2007) languages were used by Avcu et al. (2017) and Kelleher (2018, 2019a,b) to test the propensity of LSTMs to learn long-distance dependencies, compared to Elman's (1990) simple recurrent networks. SP languages are regular languages whose membership is defined by the presence or absence of certain subsequences, which may or may not be contiguous. For example, ad is a subsequence of abcde, since both letters of ad occur in abcde, in the same order. Based on these ideas, we define the SP task as follows.
Task 3 (SP Task). Given x ∈ {a, b, c, d} * , determine whether or not x contains at least one of the following subsequences: ab, bc, cd, dc.
Example 4. In the SP task, aab is classified as True, since it contains the subsequence ab. Similarly, acb is classified as True, since it contains ab non-contiguously. The string aaa is classified as False.
The choice of SP languages as a test for longdistance dependencies is motivated by the fact that symbols in a non-contiguous subsequence may occur arbitrarily far from one another. The SP task yields a variant of the pointing game task in the sense that the input string may or may not contain an "object" (one of the four subsequences) that the network must identify. Therefore, we expect an input symbol to receive a nonzero attribution score if and only if it comprises a subsequence.

Hierarchical Dependencies
The Dyck language is the language D generated by the following context-free grammar, where ε is the empty string.
D contains all balanced strings of parentheses and square brackets. Since D is often viewed as a canonical example of a context-free language (Chomsky and Schützenberger, 1959), several recent studies, including Sennhauser and Berwick (2018), Bernardy (2018), Skachkova et al. (2018), and Yu et al. (2019), have used D to evaluate whether LSTMs can learn hierarchical dependencies implemented by pushdown automata. Here, we consider the bracket prediction task proposed by Sennhauser and Berwick (2018).
Task 5 (Bracket Prediction Task). Given a prefix p of some string in D, identify the next valid closing bracket for p. In heatmaps for the bracket prediction task, we expect the last unclosed bracket to receive the highest-magnitude relevance score.

White-Box Networks
We use two approaches to construct white-box networks for our tasks. In the counter-based approach, the cell state contains a set of counters, which are incremented or decremented throughout the computation. The network's final output is based on the values of the counters. In the automaton-based approach, we use the LSTM to simulate an automaton, with the cell state containing a representation of the automaton's state. We use a counter-based network to solve the counter task and an automaton-based network to solve the bracket prediction task. We use both kinds of networks to solve the SP task. All networks perfectly solve the tasks they were designed for. This section describes our white-box networks at a high level; a detailed description is given in Appendix A.
In the rest of this paper, we identify the alphabet symbols a, b, c, and d with the one-hot vectors for indices 1, 2, 3, and 4, respectively. The vectors f (t) , i (t) , and o (t) represent the forget, input, and output gates, respectively. g (t) is the value added to the cell state at each time step, and σ represents the sigmoid function. We assume that the hidden state h (t) and cell state c (t) are updated as follows.

Counter-Based Networks
In the counter-based approach, each position of the cell state contains the value of a counter. To adjust the counter in position j by some value v ∈ (−1, 1), we set g (t) j = v, and we saturate the gates by setting them to σ(m) ≈ 1, where m ≫ 0 is a large constant. For example, our network for the counting task uses a single hidden unit, with the gates always saturated and with g (t) given by where u > 0 is a hyperparameter that scales the counter by a factor of v = tanh(u). 2 When For the SP task, we use seven counters. The first four counters record how many occurrences of each symbol have been observed at time step t. The next three counters record the number of bs, cs, and ds that form one of the four distinguished subsequences with an earlier symbol. For example, after seeing the input aaabbc, the counterbased network for the SP task satisfies The first four counters represent the fact that the input has 3 as, 2 bs, 1 c, and no ds. Counter #5 is 2v because the two bs form a subsequence with the as, and counter #6 is v because the c forms a subsequence with the bs. The logit scores of our counter-based networks are computed by a linear decoder using the tanh of the counter values. For the counting task, the score of the True class is h (t) , while the score of the False class is fixed to tanh(v)/2. This means that the network outputs True if and only if the final counter value is at least v. For the SP task, the score of the True class is h , while the score of the False class is again tanh(v)/2.

Automata-Based Networks
We consider two types of automata-based networks: one that implements a finite-state automaton (FSA) for the SP task, and one that implements a pushdown automaton (PDA) for the bracket prediction task.
Our FSA construction is similar to Korsky and Berwick's (2019) FSA construction for simple recurrent networks. Consider a deterministic FSA A with states Q and alphabet Σ. To simulate A using an LSTM, we use |Q| · |Σ| hidden units, with the following interpretation. Suppose that A transitions to state q after reading input , which encodes both the current state of A and the most recent input symbol. Since the FSA undergoes a state transition with each input symbol, the forget gate always clears c (t) , so that information written to the cell state does not persist beyond a single time step. The output layer simply detects whether or not the FSA is in an accepting state. Details are provided in Appendix A.3.
Next, we describe how to implement a PDA for the bracket prediction task. We use a stack containing all unclosed brackets observed in the input string, and make predictions based on the top item of the stack. We represent a bounded stack of size k using 2k + 1 hidden units. The first k − 1 positions contain all stack items except the top item, with ( represented by the value 1, [ represented by −1, and empty positions represented by 0. The kth position contains the top item of the stack. The next k positions contain the height of the stack in unary notation, and the last position contains a bit indicating whether or not the stack is empty. For example, after reading the input ([(() with a stack of size 4, the stack contents ([( are represented by The 1 in position 4 indicates that the top item of the stack is (, and the 1, −1, and 0 in positions 1-3 indicate that the remainder of the stack is ([. The three 1s in positions 5-8 indicate that the stack height is 3, and the 0 in position 9 indicates that the stack is not empty. When is copied to the highest empty position in c (t) :k−1 , pushing the opening bracket to the stack. The empty stack bit is then set to 0, marking the stack as non-empty. When the current input symbol is a closing bracket, the highest item of positions 1 through k − 1 is deleted and copied to position k, popping the top item from the stack. Because the PDA network is quite complex, we focus here on describing how the top stack item in position k is determined, and leave other details for Ap- , where m ≫ 0 and k contains the stack encoding of the current input symbol if it is an opening bracket. If the current input symbol is a closing bracket, then α (t) = 0, so the sign of u (t) is determined by the highest item of h (t−1) :k−1 .

Attribution Methods
Let X be a matrix of input vectors, such that the input at time t is the row vector X t,: = ( x (t) ) ⊤ . Given X, an LSTM classifier produces a vector y of logit scores. Based on X,ŷ, and possibly a baseline input X, an attribution method assigns an attribution score R (c) t,i (X) to input feature X t,i for each output class c. These feature-level scores are then aggregated to produce token-level scores: Broadly speaking, our five attribution methods are grouped into three types: one perturbation-based, three gradient-based, and one decompositionbased. The following subsections describe how each method computes R (c) t,i (X).

Perturbation-and Gradient-Based Methods
Perturbation-based methods are premised on the idea that if X t,i is an important input feature, then changing the value of X t,i would causeŷ to change. The one perturbation method we consider is occlusion. In this method, R t,i (X) is the change inŷ c observed when X t,: is replaced by 0.
Gradient-based methods rely on the same intuition as perturbation-based methods, but use automatic differentiation to simulate infinitesimal perturbations. The definitions of our three gradientbased methods are given in Table 2. The most basic of these is saliency, which simply measures relevance by the derivative of the logit score with respect to each input feature. G × I attempts to improve upon saliency by using the first-order terms in a Taylor-series approximation of the model instead of the gradients on their own. IG is designed to address the issue of small gradients found in saturated units by integrating G × I along the line connecting X to a baseline input X, here taken to be the zero matrix.

Decomposition-Based Methods
Decomposition-based methods are methods that satisfy the relation where R (c) bias is a relevance score assigned to the bias units of the network. The interpretation of equation (2) is that the logit scoreŷ c is "distributed" among the input features and the bias units, so that the relevance scores form a "decomposition" ofŷ c .
The one decomposition-based method we consider is LRP, which computes scores using a backpropagation algorithm that distributes scores layer by layer. The scores of the output layer are initialized to For each layer l with activation z (l) , activation function f (l) , and output a (l) = f (l) ( z (l) ) , the relevance r (c,l) of a (l) is determined by the following propagation rule: where l ′ ranges over all layers to which l has a forward connection via W (l ′ ←l) and ε > 0 is a stabilizing constant. 3 For the LSTM gate interactions, we follow Arras et al. (2017b) in treating multiplicative connections of the form a (l 1 ) ⊙a (l 2 ) as activation functions of the form a (l 1 ) ⊙ f (l 2 ) (·), where a (l 1 ) is f (t) , i (t) , or o (t) . The final attribution scores are given by the values propagated to the input layer:

Qualitative Evaluation
To evaluate attribution methods under our framework, we begin with a qualitative description of the heatmaps that are computed for our whitebox networks, based on the illustrative sample of heatmaps appearing in Table 3.

Counting Task
Occlusion, G × I, and IG are well-behaved for the counting task. As expected, these methods assign a a positive value and b a negative value when the output class for attribution is c = True. When the number of as is different from the number of bs, occlusion assigns a lower-magnitude score to the symbol with fewer instances. When c = False, all relevance scores are 0. This is becauseŷ False is fixed to a constant value supplied by a bias term, so input features cannot affect its value. Saliency and LRP both fail to produce nonzero scores, at least in some cases. Saliency scores satisfy R (True) t,1 (X) = −R (True) t,2 (X), resulting in token-level scores of 0 for all inputs. Heatmaps #3 and #4 show that LRP assigns scores of 0 to prefixes containing equal numbers of as and bs. We will see in Subsection 7.1 that this phenomenon appears to be related to the fact that the LSTM gates are saturated.

SP Task
We obtain radically different heatmaps for the two SP task networks, despite the fact that they produce the same classifications for all inputs.
For the counter-based network, all methods except for saliency assign positive scores for c = True to symbols constituting one of the four subsequences, and scores of zero elsewhere. The saliency heatmaps do not adhere to this pattern, and instead generally assign higher scores to tokens occurring near the end of the input. Heatmaps #7-10 show that LRP fails to assign positive scores to the first symbol of each subsequence, while the other methods generally do not. 4 The LRP behavior reflects the fact that the initial a does not increment the subsequence counters, which determine the final logit score. In contrast, the behavior of occlusion, G × I, and IG is explained by the fact that removing either the a or the b destroys the subsequence. Note that the as in heatmap #9 receive scores of 0 from occlusion and G × I, since removing only one of the two as does not destroy the subsequence.
For the FSA-based network, saliency, G × I, and LRP assign only the last symbol a nonzero score when the relevance output class c matches the network's predicted class. IG appears to produce erratic heatmaps, exhibiting no immediately obvious pattern. Although occlusion appears to be erratic at first glance, its behavior can be explained by the fact that changing x (t) to 0 causes h (t) to be 0, which the LSTM interprets as the initial state of the FSA; thus, R (c) t (X) ̸ = 0 precisely when X t+1:,: is classified differently from X. In all cases, the heatmaps for the FSA-based network diverge significantly from the expected heatmaps.

Bracket Prediction Task
The heatmaps for the PDA-based network also differ strikingly from those of the other networks, in that the gradient-based methods never assign nonzero scores. This is because equation (1) causes g (t) to be highly saturated, resulting in zero gradients. In the case of LRP, the matching bracket is highlighted when c ̸ = None. When the matching bracket is not the last symbol of the input, the other unclosed brackets are also highlighted, with progressively smaller magnitudes, and with brackets of the opposite type from c receiving negative scores. This pattern reflects the mechanism of (1), in which progressively larger powers of 2 are used to determine the content copied to c (t) k . When the relevance output class is c = None, LRP assigns opening brackets a negative score, revealing the fact that those input symbols set the bit c (t) 2k+1 to indicate that the stack is not empty. Although occlusion sometimes highlights the matching bracket, it does not appear to be consistent in doing so. For example, it fails to highlight the matching bracket  a a a b b  a a a b b  a a a b b  a a a b b  a a a True False  a a a b b b a a a b b b a a a b b b a a a b b b a a a b a a a b b  a a a b b  a a a b b  a a a ( [ [ ( [ ( [ [ ( [ ( [ [ ( [ ( [ [ ( [ 22 ) ) (    in heatmap #21, and highlights one other bracket in heatmaps #23-24.

Detailed Evaluations
We now turn to focused investigations of particular phenomena that attribution methods exhibit when applied to white-box networks. Subsection 7.1 begins by discussing the effect of network saturation on the gradient-based methods and LRP. In Subsection 7.2 we apply Bach et al.'s (2015) ablation test to our attribution methods for the SP task.

Saturation
As mentioned in the previous section, network saturation causes gradients to be approximately 0 when using sigmoid or tanh activation functions.  by saturation, Table 4 shows heatmaps for the input accb generated by gradient-based methods for different instantiations of the counter-based SP network with varying degrees of saturation. Recall from Section 4 that counter values for this network are expressed in multiples of the scaling factor v. We control the saturation of the network via the parameter u = tanh −1 (v). For all three gradient-based methods, scores for a decrease and scores for b increase as u increases. Additionally, saliency scores for the first c decrease when u increases. When u = 8, v is almost completely saturated, causing G × I to produce all-zero heatmaps.
On the other hand, IG is still able to produce nonzero heatmaps even at u = 64. Thus, IG is much more resistant to the effects of saturation than G × I. According to Sundararajan et al. (2017), gradient-based methods satisfy the axiom of implementation invariance: they produce the same heatmaps for any two networks that compute the same function. This formal property is seemingly at odds with the diverse array of heatmaps appearing in Table 4, which are produced for networks that all yield identical classifiers. In particular, the networks with u = 8, 16, and 64 yield qualitatively different heatmaps, despite the fact that the three networks are distinguished only by differences in v of less than 0.001. Because the three functions are technically not equal, implementation invariance is not violated in theory; but the fact that IG produces different heatmaps for three nearly identical networks shows that the intuition described by implementation invariance is not borne out in practice.
Besides the gradient-based methods, LRP is also susceptible to problems arising from saturation. Recall from heatmaps #3 and #4 of Table 3 that for the counting task network, LRP assigns scores of 0 to prefixes with equal numbers of as and bs. We hypothesize that this phenomenon is related to the fact c (t) = 0 after reading such prefixes, since the counter has been incremented and decremented in equal amounts. Accordingly, we test whether this phenomenon can be mitigated by desaturating the gates so that c (t) does not exactly reach 0. Recall that the white-box LSTM gates approximate 1 ≈ σ(m) using a constant m ≫ 0. We construct networks with varying values of m and compute LRP scores on a randomly generated testing set of 1000 strings, each of which contains at least one prefix with equal numbers of as and bs. In Table 5 we report the percentage of examples for which such prefixes receive LRP scores of 0, along with the network's accuracy on this testing set and the average value of c (t) when the counter reaches 0. Indeed, the percentage of prefixes receiving scores of 0 increases as the approximation c (t) ≈ 0 becomes more exact.

Ablation Test
So far, we have primarily compared attribution methods via visual inspection of individual examples. To compare the five methods quantitatively,  Table 6: Mean and standard deviation results of the ablation test, normalized by string length and expressed as a percentage. "Optimal" is the best possible score.
we apply the ablation test of Bach et al. (2015) to our two white-box networks for the SP task. 5 Given an input string classified as True, we iteratively remove the symbol with the highest relevance score, recomputing heatmaps at each iteration, until the string no longer contains any of the four subsequences. We apply the ablation test to 100 randomly generated input strings, and report the average percentage of each string that is ablated in Table 6. A peculiar property of the SP task is that removing a symbol preserves the validity of input strings. This means that, unlike in NLP settings, our ablation test does not suffer from the issue that ablation produces invalid inputs. Saliency, G × I, and LRP perform close to the random baseline on the FSA network; this is unsurprising, since these methods only assign nonzero scores to the last input symbol. While Table 3 shows some variation in the IG heatmaps, IG also performs close to the random baseline. Only occlusion performs considerably better, since it is able to identify symbols whose ablation would destroy subsequences.
On the counter-based SP network, IG performs remarkably close to the optimal benchmark, which represents the best possible performance on this task. Occlusion, G × I, and LRP achieve a similar level of performance to one another, while saliency performs worse than the random baseline.

Conclusion
Of all the heatmaps considered in this paper, only those computed by G × I and IG for the counting task fully matched our expectations. In other cases, all attribution methods fail to identify at least some of the input features that should be considered relevant, or assign relevance to input features that do not affect the model's behavior. Among the five methods, saliency achieves the worst performance: it never assigns nonzero scores for the counting and bracket prediction tasks, and it does not identify the relevant symbols for either of the two SP networks. Saliency also achieves the worst performance on the ablation test for both the counterbased and the FSA-based SP networks. Among the four white-box networks, the two automatabased networks proved to be much more challenging for the attribution methods than the counterbased networks. While the LRP heatmaps for the PDA network correctly identify the matching bracket when available, no other method produces reasonable heatmaps for the PDA network, and all five methods fail to interpret the FSA network.
Taken together, our results suggest that attribution heatmaps should be viewed with skepticism. This paper has identified cases in which heatmaps fail to highlight relevant features, as well as cases in which heatmaps incorrectly highlight irrelevant features. Although most of the methods perform better for the counter-based networks than the automaton-based networks, in practical settings we do not know what kinds of computations are implemented by a trained network, making it impossible to determine whether the network under analysis is compatible with the attribution method being used.
In future work, we encourage the use of our four white-box models as qualitative benchmarks for evaluating interpretability methods. For example, the style of evaluation we have developed can be replicated for attribution methods not covered in this paper, including DeepLIFT (Shrikumar et al., 2017) and contextual decomposition (Murdoch et al., 2018). We believe that insights gleaned from white-box analysis can help researchers choose between different attribution methods and identify areas of improvement in current techniques.

A Detailed Descriptions of White-Box Networks
This appendix provides detailed descriptions of our four white-box networks.
A.1 Counting Task Network As described in Subsection 4.1, the network for the counting task simply sets g (t) to v = tanh(u) when x (t) = a and −v when x (t) = b. All gates are fixed to 1. The output layer uses h (t) = tanh ( c (t) ) as the score for the True class and v/2 as the score for the False class.
The seven counters for the SP task are implemented as follows. First, we compute g (t) under the assumption that one of the first four counters is always incremented, and one of the last three counters is always incremented as long as x (t) ̸ = a.
Then, we use the input gate to condition the last three counters on the value of the first four counters. For example, if h (t−1) 1 = 0, then no as have been encountered in the input string before time t. In that case, the input gate for counter #5, which represents subsequences ending with b, is set to i (t) 5 = σ(−m) ≈ 0. This is because a b encountered at time t would not form part of a subsequence if no as have been encountered so far, so counter #5 should not be incremented.
All other gates are fixed to 1. The output layer sets the score of the True class to h 7 and the score of the False class to v/2.
Here we describe a general construction of an LSTM simulating an FSA with states Q, accepting states Q F ⊆ Q, alphabet Σ, and transition function δ : Q × Σ → Q. Recall that h (t) contains a one-hot representation of pairs in Q × Σ encoding the current state of the FSA and the most recent input symbol. The initial state h (0) = 0 represents the starting configuration of the FSA.
At a high level, the state transition system works as follows. First, g (t) first marks all the positions corresponding to the current input x (t) . 6 The input gate then filters out any positions that do not represent valid transitions from the previous state q ′ , which is recovered from h (t−1) .
Now, we describe how this behavior is implemented in our LSTM. The cell state update is straightforwardly implemented as follows: , where W (c,x) ⟨q,x⟩,j = { 1, j is the index for x 0, otherwise.
Observe that the matrix W (c,x) essentially contains a copy of I 4 for each state, such that each copy is distributed across the different cell state units designated for that state. The input gate is more complex. First, the bias term handles the case where the current case is the starting state q 0 . This is necessary because the initial configuration of the network is represented by The bias vector sets i (t) ⟨q,x⟩ to be 1 if the FSA transitions from q 0 to q after reading x, and 0 otherwise. We replicate this behavior for other values 6 We use v = tanh(1) ≈ 0.762. of h (t−1) by using the weight matrix W (i,h) , taking the bias vector into account: The forget gate is fixed to −1, since the state needs to be updated at every time step. The output gate is fixed to 1.
The output layer simply selects hidden units that represent accepting and rejecting states: where W c,⟨q,x⟩ =      1, c = True and q ∈ Q F 1, c = False and q / ∈ Q F 0, otherwise.

A.4 PDA Network
Finally, we describe how the PDA network for the bracket prediction task is implemented. Of the four networks, this one is the most intricate. Recall from Subsection 4.2 that we implement a bounded stack of size k using 2k + 1 hidden units, with the following interpretation: • c 2k+1 is a bit, which is set to be positive if the stack is empty and nonpositive otherwise.
We represent the brackets (, [, ), and ] in onehot encoding with the indices 1, 2, 3, and 4, respectively. The opening brackets ( and [ are represented on the stack by 1 and −1, respectively. T stack height is incremented, only the appropriate position is updated. Finally, the bias vector ensures that the top stack item and the empty stack indicator are always updated.
The forget gate is responsible for deleting portions of memory when stack items are popped.
)) To complete the construction, we fix the output gate to 1, and have the output layer read the top stack position: where W c,j =            1, c = ) and j = k −1, c = ] and j = k 1, c = None and j = 2k + 1 0, otherwise.