QAInfomax: Learning Robust Question Answering System by Mutual Information Maximization

Standard accuracy metrics indicate that modern reading comprehension systems have achieved strong performance in many question answering datasets. However, the extent these systems truly understand language remains unknown, and existing systems are not good at distinguishing distractor sentences which look related but do not answer the question. To address this problem, we propose QAInfomax as a regularizer in reading comprehension systems by maximizing mutual information among passages, a question, and its answer. QAInfomax helps regularize the model to not simply learn the superficial correlation for answering the questions. The experiments show that our proposed QAInfomax achieves the state-of-the-art performance on the benchmark Adversarial-SQuAD dataset.


Introduction
Question answering tasks are widely used for training and testing machine comprehension and reasoning (Rajpurkar et al., 2016;Joshi et al., 2017). However, high performance in standard automatic metrics has been achieved with only superficial understanding, as models exploit simple correlations in the data that happen to be predictive on most test examples. Jia and Liang (2017) addressed this problem and proposed an adversarial version of the SQuAD dataset, which was created by adding a distractor sentence to each paragraph. The distractor sentences challenge the model robustness, and the created Adversarial-SQuAD data shows the inability of a model about distinguishing a sentence that actually answers the question from one that merely has words in common with it, where almost all state-of-the-art machine comprehension systems are significantly degraded on adversarial examples. Lewis and Fan (2018) argued that over-fitting to superficial biases is partially caused by discriminative loss functions, which saturate when simple correlations allow the question to be answered confidently, leaving no incentive for further learning on the example. Therefore, they designed generative QA models, which use a generative loss function in question answering instead, and showed the improvement on Adversarial-SQuAD.
Instead of regularizing models by generative loss functions, we propose an alternative approach named "QAInfomax" by maximizing mutual information (MI) among passages, questions, and answers, aiming at helping models be not stuck with superficial biases in the data during learning. To efficiently estimate MI, QAInfomax incorporates the recently proposed deep infomax (DIM) in the model , which was proved effective in learning representations for image, audio (Ravanelli and Bengio, 2018), and graph domains (Veličković et al., 2018). In this work, the proposed QAInfomax further extends DIM to the text domain, and encourages the question answering model to generate answers carrying information that can explain not only questions but also itself, and thus be more sensitive to distractor sentences. Our contributions are summarized: • This paper first attempts at applying DIMbased MI estimation as a regularizer for representation learning in the NLP domain.
• The proposed QAInfomax achieves the stateof-the-art performance on the Adversarial-SQuAD dataset without additional training data, demonstrating its better robustness.

Mutual Information (MI) Estimation
In this section, we introduce how scalable estimation of mutual information is performed in terms of practical scenarios via mutual information neural estimation (MINE) (Belghazi et al., 2018) and the deep infomax (DIM)  described below. The mutual information between two random variable X and Y is defined as: where D KL is the Kullback-Leibler (KL) divergence between the joint distribution p(X, Y ) and the product of marginals p(X)p(Y ).
MINE estimates mutual information by training a classifier to distinguish between positive samples (x, y) from the joint distribution and negative samples (x,ȳ) from the product of marginals. Mutual information neural estimation (MINE) uses Donsker-Varadhan representation (DV) (Donsker and Varadhan, 1983) as a lower-bound to estimate MI.
where E P and E N denote the expectation over positive and negative samples respectively, and g is the discriminator function that outputs a real number modeled by a neural network.
While the DV representation is the strong bound of mutual information shown in MINE, we are primarily interested in maximizing MI but not focusing on its precise value. Thus DIM proposes an alternative estimation using Jensen-Shannon divergence (JS), which can be efficiently implemented using the cross-entropy (BCE) loss: While two representations should behave similarly, considering that both act like classifiers with objectives maximizing the expected log-ratio of the joint over the product of marginals, it is found that the BCE loss empirically works better than the DV-based objective Ravanelli and Bengio, 2018;Veličković et al., 2018). The reason may be that the BCE loss is bounded (i.e., its maximum is zero), making the convergence of the network more numerically stable. In our experiments, we primarily use the JS representation to estimate mutual information.
Recently, Tian et al. (2019) showed strong empirical performance through the improved multiview CPC training (Oord et al., 2018), which shares many common ideas as mutual information maximization. Inspired by their work, we modify (1) by first switching the role of x and y and summing them up: where (x, y) is also the negative sample sampled from the product of marginals. We empirically find that (2) gives the best performance, and more exploration about parameterization of MI is left as our future work.

