Learning Variational Word Masks to Improve the Interpretability of Neural Text Classifiers

To build an interpretable neural text classifier, most of the prior work has focused on designing inherently interpretable models or finding faithful explanations. A new line of work on improving model interpretability has just started, and many existing methods require either prior information or human annotations as additional inputs in training. To address this limitation, we propose the variational word mask (VMASK) method to automatically learn task-specific important words and reduce irrelevant information on classification, which ultimately improves the interpretability of model predictions. The proposed method is evaluated with three neural text classifiers (CNN, LSTM, and BERT) on seven benchmark text classification datasets. Experiments show the effectiveness of VMASK in improving both model prediction accuracy and interpretability.


Introduction
Neural network models have achieved remarkable performance on text classification due to their capacity of representation learning on natural language texts (Zhang et al., 2015;Yang et al., 2016;Joulin et al., 2017;Devlin et al., 2018). However, the lack of understanding of their prediction behaviors has become a critical issue for reliability and trustworthiness and hindered their applications in the real world (Lipton, 2016;Ribeiro et al., 2016;Jacovi and Goldberg, 2020). Many explanation methods have been proposed to provide post-hoc explanations for neural networks (Ribeiro et al., 2016;Lundberg and Lee, 2017;Sundararajan et al., 2017), but they are only able to explain model predictions and cannot help improve their interpretability.
In this work, we consider interpretability as an intrinsic property of neural network models. Furthermore, we hypothesize that neural network models with similar network architectures could have Ex  . Two post-hoc explanation methods, LIME (Ribeiro et al., 2016) and SampleShapley (Kononenko et al., 2010), are used to explain the model predictions on example 1 and 2 respectively. Top three important words are shown in pink or blue for model A and B. Whichever post-hoc method is used, explanations from model B are easier to understand because the sentiment keywords "clever" and "gimmicky" are highlighted. different levels of interpretability, even though they may have similar prediction performance. Table 1 shows explanations extracted from two neural text classifiers with similar network architectures. 1 Although both models make correct predictions of the sentiment polarities of two input texts (positive for example 1 and negative for example 2), they have different explanations for their predictions.
In both examples, no matter which explanation generation method is used, explanations from model B are easier to be interpreted regarding the corresponding predictions. Motivated by the difference of interpretability, we would like to investigate the possibility of building more interpretable neural classifiers with a simple modification on input layers. The proposed method does not demand significant efforts on engineering network architectures (Rudin, 2019;Melis and Jaakkola, 2018). Also, unlike prior work on improving model interpretability (Erion et al., 2019;Plumb et al., 2019), it does not require pre-defined important attributions or pre-collected explanations. Specifically, we propose variational word masks (VMASK) that are inserted into a neural text classifier, after the word embedding layer, and trained jointly with the model. VMASK learns to restrict the information of globally irrelevant or noisy wordlevel features flowing to subsequent network layers, hence forcing the model to focus on important features to make predictions. Experiments in section 5 show that this method can improve model interpretability and prediction performance. As VMASK is deployed on top of the word-embedding layer and the major network structure keeps unchanged, it is model-agnostic and can be applied to any neural text classifiers.
The contribution of this work is three-fold: (1) we proposed the VMASK method to learn global task-specific important features that can improve both model interpretability and prediction accuracy; (2) we formulated the problem in the framework of information bottleneck (IB) (Tishby et al., 2000;Tishby and Zaslavsky, 2015) and derived a lower bound of the objective function via the variational IB method (Alemi et al., 2016); and (3) we evaluated the proposed method with three neural network models, CNN (Kim, 2014), LSTM (Hochreiter and Schmidhuber, 1997), and BERT (Devlin et al., 2018), on seven text classification tasks via both quantitative and qualitative evaluations.

