Distilling weighted finite automata from arbitrary probabilistic models

Weighted finite automata (WFA) are often used to represent probabilistic models, such as n-gram language models, since they are efficient for recognition tasks in time and space. The probabilistic source to be represented as a WFA, however, may come in many forms. Given a generic probabilistic model over sequences, we propose an algorithm to approximate it as a weighted finite automaton such that the Kullback-Leibler divergence between the source model and the WFA target model is minimized. The proposed algorithm involves a counting step and a difference of convex optimization, both of which can be performed efficiently. We demonstrate the usefulness of our approach on some tasks including distilling n-gram models from neural models.


Introduction
Given a sequence of symbols x 1 , x 2 , . . . , x n−1 , where symbols are drawn from the alphabet Σ, a probabilistic model S assigns to the next symbol x n ∈ Σ the conditional probability p s [x n | x n−1 . . . x 1 ].

Such a model might be Markovian, where
such as a k-gram language model (LM) (Chen and Goodman, 1998) or it might be non-Markovian such as a long short-term memory (LSTM) neural network language model (Sundermeyer et al., 2012). Our goal is to approximate a probabilistic model as a weighted finite automaton (WFA) such that the weight assigned by the WFA is close to the probability assigned by the source model. Specifically, we will seek to minimize the Kullback-Leibler (KL) divergence between the source S and the target WFA model. Representing the target model as a WFA has many advantages including efficient use, compact representation, interpretability, and composability. WFA models have been used in many applications including speech recognition , speech synthesis (Ebden and Sproat, 2015), optical character recognition (Breuel, 2008), machine translation (Iglesias et al., 2011), computational biology (Durbin et al., 1998), and image processing (Albert and Kari, 2009). One particular problem of interest is language models for on-device (virtual) keyboard decoding , where WFA models are used due to space and time constraints. However, storing the training data in a centralized server and training k-gram or other WFA models directly may not be feasible due to privacy constraints (Hard et al., 2018). Alternatively, an LSTM model can be trained by federated learning (Konečnỳ et al., 2016;Hard et al., 2018), then converted to a WFA at the server for fast ondevice inference. This not only may improve performance, but also provide additional privacy.
We allow failure transitions (Aho and Corasick, 1975;Mohri, 1997) in the target WFA, which are taken only when no immediate match is possible at a given state, for compactness. For example, in the WFA representation of a backoff k-gram model, failure transitions can compactly implement the backoff (Katz, 1987;Chen and Goodman, 1998;Allauzen et al., 2003;Novak et al., 2013;Hellsten et al., 2017). The inclusion of failure transitions will complicate our analysis and algorithms but is highly desirable in applications such as keyboard decoding. Further, to avoid redundancy that leads to inefficiency, we assume the target model is deterministic, which requires at each state there is at most one transition labeled with a given symbol.
The approximation problem can be divided into two steps: (1) select an unweighted automaton A that will serve as the topology of the target automaton and (2) weight the automaton A to form our weighted approximationÂ. The main goal of this paper is the latter determination of the automaton's Figure 1: 3-gram topology example derived from the corpus aab. States are labeled with the context that is remembered, ∧ denotes the initial context, the empty context, $ the final context (and terminates accepted strings), and matches any symbol in a context. Failure transitions, labeled with ϕ, implement backoff from histories xy to y to .
weighting in the approximation. In some applications, the topology may be unknown. In such cases, one choice is to build a kgram deterministic finite automaton (DFA) topology from a corpus drawn from S (Allauzen et al., 2003). This could be from an existing corpus or from random samples drawn from S. Figure 1 shows a trigram topology for the very simple corpus aab. This representation makes use of failure transitions. These allow modeling strings unseen in the corpus (e.g., abab) in a compact way by failing or backing-off to states that correspond to lower-order histories. Such models can be made more elaborate if some transitions represent classes, such as names or numbers, that are themselves represented by sub-automata. As mentioned previously, we will mostly assume we have a topology either pre-specified or inferred by some means and focus on how to weight that topology to best approximate the source distribution.
In previous work, there have been various approaches for estimating weighted automata. Methods include state merging and weight estimation from a prefix tree data representation Oncina, 1994, 1999), the EM algorithm (Dempster et al., 1977) applied to fully connected HMMs or specific topologies (Eisner, 2001) and spectral methods applied to automata (Balle and Mohri, 2012;Balle et al., 2014). For approximating neural network (NN) models as WFAs, methods have been proposed to build n-gram models from RNN samples (Deoras et al., 2011), from DNNs trained at different orders (Arisoy et al., 2014;Adel et al., 2014), and from RNNs with quantized hidden states (Tiño and Vojtek, 1997;Lecorvé and Motlicek, 2012).
Our paper is distinguished in several respects from previous work. First, our general approach does not depend on the form the source distribution. Second, our targets are a wide class of deterministic automata with failure transitions. Third, we search for the minimal KL divergence between the source and target distributions, given a fixed target topology. We remark that if the source probabilistic model is represented as a WFA, our approximation will in general give a different solution than forming the finite-state intersection with the topology and weight-pushing to normalize the result (Mohri, 2009;. Our approximation has the same states as the topology whereas a weightpushed intersection could have many more states and and is not an approximation, but an exact representation, of the source distribution. Before presenting and validating algorithms for a minimum KL divergence approximation when either the source itself is finite-state or not (in which case sampling is involved), we next present the theoretical formulation of the problem and the minimum KL divergence approximation.
2 Theoretical analysis 2.1 Probabilistic models Let Σ be a finite alphabet. Let x n i ∈ Σ * denote the string x i x i+1 . . . x n and x n x n 1 . A probabilistic model p over Σ is a probabilistic distribution over the next symbol x n , given the previous symbols x n−1 , such that 1 x∈Σ p(x n = x|x n−1 ) = 1 ∧ ∀x ∈ Σ, p(x n = x|x n−1 ) ≥ 0.
Without loss of generality, we assume that the model maintains an internal state q and updates it after observing the next symbol. 2 Furthermore, the probability of the subsequent state just depends on the state q for all i, n, x i , x n i+1 , where q(x i ) is the state the model has reached after observing sequence x i . Let Q(p) be the set of possible states. Let the language L(p) ⊆ Σ * defined by the distribution p be 1 We define x 0 , the empty string, and adopt p( ) = 0. 2 In the most general case, q(x n ) = x n . L(p) {x n ∈ Σ * : p(x n ) > 0, The symbol $ is used as a stopping criterion. Further for all x n ∈ Σ * , p(x n |x n−1 : x n−1 = $ ) = 0.
The KL divergence between two models p s and p a is given by where for notational simplicity, we adopt the notion 0/0 = 1 and 0 log(0/0) = 0 throughout the paper. Note that for the KL divergence to be finite, we need L(p s ) ⊆ L(p a ). We first reduce the KL divergence between two models as follows (cf. Carrasco, 1997;Cortes et al., 2008). In the following, let q * denote the states of the probability distribution p * .
where c(x, q a ) is given by and does not depend on p a . Proof is omitted due to space limitations.

Weighted finite automata
A weighted finite automaton A = (Σ, Q, E, i, f ) over R + is given by a finite alphabet Σ, a finite set of states Q, a finite set of transitions Transitions e 1 and e 2 are consecutive if n[e i ] = p[e i+1 ]. A path π = e 1 · · · e n ∈ E * is a finite sequence of consecutive transitions, the source and destination states of which we denote by p[π] and n[π], respectively. The label of a path is the concatenation of its transition labels [π] = [e 1 ] · · · [e n ]. The weight of a path is obtained by multiplying its transition weights w For a non-empty path, the i-th transition is denoted by π i . P (q, q ) denotes the set of all paths in A from state q to q . We extend this to sets in the obvious way: P (q, R) denotes the set of all paths from state q to q ∈ R and so forth. A path π is successful if it is in P (i, f ) and in that case the automaton is said to accept the input string α = [π].
The language accepted by an automaton A is the regular set L Thus all successful paths are terminated by the symbol $ .
For a symbol x ∈ Σ and a state q ∈ Q of a deterministic, probabilistic WFA A, define a distribution p a (x|q) w if (q, x, w, q ) ∈ E and p a (x|q) 0 otherwise. Then p a is a probabilistic model over Σ as defined in the previous section. If A = (Σ, Q, E, i, f ) is an unweighted deterministic automaton, we denote by P(A) the set of all probabilistic models p a representable as a weighted WFAÂ = (Σ, Q,Ê, i, f ) with the same topology as A whereÊ = {(q, x, p a (x|q), q ) : (q, x, 1, q ) ∈ E}.

Weighted finite automata with failure transitions
A ϕ transition does not add to a path label; it consumes no input. However it is followed only when the input can not be read immediately. Specifically, a path e 1 · · · e n in a ϕ-WFA is disallowed if it contains a subpath e i · · · e j such that [e k ] = ϕ for all k, i ≤ k < j, and there is another transition e ∈ E such that p[e i ] = p[e] and [e j ] = [e] ∈ Σ (see Figure 2). Since the label x = l[e j ] can be read on e, we do not follow the failure transitions to read it on e j as well.
We use P * (q, q ) ⊆ P (q, q ) to denote the set of (not dis-) allowed paths from state q to q in a ϕ-WFA. This again extends to sets in the obvious way. A path π is successful in a ϕ-WFA if π ∈ P * (i, F ) and only in that case is the input string α = [π] accepted.
The language accepted by the ϕ-automaton A is the regular set L . We assume each string in L(A) is terminated by the symbol $ as before. We also assume there are no ϕ-labeled cycles and there is at most one exiting failure transition per state.
We express the ϕ-extended transitions leaving q as This is a set of (possibly new) transitions (q, x, ω, q ), one for each allowed path from source state q to destination state q with optional leading failure transitions and a final x-labeled transition. Denote the labels of E * [q] by L * [q].
A probabilistic (or stochastic) ϕ-WFA satisfies In other words, if a symbol can be read immediately from a state q it can also be read from a state failing (backing-off) from q and if q does not have a backoff arc, then at least one additional label can be read from q that cannot be read from q. For example, the topology depicted in Figure 1 has this property. We restrict our target automata to have a topology with the backoff-complete property since it will simplify our analysis, make our algorithms efficient and is commonly found in applications.
For a symbol x ∈ Σ and a state q ∈ Q of a deterministic, probabilistic ϕ-WFA A, de- and p * a (x|q) 0 otherwise. Then p * a is a probabilistic model over Σ as defined in Section 2.1. Note the distribution p * a at a state q is defined over the ϕ−extended transitions E * [q] where p a in the previous section is defined over the transitions E[q]. It is convenient to define a companion distribution p a ∈ P (A) to p * a as follows: 3 given a symbol x ∈ Σ ∪ {ϕ} and state q ∈ Q, define p a (x|q) p * a (x|q) when x ∈ L[q] ∩ Σ, p a (ϕ|q) 1 − x∈L[q]∩Σ p * a (x|q), and p a (x|q) 0 otherwise. The companion distribution is thus defined solely over the transitions E[q].
When A = (Σ, Q, E, i, f ) is an unweighted deterministic, backoff-complete ϕ-WFA, we denote by P * (A) the set of all probabilistic models p * a representable as a weighted ϕ-WFAÂ = (Σ, Q,Ê, i, f ) of same topology as A witĥ where p a ∈ P (A) is the companion distribution to p * a and α(q, q ) = p a (ϕ|q)/d(q, q ) is the weight of the failure transition from state q to q with Note we have specified the weights on the automaton that represents p * a ∈ P * (A) entirely in terms of the companion distribution p a ∈ P (A), thanks to the backoff-complete property.
Conversely, each distribution p a ∈ P(A) can be associated to a distribution p * a ∈ P * (A) given a deterministic, backoff-complete ϕ-WFA A. First extend α(q, q ) to any failure path as follows. Denote a failure path from state q to q by π ϕ (q, q ). Then define where this quantity is taken to be 1 when the fail-ure path is empty (q = q ). Finally define where for x ∈ L * [q], q x signifies the first state q on a ϕ-labeled path in A from state q for which x ∈ L[q ].
For (6) to be well-defined, we need d(p[e], n[e]) > 0. To ensure this condition, we restrict P(A) to contain distributions such that p a (x|q) ≥ for each x ∈ L[q]. 4 Given an unweighted deterministic, backoffcomplete, automaton A, our goal is to find the target distribution p * a ∈ P * (A) that has the minimum KL divergence from our source probability model p s . We can restate our goal in terms of the companion distribution p a ∈ P(A). Let B n (q) be the set of states in A that back-off to state q in n failure transitions and let B(q) = |Qa| n=0 B n (q).
and do not depend on p a . Proof is omitted due to space limitations. The quantity in braces in the statement of Lemma 2 depends on the distribution p a only at state q so the minimum KL divergence D(p s ||p * a ) can be found by maximizing that quantity independently for each state.

Algorithms
Approximating a probabilistic source algorithmically as a weighted finite automaton requires two steps: (1) compute the quantity C(x, q) in Lemma 2 and (2) use this quantity to find the 4 For brevity, we do not include in the notation of P(A). minimum KL divergence solution. The first step, which we will refer to as counting, is covered in the next section and the KL divergence minimization step is covered afterwards.

Counting
How the counts are computed will depend on the source model form. We divide this into two cases.
3.1.1 ϕ-WFA source and target When the source and target models are represented as ϕ-WFAs we compute C(x, q a ) from Lemma 2. From Equation 9 this can be written as The quantity γ(q s , q a ) can be computed as where S ∩ A is the weighted intersection of automata S and A formed using an efficient ϕ-WFA intersection that compactly retains failure transitions in the result, as described in Allauzen and Riley (2018). The quantity γ(q s , q a ) is the (generalized) shortest distance from the initial state to a specified state computed over the positive real semiring (Mohri, 2002;Allauzen and Riley, 2018). Equation 11 is the weighted count of the paths in S ∩ A allowed by the failure transitions that begin at the initial state and end in any transition leaving a state (q s , q) labeled with x. This computation can be simplified by the following transformation. First we convert S ∩ A to an equivalent WFA by replacing each failure transition with an epsilon transition and introducing a negatively-weighted transition to compensate for formerly disallowed paths (Allauzen and Riley, 2018). The result is then promoted to a transducer T with the output label used to keep track of the source state in A of the compensated positive transition (see Figure 3). 5 S ∩ A (q s ,q a ) x/ω (q s ',q a ') φ/α x/ν T (q s ,q a ) x:q a /ω (q s ',q a ') ε:-/α x:q a '/-α ν x:q a '/ν Then, for x ∈ Σ, ) is a transition in T and γ T (q s , q) is the shortest distance from the initial state to (q s , q a ) in T computed over the real semiring as described in Allauzen and Riley (2018). Equation 12 is the weighted count of all paths in S ∩ A that begin at the initial state and end in any transition leaving a state (q s , q) labeled with x minus the weighted count of those paths that are disallowed by the failure transitions.
Finally, we compute C(ϕ, q) as follows. The count mass entering a state must equal the count mass leaving a state (qa,x,1,q)∈A This quantity can be computed iteratively in the topological order of states with respect to the ϕlabeled transitions.
3.1.2 Arbitrary source and ϕ-WFA target In some cases, the source is a distribution with possibly infinite states, e.g., LSTMs. For these sources, computing C(x, q) can be computationally intractable as (11) requires a summation over all possible states in the source machine, Q s . We propose to use a sampling approach to approximate C(x, q) for these cases. Let x(1), x(2), . . . , x(m) be independent random samples from p s . Instead of C(x, q), we propose to usê Observe that the expectation E[γ(q s , q a )] is given by henceγ(q s , q a ) is an unbiased, asymptotically consistent estimator of γ(q s , q a ). GivenĈ(x, q), we compute C(ϕ, q) similarly to the previous section.

KL divergence minimization
As noted before, the quantity in braces in the statement of Lemma 2 depends on the distribution p a only at state q so the minimum KL divergence D(p s ||p * a ) can be found by maximizing that quantity independently for each state.
Fix a state q and let y x p a (x|q) for x ∈ L[q] and let y [y x ] x∈L [q] 6 . Then our goal reduces to subject to the constraints y x ≥ for x ∈ L[q] and x∈L[q] y x = 1.

Figure 4: KL-MINIMIZATION Algorithm
This is a difference of two concave functions in y since log(f (y)) is concave for any linear function f (y), the C(x, q) are always non-negative and the sum of concave functions is also concave. We give a DC programming solution to this optimization (Horst and Thoai, 1999). Let be the function domain. The DC programming solution for such a problem uses an iterative procedure that linearizes the subtrahend in the concave difference about the current estimate and then solves the resulting concave objective for the next estimate. Using this procedure on Equation 13 gives y n+1 as argmax y∈Ω x∈L [q] C(x, q) log y x + y x f (x, q, y n ) Observe that 1 − x ∈L[q 0 ]∩Σ y n x ≥ as the automaton is backoff-complete and y n ∈ Ω.
Let C(q) be defined as: C(x , q) The following lemma provides the solution to the optimization problem in (14) which leads to a stationary point of the objective.
Lemma 3. Solution to (14) given by where λ is chosen so that x y n x = 1 and lies in max x∈L [q] f (x, q, y n ) + C(x, q), max f (x, q, y n ) + C(q) .
Proof is omitted due to space limitations. From this, we form the KL-MINIMIZATION algorithm in Figure 4. Observe that if all the counts are zero, then any solution is an optimal solution and the algorithm returns a uniform distribution over labels. In other cases, we initialize the model based on counts such that y 0 ∈ Ω. We then repeat the DC programming algorithm iteratively until convergence. Since Ω is a convex compact set and both the functions are continuous and differentiable in Ω, the KL-MINIMIZATION converges to a stationary point (Sriperumbudur and Lanckriet, 2009, Theorem 4).

Experiments
We now provide experimental evidence of the theory's validity and show its usefulness in various applications. For the ease of notation, we use WFA-APPROX to denote the exact counting algorithm described in Section 3.1.1 followed by the KL-MINIMIZATION algorithm of Section 3.2. Similarly, we use WFA-SAMPLEAPPROX(N) to denote the sampled counting described in Section 3.1.2 with N sampled sentences followed by KL-MINIMIZATION.
We first give experimental evidence that supports the theory in Section 4.1. We then show how to approximate neural models as WFAs in Section 4.2. We also use the proposed method to provide lower bounds on the perplexity given a target topology in Section 4.3.
For all the experiments we use the 1996 CSR Hub4 Language Model data, LDC98T31 from the Broadcast News (BN) task. We use the processed form of the corpus and further process it to downcase all the words and remove punctuation. The resulting dataset has 132M words in the training set, 20M words in the test set, and has 240K unique words. From this, we create a vocabulary of approximately 32K words consisting of all words that appeared more than 50 times in the training corpus. Using this vocabulary, we create a trigram Katz model and prune it to contain 2M n-grams using entropy pruning (Stolcke, 2000), which we use as a baseline in all our experiments. We use Katz smoothing since it is amenable to pruning (Chelba et al., 2010). The perplexity of this model on the test set is 144.4. 7 All algorithms were implemented using the open-source OpenFst and OpenGrm n-gram and stochastic automata (SFst) libraries 8 with the last library including these implementations (Allauzen et al., 2007;Roark et al., 2012;Allauzen and Riley, 2018).

Empirical evidence of theory
Recall that our goal is to find the distribution on a target DFA topology that minimizes the KL divergence to the source distribution. However, as shown in Section 3.2, when the target topology has failure transitions, the optimization objective is not convex so the stationary point solution may not be the global optimum. We now show that the model indeed converges to a good solution in various cases empirically.
Idempotency: When the target topology is the same as the source topology, we show that the performance of the approximated model matches the source model. Let p s be the pruned Katz word model described above. We approximate 7 For all perplexity measurements we treat the unknown word as a single token instead of a class. To compute the perplexity with the unknown token being treated as class, multiply the perplexity by k 0.0115 , where k is the number of tokens in the unknown class and 0.0115 is the out of vocabulary rate in the test dataset. 8 These libraries are available at www.openfst.org and www.opengrm.org p s onto the same topology using WFA-APPROX and WFA-SAMPLEAPPROX(·) and then compute perplexity on the test corpus. The results are presented in Figure 5. The test perplexity of the WFA-APPROX model matches that of the source model and the performance of the WFA-SAMPLEAPPROX(N) model approaches that of the source model as the number of samples N increases.
Comparison to greedy pruning: Recall that entropy pruning (Stolcke, 2000) greedily removes ngrams such that the KL divergence to the original model p s is small. Let p greedy be the resulting model and A greedy be the topology of p greedy . If the KL-MINIMIZATION converges to a good solution, then approximating p s onto A greedy would give a model that is at least as good as p greedy . We show that this is indeed the case; in fact, approximating p s onto A greedy performs better than p greedy . In particular, let p s again be the 2M n-gram Katz model described above. We prune it to have 1M n-grams and obtain p greedy , which has a test perplexity of 157.4. We then approximate p s on A greedy and the resulting model has test perplexity of 155.6, which is smaller than the test perplexity of p greedy . This shows that the approximation algorithm indeed finds a good solution.

Neural models to WFA conversion
Since neural models such as LSTMs give improved performance over n-gram models, we investigated if an LSTM distilled onto a WFA model can obtain better performance than the baseline WFA trained directly from Katz smoothing. As stated in the introduction, this could then be used together with federated learning for fast and pri-vate on-device inference.
To explore this, we trained an LSTM language model on the training data. The model has 2 LSTM layers with 1024 states and embedding size of 1024. The resulting model has a test perplexity of 60.5. We approximate this model as an WFA in two ways from samples drawn from the LSTM.
The first way is to construct a Katz ngram model on N LSTM samples and entropyprune to 2M n-grams, which we denote by WFA-SAMPLEKATZ(N). The second way is is to approximate onto the baseline Katz 2M n-gram topology described above using WFA-SAMPLEAPPROX(N).
The results are included in Figure 5. It shows that the WFA-SAMPLEKATZ(N) model performs significantly worse than the baseline Katz model even at 32M samples, while the WFA-SAMPLEAPPROX(N) models have better perplexity than the baseline Katz model with as little as 1M samples. With 32M samples this way of approximating the LSTM model as a WFA is 3.6 better in perplexity than the baseline Katz.

Lower bounds on perplexity
The neural model in Section 4.2 has a perplexity of 60.5, but the best perplexity for the approximated model is 140.8. Is there a better approximation algorithm for the given target topology? We place bounds on that next, in our final experiment.
Let T be the set of test sentences. The test-set log-perplexity of a model p can be written as wherep t is the empirical distribution of test sentences. Observe that the best model with topology A can be computed as , which is the model with topology A that has minimal KL divergence from the test distributionp t . This can be computed using WFA-APPROX . If we use this approach on the BN test set with the 2M n-gram Katz model, the result has a perplexity of 121.1. This demonstrates that, under the assumption that the algorithm finds the global KL divergence minimum, the test perplexity with this topology cannot be improved beyond 121.1, irrespective of the method.
What if we approximate the LSTM onto the best trigram topology? To test this, we build a trigram model from the test data and approximate the LSTM on the trigram topology. This approximated model has 11M n-grams and a perplexity of 81. This shows that for large datasets, the shortfall of n-gram models in the approximation is in the n-gram topology.

Summary
In this paper, we have presented an algorithm for minimizing the KL-divergence between a probabilistic source model over sequences and a WFA target model. Our algorithm is general enough to permit source models of arbitrary form (e.g., RNNs) and a wide class of target WFA models, importantly including those with failure transitions, such as n-gram models. We provide some experimental validation of our algorithm, including demonstrating that it is well-behaved in common scenarios and that it yields improved performance over baseline n-gram models using the same WFA topology. Additionally, we use our methods to provide lower bounds on how well a given WFA topology can model a given test set. All of the algorithms reported here are available in the open-source OpenGrm libraries at http: //www.opengrm.org.
In addition to the above-mentioned results, we also demonstrated that optimizing the WFA topology for the given test set yields far better perplexities than were obtained using WFA topologies derived from training data alone, suggesting that the problem of deriving an appropriate WFA topology -something we do not really touch on in this paper -is particularly important.