Methodology
In the extractive question answering dataset like SQuAD, the answer A = {a 1 , . . . , a M } to the question Q = {q 1 , . . . , q K } is guaranteed to be the span {p m , . . . , p m+M } in the paragraph P = {p 1 , . . . , p N }. Given Q and P , the encoded representations from the QA system M can be formulated as: where r q and r p are representations of the question and the passage respectively after the reasoning process in the QA system M .
Most models then feed the passage representation r p to a single-layer neural network, obtain the span start and end probabilities for each passage word, and compute the loss L span , which is the negative sum of log probabilities of the predicted distributions indexed by true start and end indices.
Our QAInfomax aims at regularizing the QA system M to not simply exploit the superficial biases in the dataset for answering questions. Therefore, two constraints are introduced in order to guide the model learning. Intuitively, the model is expected to choose the answer span after fully considering the entire question and paragraph. However, traditional QA models suffered the overstability problem, and tended to be fooled by distractor answers, such as the one containing an unrelated human name. As Lewis and Fan (2018) argued, we also believe that the main reason is that QA models are only trained to predict start and end positions of answer spans. Correlation in the dataset allows QA models to find shortcuts and ignore what the answer span looks like. A learned behavior of traditional QA models can be viewed as a simple pattern matching, such as choosing the 5-length span after the word "river" if a question is about a river and the context talks about countries in European.
Following the intuition, two constraints LC and GC are introduced to guide models to learn the desired behaviors. To prevent the model from only learning to match some specific word patterns to find the answer, LC forces the model to generate answer span representations while maximizing mutual information among words in the span and the context words surrounding the span. By maximizing the mutual information between an answer word and all of its context words, models need to incorporate the entire context into its decision process while choosing answers, and thus can be more robust to the adversarial sentences. Then we further require models to maximize mutual information among answer words, so models can no longer ignore any word in the chosen answer span.
On the other hand, different from LC , which only focuses on the answer span and its context, GC pushes the model to prefer answer representations carrying information that is globally shared across the whole input conditions Q and P , because shortcuts do not necessarily appear near to the answer. If the model only learns to leverage the correlation specific to the partial input, the MI of any input word without such relationship would not increased.
The overview about two proposed constraints is illustrated in Figure 1. The detail of two constraints and our QAInfomax regularizer is described below.

Local Constraint
As shown in Section 2, the maximization of MI needs positive samples and negative samples drawn from joint distribution and the product of marginal distribution respectively.
In LC , because all answer word representations are expected to carry the information of each other and their contexts, we choose to maximize averaged MI between the sampled answer word representations and the whole answer sequence with its context words. Specifically, a positive sample is obtained by pairing the sampled answer word representation x 2 r a = {r p m , . . . , r p m+M } to all other answer and context words r c = {r p m C , . . . r p m+M +C } \ {x}, where C is the hyperparameter defining how many context words for consideration. Negative samples, on the other hand, are obtained by randomly sampling answer representationr a = {r p l , . . . ,r p l+L } and the correspondingr c from other training examples. Following (2), the objective for sampled x, r c ,x 2r a andr c is formulated.

Global Constraint
Different from LC described above, GC forces the learned answer representations r a to have information shared with all other question and passage representations. Here, we maximize the mutual information between the summarized answer vector s = S(r a ) and r l 2 r = {r q , r p } \ {r a } pairs. In
Specifically, a positive sample here is the pair of a answer summary vector s = S(r a ) and all other word representations in r. Negative samples are provided by sampling question, passage and answer representations {r q ,r p ,r a } from an alternative training example. Then we pair the summary s withr = {r q ,r p } \ {r a }, ands = S(r a ) with r.

