Estimating Marginal Probabilities of n-grams for Recurrent Neural Language Models

Recurrent neural network language models (RNNLMs) are the current standard-bearer for statistical language modeling. However, RNNLMs only estimate probabilities for complete sequences of text, whereas some applications require context-independent phrase probabilities instead. In this paper, we study how to compute an RNNLM’s em marginal probability: the probability that the model assigns to a short sequence of text when the preceding context is not known. We introduce a simple method of altering the RNNLM training to make the model more accurate at marginal estimation. Our experiments demonstrate that the technique is effective compared to baselines including the traditional RNNLM probability and an importance sampling approach. Finally, we show how we can use the marginal estimation to improve an RNNLM by training the marginals to match n-gram probabilities from a larger corpus.


Introduction
Recurrent neural networks (RNNs) are the stateof-the-art architecture for statistical language modeling (Jozefowicz et al., 2016;Melis et al., 2018), the task of assigning a probability distribution to a sequence of words. The relative likelihoods of the sequences are useful in applications such as speech recognition, machine translation, automated conversation, and summarization (Mikolov et al., 2010;See et al., 2017;Wen et al., 2017). Typically, RNN language models (RNNLMs) are trained on complete sequences (e.g., a sentence or an utterance), or long sequences (e.g. several documents), and used in the same fashion in applications or testing.
A question arises when we want to compute the probability of a short sequence without the preceding context. For instance, we may wish to query for how likely the RNNLM is to generate a particular phase aggregated over all contexts. We refer to this context-independent probability of a short phrase as a marginal probability, or marginal.
These marginal probabilities are useful in three board categories of applications. First, they allow us to inspect the behavior of a given RNNLM. We could check, for example, whether an RNNLMbased generator might ever output a given offensive phrase. Second, the marginals could be used in phrase-based information extraction, such as extracting cities by finding high-probability x's in the phrase "cities such as x" (Soderland et al., 2004;Bhagavatula et al., 2014). Finally, we can use the phrase probabilities to train an RNNLM itself, e.g. updating the RNNLM according to n-gram statistics instead of running text (Chelba et al., 2017;Noraset et al., 2018). In our experiments, we show an example of the last application.
Estimating marginals from an RNNLM is challenging. Unlike an n-gram language model (Chen and Goodman, 1996), an RNNLM does not explicitly store marginal probabilities as its parameters. Instead, previous words are recurrently combined with the RNN's hidden state to produce a new state, which is used to compute a probability distribution of the next word (Elman, 1990;Mikolov et al., 2010). When the preceding context is absent, however, the starting state is also missing. In order to compute the marginal probability, in principle we must marginalize over all possible previous contexts or all continuous-vector states. Both options pose a severe computational challenge.
In this paper, we study how to efficiently approximate marginal probabilities from an RNNLM, without generating a large amount of text. Given an RNNLM and a phrase, our goal is to estimate how frequently the phrase will occur in text generated by the RNNLM. We present two approaches that can be used to estimate the marginal probabilities: sampling for the starting state, and using a single starting state with altered RNNLM training. We show empirically that we can use a zero vector as a starting state of an RNNLM to compute accurate marginal estimates, but we must randomly reset the RNNLM state to zero during training and add a unigram likelihood term to the RNNLM training objective. Finally, we demonstrate that we can use marginal estimation to incorporate n-gram statistics from a larger corpus to improve the perplexity of an RNNLM trained on a similar, but smaller corpus.

Marginal Estimation
The goal of marginal estimation is to determine the likelihood of a short phrase where the preceding context is not known; we refer to this likelihood as a marginal probability. In other words, the marginal probability of a query refers to how likely a language model will generate a query regardless of context.

Problem settings
An RNNLM (Mikolov et al., 2010) defines a probability distribution over words conditioned on previous words as the following: where w 1:t−1 is a sequence of previous words, θ w o denotes the output weights of a word w, and g(·) is a recurrent function such as an LSTM (Hochreiter and Schmidhuber, 1997) or GRU unit .
An initial state, h 1 is needed to start the recurrent function g(h 1 , w 1 ), and also defines the probability distribution of the first word P (w 1 |h 1 ).
In the standard language model setting, we compute h 1 using a start-of-sentence symbol for w 0 ("<s>"), and a special starting state h 0 (usually set to be a vector of zeros 0). This initialization approach works fine for long sequences, because it is only utilized once and its effect is quickly swamped by the recurrent steps of the network. However, it is not effective for estimating marginal probabilities of a short phrase. For example, if we naively apply the standard approach to compute the probability of the phrase "of the", we would obtain: The initial setting of the network results in low likelihoods of the first few tokens in the evaluation. For instance, the probability P (of the) computed in the above fashion will likely be a bad underestimate, because "of the" does not usually start a sentence.
We would like to compute the likelihood of standalone phrases, where rather than assuming the starting state we instead marginalize out the preceding context. Let the RNN's state prior to our query sequence be z ∈ R d , a vector-valued random variable representing the RNN initial state, and let w 1:T be a short sequence of text. The marginal probability is defined as: The integral form of the marginal probability is intractable and requires an unknown density estimator of the state, P (z).