Related Work
Various approaches have been proposed to interpret DNNs, ranging from designing inherently interpretable models (Melis and Jaakkola, 2018; Rudin, 2019), to tracking the inner-workings of neural networks (Jacovi et al., 2018;Murdoch et al., 2018), to generating post-hoc explanations (Ribeiro et al., 2016;Lundberg and Lee, 2017). Beyond interpreting model predictions, the explanation generation methods are also promising in improving model's performance. We propose an information-theoretic method to improve both prediction accuracy and interpretability.
Explanation from the information-theoretic perspective. A line of works that motivate ours leverage information theory to produce explana-tions, either maximizing mutual information to recognize important features (Chen et al., 2018;Guan et al., 2019), or optimizing the information bottleneck to identify feature attributions (Schulz et al., 2020;Bang et al., 2019). The information-theoretic approaches are efficient and flexible in identifying important features. Different from generating post-hoc explanations for well-trained models, we utilize information bottleneck to train a more interpretable model with better prediction performance.
Improving prediction performance via explanations. Human-annotated explanations have been utilized to help improve model prediction accuracy (Zhang et al., 2016). Recent work has been using post-hoc explanations to regularize models on prediction behaviors and force them to emphasize more on predefined important features, hence improving their performance (Ross et al., 2017;Ross and Doshi-Velez, 2018;Liu and Avci, 2019;Rieger et al., 2019). Different from these methods that require expert prior information or human annotations, the VMASK method learns global important features automatically during training and incorporate them seamlessly on improving model prediction behaviors.
Improving interpretability via explanations. Some work focuses on improving model's interpretability by aligning explanations with humanjudgements (Camburu et al., 2018;Du et al., 2019b;Chen and Ji, 2019;Erion et al., 2019;Plumb et al., 2019). Similarly to the prior work on improving model prediction performance, these methods still rely on annotations or external resources. Although enhancing model interpretability, they may cause the performance drop on prediction accuracy due to the inconsistency between human recognition and model reasoning process (Jacovi and Goldberg, 2020). Our approach can improve both prediction accuracy and interpretability without resorting to human-judgements.

Method
This section introduces the proposed VMASK method. For a given neural text classifier, the only modification on the neural network architecture is to insert a word mask layer between the input layer (e.g., word embeddings) and the representation learning layer. We formulate our idea within the information bottleneck framework (Tishby et al., 2000), where the word mask layer restricts the in-formation from words to the final prediction.

Interpretable Text Classifier with Word Masks
For an input text x = [x 1 , · · · , x T ], where x t (t ∈ {1, . . . , T }) indicates the word or the word index in a predefined vocabulary. In addition, we use x t ∈ R d as the word embedding of x t . A neural text classifier is denoted as f θ (·) with parameter θ, which by default takes x as input and generates a probability of output Y , p(Y |x), over all possible class labels. In this work, beyond prediction accuracy, we also expect the neural network model to be more interpretable, by focusing on important words to make predictions. To help neural network models for better feature selection, we add a random layer R after the word embeddings, where R = [R x 1 , . . . , R x T ] has the same length of x. Each R xt ∈ {0, 1} is a binary random variable associated with the word type x t instead of the word position. This random layer together with word embeddings form the input to the neural network model, i.e., where is an element-wise multiplication and each Z t = R xt ·x t . Intuitively, Z only contains a subset of x, which is selected randomly by R. Since R is applied directly on the words as a sequence of 0-1 masks, we also call it the word mask layer in this work.
To ensure Z has enough information on predicting Y while contains the least redundant information from x, we follow the standard practice in the information bottleneck theory (Tishby et al., 2000), and write the objective function as where X as a random variable representing a generic word sequence as input, Y is the one-hot output random variable, I(·; ·) is the mutual information, and β ∈ R + is a coefficient to balance the two mutual information items. This formulation reflects our exact expectation on Z. The main challenge here is to compute the mutual information.

