Explain by Evidence: An Explainable Memory-based Neural Network for Question Answering

Interpretability and explainability of deep neural net models are always challenging due to their size and complexity. Many previous works focused on visualizing internal components of neural networks to represent them through human-friendly concepts. On the other hand, in real life, when making a decision, human tends to rely on similar situations in the past. Thus, we argue that one potential approach to make the model interpretable and explainable is to design it in a way such that the model explicitly connects the current sample with the seen samples, and bases its decision on these samples. In this work, we design one such model: an explainable, evidence-based memory network architecture, which learns to summarize the dataset and extract supporting evidences to make its decision. The model achieves state-of-the-art performance on two popular question answering datasets, the TrecQA dataset and the WikiQA dataset. Via further analysis, we showed that this model can reliably trace the errors it has made in the validation step to the training instances that might have caused this error. We believe that this error-tracing capability might be beneficial in improving dataset quality in many applications.


Introduction
Interpretability of neural networks is an active research field in machine learning. Deep neural networks might have tens if not hundreds of millions of parameters (Devlin et al., 2019;Liu et al., 2019a) organized into intricate architectures. The sheer amount of parameters and the complexity of the architectures largely prevent human to directly make sense of which concepts and how the network truly learns. The comparative lack of explainable intuition behind deep neural networks might hamper the development and adoption of those models. In certain scenarios, prediction accuracy alone is not sufficient (Caruana et al., 2015;Lapuschkin et al., 2019). For example, as discussed in (Zhang et al., 2018b;Zhang et al., 2018a), it is difficult to trust a deep model even if it has high test set performance given the inherent biases in the dataset. Thus, we argue that interpretability is perhaps one of the keys to accelerate both the development and adoption of deep neural networks.
There have been many successful attempts from the research community to make sense of deep models' prediction. These attempts can be broadly categorized into several classes. One of the major classes concerns with the network visualization techniques, for example, visual saliency representations in convolutional models (Simonyan et al., 2013;Sundararajan et al., 2017). For recurrent neural networks (RNN), Karpathy et al. (2015) focused on analyzing and visualizing the RNN to explain its ability to keep track of long-range information.
The visualization-based methods, although achieving great successes, still operate on a very high level of abstractions. It requires a great deal of machine learning knowledge to make use of those visualizations. Thus, these techniques are not always useful for a broader audience, who might not have the machine learning expertise. Looking back at classic machine learning models, one class of models stands out as being very intuitive and easy to understand: the instance-based learning algorithms. The k-nearest neighbors algorithm, a prime example, operates on a very human-like assumption. To elaborate, if the current circumstances are similar to that of a known situation in the past, we may very well make this decision based on the outcome of the past decision. We argue that this assumption puts the interpretability on a much lower level of abstraction compared to the visualization methods. If somehow our model can learn how to link the evidences from the training data to the prediction phase, we will have a direct source of interpretability that can be appreciated by a broader audience.
The k-nearest neighbors algorithm, as an instance-based method, might not be a deep neural network technique; however, there have been many papers in the deep model literature inspired by or related to this method. A notable example is the neural nearest neighbors network (Plötz and Roth, 2018). Moreover, there is a class of problems with strong links to k-nearest neighbors: few-shot learning. It is from the two major papers in the few-shot learning literature, the prototypical network (Snell et al., 2017) and the matching network (Vinyals et al., 2016), we find a potential realization for our ideas.
In few-shot learning, it is possible to learn the support from each of the instances from the support set to the current prediction; however, such approach is infeasible when the training data get larger. Inspired by the techniques discussed in (Ravi and Larochelle, 2017), we apply a training-level data summarizer based on the neural Turing machine (NTM) (Graves et al., 2014) that reads the dataset and summarizes (or writes) it into a few meta-evidence nodes. These meta-evidence nodes, in turn, lend support to each of the prediction similar to a few-shot learning model. The parameters of the NTM are jointly trained with other parameters of the network. Our final model not only has great predictive power and achieves state-of-the-art results on two popular answer selection datasets, but also shows a strong "error-tracing" capability, in which the errors in the validation set can be traced to the sources in the training set.
To summarize, our contributions in this work are twofold. First, we propose a novel neural network model that achieves state-of-the-art performance on two answer selection datasets. Second, we show the utility of the error-tracing capability of our model to find the noisy instances in the training data that degrades the performance on the validation data. This capability might be very useful in real-life machine learning scenarios where the training labels are noisy or the inter-annotator agreement is low.