Trace-based approaches
The integral form of the marginal probability in Eq 1 can be approximated by sampling for z. In this approach, we assume that there is a source of samples which asymptotically approaches the true distribution of the RNN states as the number of samples grows. In this work, we use a collection of RNN states generated in an evaluation, called a trace. Given a corpus of text, a trace of an RNNLM is the corresponding list of RNN states, , produced when evaluating the corpus. We can estimate the marginal probability by sampling the initial state z from H as follows: where h 2 = g(z ψ , w 1 ) and h t = g(h t−1 , w t−1 ) for t > 2 (i.e. the following states are the deterministic output of the RNN function). Given a large trace this may produce accurate estimates, but it is intractably expensive and also wasteful, since in general there are very few states in the trace that yield a high likelihood for a sequence.
To reduce the number of times we run the model on the query, we use importance sampling over the trace. We train an encoder to output for a given ngram query a state "near" the starting state(s) of the query, z χ = q χ (w 1:T ). We define a sampling weight for a state in the trace, h (tr) , proportional to the dot product of the state and the output of the encoder, z χ , as the following: This distribution is biased to toward states that are likely to precede the query w 1:T . We can estimate the marginal probability as the following: Here the choice of the prior P (z) is a uniform distribution over the states in the trace. The encoder, q χ (w 1:T ), is a trained RNN with its input reversed, and z χ is the final output state of q χ . To train the encoder, we randomly draw sub-strings w i:i+n of random length from the text used to produce the trace, and minimize the mean-squared difference between z χ and h (tr) i .

Fixed-point approaches
While the trace-based approaches work on an existing (already trained) RNNLM, they might take several samples to accurately estimate the marginal probability. We would like to have a single point as the starting state, named z ψ . We can either train this vector or simply set it to a zero vector. Then the marginal probability in Eq 1 can be estimated with a single run i.e. p(z ψ ) = 1.0 and p(z) = 0.0 if z = z ψ . The computation is reduced to: where h 2 = g(z ψ , w 1 ) and the rest of the state process is as usual, h t = g(h t−1 , w t−1 ). In this paper, we set z ψ to be a zero vector, and call this method Zero.
As we previously discussed, our fixed-point state, z ψ , is not a suitable starting state of all ngrams for any given RNNLM, so we need to train an RNNLM to adapt to this state. To achieve this, we use a slight modification of the RNN's truncated back-propagation through time training algorithm. We randomly reset the states to z ψ when computing a new state during the training of the RNNLM (a similar reset was used for a different purpose-regularization-in Melis et al. (2018)). This implies that z ψ is trained to maximize the likelihood of different subsequent texts of different lengths, and thus is an approximately good starting point for any sequence. Specifically, a new state is computed during the training as follows: where r ∼ Bern(ρ) and ρ is a hyper-parameter for the probability of resetting a state. Larger ρ means more training with z ψ , but it could disrupt the long-term dependency information captured in the state. We keep ρ relatively small at 0.05.
In addition to the state reset, we introduce a unigram regularization to improve the accuracy of the marginal estimation. From Eq 3, z ψ is used to predict the probability distribution of the first token, which should be the unigram distribution. To get this desired behavior, we employ regularization to maximize the likelihood of each token in the training data independently (as if reset every step). We call this a unigram regularizer: and we add it to the training objective: L text = − T t=1 logP (w t |h t ). Thus, the overall training loss is: L = L text + L U .

Experimental Settings
We experiment with a standard medium-size LSTM language model (Zaremba et al., 2014) over 2 datasets: Penn Treebank (PTB) (Mikolov et al., 2010) and WikiText-2 (WT-2) (Merity et al., 2017). We use weight tying (Inan et al., 2017) and train all models with Adam (Kingma and Ba, 2014) for 40 epochs with learning rate starting from 0.003 and decaying every epoch at the rate of 0.85. We use a batch size of 64 and truncated backpropagation with 35 time steps, and employ variational dropout (Gal and Ghahramani, 2016). The final parameter set is chosen as the one minimizing validation loss. For the query model q χ (w 1:T ) used in importance sampling, we use the same settings for the model and the training procedure as the above.