Variational Word Masks
Inspired by the variational information bottleneck proposed by Alemi et al. (2016), instead of computing p(X, Y , Z), we start from an approximation distribution q(X, Y , Z). Then, with a few assumptions specified in the following, we construct a tractable lower bound of the objective in Equation 2 and the detailed derivation is provided in Appendix A. For I(Z; Y ) under q, we have I(Z; Y ) = y,z q(y, z) log(q(y|z)/q(y)). By replacing log q(y|z) with the conditional probability derived from the true distribution log p(y|z), we introduce the constraint between Y and Z from the distribution and also obtain a lower bound of I(Z; Y ), where H q (·) is entropy, and the last step uses q(x, y, z) = q(x)q(y|x)q(z|x), which is a factorization based on the conditional dependency 2 . Given a specific observation (x (i) , y (i) ), we define the empirical distribution q(X (i) , Y (i) ) as a multiplication of two Delta functions q( can be further simplified as Similarly, for I(Z; X) under q, we have an upper bound of I(Z; X) by replacing p(Z|X) with a predefined prior distribution p 0 (Z) where KL[· ·] denotes Kullback-Leibler divergence. The simplification in the last step is similar to Equation 4 with the empirical distribution q(X (i) ).
Substituting (5) and (4) into Equation 2 gives us a lower bound L of the informaiton bottleneck The learning objective is to maximize Equation 6 with respect to the approximation distribution q(X, Y , Z) = q(X, Y )q(Z|X). As a classification problem, X and Y are both observed and q(X, Y ) has already been simplified as an empirical distribution, the only one left in the approximation distribution is q(Z|X). Similarly to the objective function in variational inference (Alemi et al., 2016;Rezende and Mohamed, 2015), the first term in L is to make sure the information in q(Z|X) for predicting Y , while the second term in L is to regularize q(Z|X) with a predefined prior distribution p 0 (Z).
The last step of obtaining a practical objective function is to notice that, given X where R xt ∈ {0, 1} is a standard Bernoulli distribution. Then, Z can be reparameterized as The lower bound L can be rewritten with the random variable R as Note that, although β is inherited from the information bottleneck theory, in practice it will be used as a tunable hyper-parameter to address the notorious posterior collapse issue (Bowman et al., 2016;Kim et al., 2018).

Connections
The idea of modifying word embeddings with the information bottleneck method has recently shown some interesting applications in NLP. For example, Li and Eisner (2019) proposed two ways to transform word embeddings into new representations for better POS tagging and syntactic parsing. According to Equation 1, VMASK can be viewed as a simple linear transformation on word embeddings. The difference is that {R xt } is defined on the vocabulary, therefore can be used to represent the global importance of word x t . Recall that R xt ∈ {0, 1}, from a slightly different perspective, Equation 1 can be viewed as a generalized method on wordembedding dropout (Gal and Ghahramani, 2016). Although there are two major differences: (1) in Gal and Ghahramani (2016) all words share the same dropout rate, while in VMASK every word has its own dropout rate specified by q(R xt |x t ), (2) the motivation of wordembedding dropout is to force a model not to rely on single words for prediction, while VMASK is to learn a task-specific importance for every word.
Another implementation for making word masks sparse is by adding L 0 regularization (Lei et al., 2016;Bastings et al., 2019;Cao et al., 2020), but this regularizer only distinguishes words as important or unimportant, rather than learning continuous importance scores.

Model Specification and Training
We resort to mean-field approximation (Blei et al., 2017) to simplify the assumption on our q distribution. For q φ (R|x), we have q φ (R|x) = T t=1 q φ (R xt |x t ), which means the random variables are mutually independent and each governed by x t . We use the amortized variational inference (Rezende and Mohamed, 2015) to represent the posterior distribution q φ (R xt |x t ) with using an inference network (Kingma and Welling, 2014). In this work, we adopt a single-layer feedforward neural network as the inference network, whose parameters φ are optimized with the model parameters θ during training.
Following the same factorization as in q φ (R|x), we define the prior distribution p 0 (R) as p 0 (R) = T t=1 p 0 (R xt ) and each of them as p 0 (R xt ) = Bernoulli(0.5). By choosing this non-informative prior, it means every word is initialized with no preference to be important or unimportant, and thus has the equal probability to be masked or selected. As p 0 (R) is a uniform distribution, we can further simplify the second term in Equation 8 as a conditional entropy, We apply stochastic gradient descent to solve the optimization problem (Equation 9). Particularly in each iteration, the first term in Equation 9 is approximated with a single sample from q(R|x (i) ) (Kingma and Welling, 2014). However, sampling from a Bernoulli distribution (like from any other discrete distributions) causes difficulty in backpropagation. We adopt the Gumbel-softmax trick (Jang et al., 2016;Maddison et al., 2016) to utilize a continuous differentiable approximation and tackle the discreteness of sampling from Bernoulli distributions (Appendix B). During training, We use Adam (Kingma and Ba, 2014) for optimization and KL cost annealing (Bowman et al., 2016) to avoid posterior collapse.
For a given word x t and its word embedding x t , in training stage, the model samples each r xt from q(R xt |x t ) to decide to either keep or zero out the  corresponding word embedding x t . In inference stage, the model takes the multiplication of the word embedding x t and the expectation of the word mask distribution, i.e. x t · E[q(R xt |x t )], as input.

Experiment Setup
The proposed method is evaluated on seven text classification tasks, ranging from sentiment analysis to topic classification, with three typical neural network models, a long short-term memories (Hochreiter and Schmidhuber, 1997, LSTM), a convolutional neural network (Kim, 2014, CNN), and BERT (Devlin et al., 2018).
Datasets. We adopt seven benchmark datasets: movie reviews IMDB (Maas et al., 2011), Stanford Sentiment Treebank with fine-grained labels SST-1 and its binary version SST-2 (Socher et al., 2013), Yelp reviews (Zhang et al., 2015), AG's News (Zhang et al., 2015), 6-class question classification TREC (Li and Roth, 2002), and subjective/objective classification Subj (Pang and Lee, 2005). For the datasets (e.g. IMDB, Subj) without standard train/dev/test split, we hold out a proportion of training examples as the development set. Table 2 shows the statistics of the datasets.
Models. The CNN model (Kim, 2014) contains a single convolutional layer with filter sizes ranging from 3 to 5. The LSTM (Hochreiter and Schmidhuber, 1997) has a single unidirectional hidden layer. Both models are initialized with 300-dimensional pretrained word embeddings (Mikolov et al., 2013). We fix the embedding layer and update other parameters on different datasets to achieve the best performance respectively. We use the pretrained BERT-base model 3 with 12 transformer layers, 12 self-attention heads, and the hidden size of 768. We fine-tune it with different downstream tasks, and then fix the embedding layer and train the mask layer with the rest of the model together.
Baselines and Competitive Methods. As the goal of this work is to propose a novel training method that improves both prediction accuracy and interpretability, we employ two groups of models as baselines and competitive systems. Models trained with the proposed method are named with suffix "-VMASK". We also provide two baselines: (1) models trained by minimizing the crossentropy loss (postfixed with "-base") and (2) models trained with 2 -regularization (postfixed with "-2 "). The comparison with these two baseline methods mainly focuses on prediction performance as no explicit training strategies are used to improve interpretability. Besides, we also propose two competitive methods: models trained with the explanation framework "Learning to Explain" (Chen et al., 2018) (postfixed with "-L2X") and the "Information Bottleneck Attribution" (Schulz et al., 2020) (postfixed with "-IBA"). L2X and IBA were originally proposed to find feature attributions as post-hoc explanations for well-trained models. We integrated them in model training, working as the mask layer to directly generate mask values for input features (L2X) or restrict information flow by adding noise (IBA). In our experiments, all training methods worked with random dropout (ρ = 0.2) to avoid overfitting.
More details about experiment setup are in Appendix C, including data pre-processing, model configurations, and the implementation of L2X and IBA in our experiments.

Results and Discussion
We trained the three models on the seven datasets with different training strategies. Table 3 shows the prediction accuracy of different models on test sets. The validation performance and average runtime are in Appendix D. As shown in Table 3, all base models have the similar prediction performance comparing to numbers reported in prior work (Appendix E). The models trained with VMASK outperform the ones with similar network architec-  tures but trained differently. The results show that VMASK can help improve the generalization power. Except the base models and the models trained with the proposed method, the records of other three competitors are mixed. For example, the traditional 2 -regularization cannot always help improve accuracy, especially for the BERT model. Although the performance with IBA is slightly better than with L2X, training with them does not show a constant improvement on a model's prediction accuracy.
To echo the purpose of improving model interpretability, the rest of this section will focus on evaluating the model interpretability quantitatively and qualitatively.

Quantitative Evaluation
We evaluate the local interpretability of VMASKbased models against the base models via the AOPC score (Nguyen, 2018;Samek et al., 2016) and the global interpretability against the IBAbased models via post-hoc accuracy (Chen et al., 2018). Empirically, we observed the agreement between local and global interpretability, so there is no need to exhaust all possible combinations in our evaluation.

Local interpretability: AOPC
We adopt two model-agnostic explanation methods, LIME (Ribeiro et al., 2016) and SampleShapley (Kononenko et al., 2010), to generate local explanations for base and VMASK-based models, where "local" means explaining each test data individually. The area over the perturbation curve (AOPC) (Nguyen, 2018;Samek et al., 2016) metric is utilized to evaluate the faithfulness of explanations to models. It calculates the average change of prediction probability on the predicted class over all test data by deleting top n words in explanations. We adopt this metric to evaluate the model interpretability to post-hoc explanations. Higher AOPC scores are better.
For TREC and Subj datasets, we evaluate all test data. For each other dataset, we randomly pick up 1000 examples for evaluation due to computation costs. Table 4 shows the AOPCs of different models on the seven datasets by deleting top 5 words identified by LIME or SampleShapley. The AOPCs of VMASK-based models are significantly higher than that of base models on most of the datasets, indicating that VMASK can improve model's interpretability to post-hoc explanations. The results on the TREC dataset are very close because top 5 words are possible to include all informative words for short sentences with the average length of 10.

Global Interpretability: Post-hoc accuracy
The   where M is the number of examples, y m is the predicted label on the m-th test data, and y m (k) is the predicted label based on the top k important words. Figure 1 shows the results of VMASK-and IBAbased models on the seven datasets with k ranging from 1 to 10. VMASK-based models (solid lines) outperform IBA-based models (dotted lines) with higher post-hoc accuracy, which indicates our proposed method is better on capturing task-specific important features. For CNN-VMASK and LSTM-VMASK, using only top two words can achieve about 80% post-hoc accuracy, even for the IMDB dataset, which has the average sentence length of 268 tokens. The results illustrate that VMASK can identify informative words for model predictions.
We also noticed that BERT-VMASK has lower posthoc accuracy than the other two models. It is probably because BERT tends to use larger context with its self-attentions for predictions. This also explains that the post-hoc accuracies of BERT-VMASK on the IMDB and SST-1 datasets are catching up slowly with k increasing.

Qualitative Evaluation
Visualizing post-hoc local explanations. Table 5 shows some examples of LIME explanations for different models on the IMDB dataset. We highlight the top three important words identified by LIME, where the color saturation indicates word attribution. The pair of base and VMASK-based models make the same and correct predictions on the input texts. For VMASK-based models, LIME can capture the sentiment words that indicate the same sentiment polarity as the prediction. While for base models, LIME selects some irrelevant words (e.g. "plot", "of", "to") as explanations, which illustrates the relatively lower interpretability of base models to post-hoc explanations.
Visualizing post-hoc global explanations. We adopt SP-LIME proposed by Ribeiro et al. (2016) as a third-party global interpretability of base and VMASK-based models. Without considering the rectriction on the number of explanations, we follow the method to compute feature global importance from LIME local explanations (subsubsection 5.1.1) by calculating the sum over all local importance scores of a feature as its global importance. To distinguish it from the global importance learned by VMASK, we call it post-hoc global importance. Table 6 lists the top three post-hoc global important words of base and VMASK-based models on the IMDB dataset. For VMASK-based mod- els, the global important features selected by SP-LIME are all sentiment words. While for base models, some irrelevant words (e.g. "performances", "plot", "butcher") are identified as important features, which makes model predictions unreliable.
Frequency-importance correlation. We compute the Pearson correlation coefficients between word frequency and global word importance of VMASK-based models in Appendix F. The results show that they are not significantly correlated, which indicates that VMASK is not simply learning to select high-frequency words. Figure 2 further verifies this by ploting the expectation (E[q(R xt |x t )]) of word masks from the LSTM-VMASK trained on Yelp and the word frequency from the same dataset. Here, we visualize the top 10 high-frequency words and top 10 important words based the expectation of word masks. The global importance scores of the sentiment words are over 0.8, even for some low-frequency words (e.g. "funnest", "craveable"), while that of the highfrequency words are all around 0.5, which means the VMASK-based models are less likely to focus on the irrelevant words to make predictions.
Task-specific important words. Figure 3 visualizes top 10 important words for the VMASKand IBA-based models on three datasets via word clouds. We can see that the selected words by VMASK are consistent with the corresponding topic, such as "funnest", "awsome" for sentiment analysis, and "encyclopedia", "spaceport" for news classification, while IBA selects some irrelevant words (e.g. "undress", "slurred").

Conclusion
In this paper, we proposed an effective method, VMASK, learning global task-specific important features to improve both model interpretability and prediction accuracy. We tested VMASK with three different neural text classifiers on seven benchmark datasets, and assessed its effectiveness via both quantitative and qualitative evaluations.

Models Texts Prediction
CNN-base Primary plot , primary direction , poor interpretation . negative CNN-VMASK Primary plot , primary direction , poor interpretation . negative LSTM-base John Leguizamo 's freak is one of the funniest one man shows I 've ever seen . I recommend it to anyone with a good sense of humor .
positive LSTM-VMASK John Leguizamo 's freak is one of the funniest one man shows I 've ever seen . I recommend it to anyone with a good sense of humor .
positive BERT-base Great story , great music . A heartwarming love story that ' s beautiful to watch and delightful to listen to . Too bad there is no soundtrack CD .
positive BERT-VMASK Great story , great music . A heartwarming love story that ' s beautiful to watch and delightful to listen to . Too bad there is no soundtrack CD .
positive  The following derivation is similar to the variational information bottleneck, where the difference is that our starting point is the approximation distribution q(X, Y , Z) instead of the true distribution p(X, Y , Z).
The lower bound for I(Z; Y ).
where H q (·) represents entropy. Now, if we replace log q(y|z) with the conditional probability derived from the true distribution log p(y|z), we have where KL[· ·] denotes Kullback-Leibler divergence. Therefore, we can obtain a lower bound of the mutual information where the last step uses q(x, y, z) = q(x)q(y|x)q(z|x), which is a factorization based on the conditional dependency 4 .
Since q(X, Y , Z) is the approximation defined by ourselves, given a specific observation (x (i) , y (i) ), the empirical distribution q(X (i) , Y (i) ) 4 Y ↔ X ↔ Z: Y and Z are independent given X.
(13) Then, Equation 12 with X (i) and Y (i) can be further simplified as The upper bound for I(Z; X).
By replacing q(z) with a prior distribution of z, p 0 (z), we have (16) Then we can obtain an upper bound of the mutual information

C Supplement of Experiment Setup
Data pre-processing. We clean up the text by converting all characters to lowercase, removing extra whitespaces and special characters. We tokenize texts and remove low-frequency words to build vocab. We truncate or pad sentences to the same length for mini-batch during training. Table 7 shows pre-processing details on the datasets.
Implementation of L2X and IBA.
• The explanation framework of L2X (Chen et al., 2018) is a neural network which learns to generate importance scores w = [w 1 , w 2 , · · · , w T ] for input features x = [x 1 , x 2 , · · · , x T ]. The neural network is optimized by maximizing the mutual information between the selected important features and the model prediction, i.e. I(x S ; y), where x S contains a subset of features from x. In our experiments, we adopt a single-layer feedforward neural network as the interpreter to generate importance scores for an input text, and multiply each word embedding with its importance score, x = w x. The weighted word embedding matrix x is sent to the rest of the model to produce an output y . We optimize the interpreter network with the original model by minimizing the cross-entropy loss between the final output and the ground-truth label, L ce (y t ; y ).
• We adopt the Readout Bottleneck of IBA which utilizes a neural network to predict mask values λ = [λ 1 , λ 2 , · · · , λ T ], where λ t ∈ [0, 1]. The information of a feature x t is restricted by adding noise, i.e. z t = λ t x t + (1 − λ t ) t , where t ∼ N (µ xt , σ 2 xt ). And z is learned by optimizing the objective function Equation 2. By assuming the variational approximation q(z) as a Gaussian distribution, the mutual information can be calculated explicitly (Schulz et al., 2020) . We still use a single-layer feedforward neural network as the Readout Bottleneck to generate continuous mask valuses λ and construct z for model to make predictions. The Readout Bottleneck is trained jointly with the original model by minimizing the sum of the crossentropy loss L ce (y t ; y) and an upper bound L I = E x [KL[p(z|x) q(z)]] of the mutual information I(Z; X). See Schulz et al. (2020) for the proof of the upper bound.

D Validation Performance and Average Runtime
The corresponding validation accuracy for each reported test accuracy is in Table 8. The average runtime for each approach on each dataset is recorded in Table 9. All experiments were performed on a single NVidia GTX 1080 GPU. Table 10 shows some results of prediction accuracy of base models reported in previous papers.