Proposed Framework
Question answering (or answer selection) is the task of identifying the correct answer to a question from a pool of candidate answers. It is an active research problem with applications in many areas (Tay et al., 2018a;Tayyar Madabushi et al., 2018;Rao et al., 2019;Lai et al., 2020). Similar to most recent papers on this topic (Tay et al., 2018b;Lai et al., 2019;Garg et al., 2020), we cast the question answering problem as a binary classification problem by concatenating the question with each of the candidate answers and assigning positive label to the concatenation containing the correct answer.
In most supervised learning scenarios, performing a full distance calculation between the current data point and every training data point would be computationally intractable. To overcome this burden, we propose a memory controller based on NTM to summarize the dataset into meta-evidence nodes. Similar to NTM, the controller is characterized by reading and writing mechanisms. Assume that we provide the controller with K cells e 1 , . . . , e K in a memory bank (i.e. to store K support/evidence vectors), and let us denote the t-th data point as x t (obtained by using a pretrained embedding model to embed the concatenation of a question and a candidate answer), the memory controller works as follows.
Writing mechanism. The writing mechanism characterizes how the controller updates its memory given a new data point. To update the memory, however, we first need an indexing mechanism for writing. Instead of using the original indexing of the NTM, we adopt the simpler indexing procedure from the memory network, which has been proven to be useful in this task (Lai et al., 2019). At time step t, for each incoming data point x t , we compute the attention weight w e t i for the support vector e t i : (1) From these attention weights, we find the writing index for an input x t by maximizing the cosine similarity between x t and the evidence vectors: With the writing index found, we compute the memory update weight via a gating mechanism: where is a scalar, σ is sigmoid function, and W g and b g are learnable parameters. The hyperparameter prevents the outliers to break the memory values. The memory update at time step t is formalized as: Reading mechanism. The reading mechanism characterizes how the controller uses its memory and the current input to produce an output. Instead of reading one memory cell, we aim to learn the support of all meta-evidence nodes. Thus, the weighted sum is used to create a support vector s t : We then incorporate the original input with the support vector s t to produce the negative/positive class probabilities P (x t ) as follows: The overall information flow of our model is visualized in Figure 1. Our formulation draws inspiration from the NTM and the memory network. Our indexing algorithms in writing and reading mechanisms are similar to the memory network, which is simpler than the NTM. However, the memory network only stores intermediate computation steps in the memory, and these memories can be considered as internal layers of the network. Our memory, on the contrary, is external and not trained, only updated by the writing mechanism. In this regard, the memory bank of our model is more similar to the NTM.

Question answering performance
In this subsection, we present our core results on two most popular datasets for answer selection: WikiQA (Yang et al., 2015) and TrecQA (Wang et al., 2007). Due to space constraint, details of these datasets are described in the Appendix. Similar to previous work, we use two standard measures for the task: mean average precision (MAP) and mean reciprocal rank (MRR). Our models make use of the RoBERTa contextual embedding (Liu et al., 2019b), pretrained on the ANSQ dataset (Garg et al., 2020). For our model, we vary the number of memory cells from 2 to 64. The base configuration with 2 memory cells mimics the prototypical network with one cell for each prototype class representation. model and the baselines' performance. All our model's configurations outperform the previous state-ofthe-art models. 1 Increasing the number of memory cells beyond the basic 2 cells -one for each classclearly helps. The performance peaks at 16 or 32 cells depending on the dataset.

Error-tracing performance
One of the main motivations behind our evidence-based model is the ability to interpret the output of the neural network. It is hard to quantify the interpretability of different models, however. To create a benchmark for interpretability, we look for a potential application of interpretability in real-life development of a deep neural network. Data collection is one of the most important parts of a machine learning model's development cycle. In many cases, nevertheless, the collected data is not always clean and consistent, either due to errors made by annotators or equivocal data points. For example, the popular Switchboard Dialog Act dataset (Stolcke et al., 2000) only has 84% inter-annotator agreement. Thus, we would like to test how well different models help in identifying noisy instances in the dataset.
Our model naturally learns the most supportive group of instances given a new instance, and thus, we can easily use this information to trace from an error in validation to a group of training instances. Ideally, we will need to test all the training samples of that group, but that would quickly make the number of samples we need to check get out of control. Hence, we rely on heuristics: from the most relevant group, we only test the top k most similar instances (by cosine distance in the embedding space). To create a noisy dataset given our current QA datasets, we randomly swap 10% the labels in each training set. 2 We then calculate the percentage of errors in validation that the model can correctly trace back to the training set perturbation. For quantitative benchmark, we compare our proposed model with the best baseline (i.e. the RoBERTa + ANSQ transfer model) and the top k most similar representations.  Table 2: Error-tracing precision. Table 2 shows the error-tracing performance of the model compared to the baseline. Our best model shows strong error-tracing capability and outperforms the baseline by a wide margin. On both datasets, our model can trace roughly 90% of the errors to the perturbed data points. This experiment clearly shows that forcing a model to provide direct evidences helps in identifying noisy training instances.

Conclusion
In this paper, we propose a novel neural network architecture that not only achieves state-of-the-art performance on popular QA datasets, but also shows strong error-tracing performance, which we argue will be of great benefits to real-life machine learning applications. In the future, we would like to apply the model on different noisy user-generated datasets to test and further improve its interpretability.