Marginal Estimation
In this subsection, we evaluate the accuracy of each approach at estimating marginal probabilities. Given a model and a phrase, we first obtain a target marginal probability, P text (w 1:T ), from a frequency of the phrase occurring in a text generated by the model. Then, we use each approach to estimate the marginal probability of the phrase, P est (w 1:T ). To measure the performance, we compute the absolute value of the log ratio (lower is better) between the target marginal probability and the estimated marginal probability (P est ): E(w 1:T ) = log(P text (w 1:T )/P est (w 1:T )) (4) This evaluation measure gives equal importance to every n-gram regardless of its frequency. In the following experiments, we generate approximately 2 million and 4 million tokens for PTB and WT-2 models respectively. The probability of phrases occurring in the generated text serves as our target marginal. We form a test set consisting of all n-grams in the generated text for n ≤ 5 words, excluding n-grams with frequency less than 20 to reduce noise from the generation. For the trace-based estimations, we average the marginal probabilities from 100 samples. Table 1 shows the average discrepancy between marginal probabilities estimated by generated-text  statistics and the methods discussed in Section 2 (Eq 4). From the table, the RNNLM trained with the state-reset and the unigram regularization (Zero (RU ) ) performs better than both zerostart and trace-based approaches on the traditional model. The importance sampling method (Trace-IW) has the second lowest error and performs better than random sampling (Trace-Rand). Ablation analysis shows that both state-reset and the unigram regularization contribute to the accuracy. Note that the trace-based methods use the same model as Zero.
To show how performance varies depending on the query, we present results aggregated by ngram lengths. Table 2 shows the errors of the WT-2 dataset. When the n-gram length is greater than 2, Trace-IW has better accuracy. This makes sense because the encoder has more evidence to use when inferring the likely start state.

Training with marginal probabilities
We now turn to an application of the marginal estimation. One way that we can apply our marginal estimation techniques is to train an RNNLM with n-gram probabilities in addition to running text. This is helpful when we want to efficiently incorporate data from a much larger corpus without training the RNNLM on it directly (Chelba et al., 2017;Noraset et al., 2018). In this work, we frame the problem as a regression and use a loss equal to the squared difference of log probabilities: where α is a hyper-parameter and set to 0.1. Following the result in Table 1, we use the Zero method to estimate P est (x (k) 1:T ) as in Eq 3, and add L N to the training losses that use the running text corpus.
To evaluate the approach, we follow the Noraset et al. (2018)   improve an RNNLM trained using WT-2. In our experiment, we use n-grams up to n = 5 with frequency greater than 50. We ignore n-gram containing <unk>, because the vocabulary sets are different. Table 3 shows the result. Since we do not use the same setting as in the original work, we cannot directly compare to that work -they use different optimization settings, more expensive n-gram loss, and Kneser-Ney bi-gram language model. However, we see that the proposed n-gram loss is beneficial when combined with the unigram loss. Importantly, unlike the approach in Noraset et al. (2018), our approach requires no sampling which makes it several times faster.
In addition, we present our preliminary result comparing training with the marginal probability of n-grams to training with the complete data. Given a limited budget of optimization steps, we ask whether training on n-grams is more valuable than training on the full corpus. To keep the results compatible, we use the vocabulary set of WikiText-2 and convert all OOV tokens in the training data of WikiText-103 to the "<unk>" token. Figure 1 shows the loss (average negative log-likelihood) of the validation data as the number of optimization steps increases.
We can see that training with the marginals does not perform as well as training with WikiText-103 training data, but outperforms the model trained only with WikiText-2 training data. This might be due to our choice of n-grams and optimization settings such as a number of n-grams per batch, weight of the n-gram loss, and the learning rate decay rate. We leave exploring these hyperparameters as an item of future work.

Conclusion
We investigated how to estimate marginal probabilities of n-grams from an RNNLM, when the preceding context is absent. We presented a simple method to train an RNNLM in which we occasionally reset the RNN's state and also maximize 0 100 200 300 steps in hundreds 4.5 5.0 5.5 6.0 6.5 loss WT-2 text WT-2 text and WT-103 n-grams only WT-103 text (WT-2's vocab) Figure 1: Loss in negative log-likelihood over steps in training. The loss computed using the valid data from WikiText-2 corpus. Training with n-grams from a larger corpus is helpful, but not as well as training with the running text from a larger corpus itself. unigram likelihood along with the traditional objective. Our experiments showed that an RNNLM trained with our method outperformed other baselines on the marginal estimation task. Finally, we showed how to improve RNNLM perplexity by efficiently using additional n-gram probabilities from a larger corpus.
For future work, we would like to evaluate our approaches in more applications. For example, we can use the marginal statistics for information extraction, or to detect and remove abnormal phrases in text generation. In addition, we would like to continue improving the marginal estimation by experimenting with recent density estimation techniques such as NADE (Uria et al., 2016).