QAInfomax
In our proposed model, we combine two objectives and formulate the model as the complete QAInfomax regularizer.
For each training batch consisting of training examples {{Q 1 , P 1 , A 1 }, . . . {Q B , P B , A B }}, we pass the batch into the model M and obtain representations {{r q 1 , r p 1 , r a 1 }, . . . , {r q B , r p B , r a B }}. Note that we abuse the subscripts to denote the example index in the batch for simplicity.
Then we shuffle the whole batch to obtain negative examples {{r q 1 ,r p 1 ,r a 1 }, . . . , {r q B ,r p B ,r a B }}. The complete objective L inf o for QAInfomax becomes: where x i andx i are the representation sampled from r a i andr a i , r c i andr c i are r a i andr a i expanded with its context words respectively, s i ands i are the summary vectors of r a i andr a i , and ↵ and are hyperparameters.
Combined with QAInfomax as a regularizer, the final objective of the model becomes where L span is the answer span prediction loss and is the regularize strength. The objective can be optimized through the simple gradient descent.

Experiments
To evaluate the effectiveness of the proposed QAInfomax, we conduct the experiments on a challenging dataset, Adversarial-SQuAD.

Setup
BERT-base (Devlin et al., 2018) is employed as our QA system M in the experiments, where we set the same hyperparameters as one released in SQuAD training 2 .
We set C, ↵, and to be 5, 1, 0.5, 0.3 respectively in all experiments, and add the proposed QAInfomax into the BERT model as described above. The discriminator function g is the bilinear function similar to the scoring used by Oord et al. (2018): where W is a learnable scoring matrix.
We train the BERT model with the proposed QAInfomax on the orignal SQuAD dataset, and  use Adversarial-SQuAD to test the robustness of the augmented model. Only ADDSENT and AD-DONESENT metrics are reported for the comparison with previous models, because most previous models did not report their ADDANY and AD-DCOMMON scores. Briefly, for each example, ADDSENT runs the model M on every humanapproved adversarial sentence, picks the one that makes the model give the worst answer and returns that score. ADDONESENT, on the other hand, only picks a random human-approved adversarial sentence. The numbers reported in all experiments are the best number across at least three runs. Table 1 reports model performance on Adversarial-SQuAD. It can be found that QAInfomax yields substantial improvement over the vanilla BERT model, and achieves the state-of-the-art performance on both ADDSENT and ADDONESENT metrics. 3 QAInfomax obtains larger improvement on the ADDSENT, which picks the worst scores of the model. It shows the effectiveness of our QAInfomax in terms of forcing the model to ignore simple correlation in the data and learn the more human-like reasoning processes. It is worth to note that while QAInfomax mitigates the overstability problem and improves the robustness to adversarial examples, it does not hurt the original performance of the QA system, demonstrating the benefit for the practical usage. Some example results from the Adversarial-SQuAD dataset can be found in the Appendix, where adversarial distracting sentences are shown in italic blue fonts. Table 2 shows the ablation study of our proposed QAInfomax, where two proposed con-  straints are both important for achieving such results. We also show the training speed of the proposed method and its limitation, where the GC objective degrades the training speed by 28%. The reason is that GC measures the averaged MI over the whole question and passage representations, which may include a long sequence of vectors.

Results
Considering that the summarization function S plays an important role in GC, we explore its different variants in Table 3: • Sample: randomly sample one r a i 2 r a According to the experimental results, Mean performs the best while Max and Sample has the competitive performance, showing the great robustness of the proposed methods to different architecture choices.

Conclusion
This paper presents a novel regularizer based on MI maximization for question answering systems named QAInfomax, which helps models be not stuck with superficial correlation in the data and improves its robustness. The proposed QAInfomax is flexible to apply to different machine comprehension models. The experiments on Adversirial-SQuAD demonstrate the effectiveness of our model, and the augmented model achieves the state-of-the-art results. In the future, we will investigate more methods for reducing the limitations of QAInfomax and improving the capability of generalization in QA systems.