Learning to Contrast the Counterfactual Samples for Robust Visual Question Answering

In the task of Visual Question Answering (VQA), most state-of-the-art models tend to learn spurious correlations in the training set and achieve poor performance in out-of-distribution test data. Some methods of generating counterfactual samples have been proposed to alleviate this problem. However, the counterfactual samples generated by most previous methods are simply added to the training data for augmentation and are not fully utilized. Therefore, we introduce a novel self-supervised contrastive learning mechanism to learn the relationship between original samples, factual samples and counterfactual samples. With the better cross-modal joint embed-dings learned from the auxiliary training objec-tive, the reasoning capability and robustness of the VQA model are boosted signiﬁcantly. We evaluate the effectiveness of our method by surpassing current state-of-the-art models on the VQA-CP dataset, a diagnostic benchmark for assessing the VQA model’s robustness.


Introduction
To develop human-like visual and language understanding of AI, the task of answering a question about the given visual content has been proposed, i.e., Visual Question Answering (VQA) (Antol et al., 2015). Although the current state-of-the-art methods (Fukui et al., 2016;Cadene et al., 2019a) can achieve good results on the VQA benchmarks such as VQA v2 (Goyal et al., 2017), recent researches Kafle and Kanan, 2017; have found that these methods tend to explore superficial correlations in the training set and perform poorly when transferred to real world setting. Specifically, given a question "What color is the banana?", the models prefer to take the shortcut and "assume" that the answer should be "yellow" since it is the most common answer in the training set, rather   (Chen et al., 2020).
than be grounded on the image. To overcome the language bias problems in VQA,  have proposed a dataset named VQA-CP, where the answer distribution of the training set differs from the test set vastly. The performance of most current state-of-the-art models (Andreas et al., 2016;Shrestha et al., 2019) drop significantly on the VQA-CP due to the language bias. Hence, it has become the standard out-of-distribution benchmark for VQA.
A successful robust and unbiased VQA system is supposed to be able to deduce the right answer from the right area of the image. Lately, some studies have proposed to synthesize counterfactual samples to improve the robustness of VQA models. (Agarwal et al., 2019;Pan et al., 2019) apply GAN (Goodfellow et al., 2014) to generate images. CSS algorithm proposed by (Chen et al., 2020) generates counterfactual samples by masking the critical objects in images or words in questions, as shown in Figure 1. The critical objects or words can be obtained from CSS as by-products. Nevertheless, the counterfactual samples are simply added to the training data for augmentation, ignoring that the relationship between original samples and counterfactual samples are vital for the reasoning of VQA models. Specifically, the model should be able to learn why the correct answer cannot be inferred after changing the original sample to the counterfactual sample. We posit that modeling the relationship between original samples, factual samples and counterfactual samples can bring more self-supervised signals to improve the reasoning ability of the model.
In order to enable the VQA model to understand the impact of the samples changing from original to counterfactual, we introduce a novel contrastive learning mechanism into the training with counterfactual samples, which is first proposed in the field of learning with counterfactual samples. The auxiliary contrastive training objective model the relationship between original samples, factual samples and counterfactual samples in the cross-modal joint embedding space. With the better cross-modal representations, both the reasoning ability and robustness of the VQA model are improved efficiently.
Overall, the contributions of this paper are as follows: • We are the first to introduce a self-supervised contrastive learning mechanism for counterfactual samples in VQA. Our method not only helps the VQA model learn the relationship between original samples, factual samples and counterfactual samples but also improves the generalization ability of the model significantly. • Experiment results show that our method brings significant improvements and achieves state of the art on VQA-CP dataset. Furthermore, the effectiveness of contrastive mechanism in counterfactual sample learning is not limited to the form of contrastive loss.
2 Related Work

Language Bias in VQA
As the issue of language bias in VQA models is pointed out Jabri et al., 2016;Goyal et al., 2017), creating a more balanced dataset is a simple way to alleviate it. To this end, the VQA v2 dataset (Goyal et al., 2017) rearranges the sample distribution so that it contains at least one different answer when given a same question and a similar image. Since the statistical bias problem remains,  introduce the VQA-CP dataset where the answer distributions are re-distributed in the training and test splits, making it become the standard benchmark for evaluating the robustness of VQA models.

Counterfactual Samples for VQA
Recently, employing insights from causal inference (Neuberg, 2003), some researches synthesize counterfactual samples to augment the training of VQA models (Agarwal et al., 2019;Pan et al., 2019;Chen et al., 2020). Similar to our work, (Teney et al., 2020a) have proposed a training objective named Gradient Supervision (GS) to use the relation information between original training samples and additional counterfactual samples. The GS encourages the gradient of the model to align with a "ground truth" gradient, which is the translation from original sample to counterfactual sample in the input space. In contrast, we employ a novel contrastive learning strategy to simultaneously learn the triplet relationship between the original training samples, factual samples and counterfactual samples.

Contrastive Learning
Contrastive learning techniques have achieved great success in unsupervised learning (Oord et al., 2018;He et al., 2019). The main idea of unsupervised contrastive learning is to maximize the mutual information between the input samples and positive samples so as to learn better representations. Inspired by this, we apply the contrastive mechanism to learn the self-supervision information from counterfactual samples for the first time and improve the robustness of VQA models.

Methodology
In this section, we introduce our technical realization. The flowchart of our proposed method is illustrated in Figure 2. Our method consists of three parts: (1) A base VQA model (2) A factual and Counterfactual Samples Synthesizing (CSS) module (3) A Contrastive Learning (CL) objective.

Baseline VQA Model
We adopt the Bottom-Up Top-Down (UpDn) (Anderson et al., 2018) model into our method, which considers the common formulation of VQA task as a multi-class classification problem. Given a set consisting of N triplets of images I i ∈ I, question Q i ∈ Q and answer a i ∈ A, we denote as The task aims to learn a mapping function f vqa : I × Q → [0, 1] |A| , producing Is this man holding a cat? an answer distribution of the given image and question. In the following sections, we will omit the subscript i for simplicity. For each question Q, the UpDn uses a question encoder e q to extract a set of word embeddings Q. For each image I, the UpDn uses an object detector e v to extract a set of visual object embeddings V . Then both Q and V are fed into attention and fusion modules to generate the joint embedding mm(Q, V ). The joint embedding is then fed into classifier C to predict the answer:

Synthesizing Counterfactual Samples
There are several ways to synthesize the counterfactual samples of the given image-question pairs in our pipeline. For instance, (Teney et al., 2020a) build counterfactual samples using annotations of human attention (Das et al., 2016). Basically, they generate the counterfactual image by masking the features whose bounding boxes overlap with the human attention map past a certain threshold. In contrast to using extra manual annotations, CSS algorithm proposed by (Chen et al., 2020) calculates the critical objects (I + ) in image or words (Q + ) in question by the modified Grad-CAM (Selvaraju et al., 2017) and masks them to generate the counterfactual samples. Since the latter is more practical, we adopt the CSS algorithm into our pipeline and obtain the factual (I + , Q + ) and counterfactual (I − , Q − ) samples: (I + , I − , Q + , Q − ) = CSS(f vqa , (I, Q, a)) (2)

Contrastive Learning Objective
With the causal triplets (I, I + , I − ) and (Q, Q + , Q − ) obtained from CSS, we can apply the contrastive learning mechanism. We take a specific triplet (I, I + , I − ) as an example shown in Figure 2 to illustrate the contrastive learning method. First, the I, I + and I − paired with the Q are fed into the VQA model to generate the joint embeddings of them. Then, we denote the joint embedding mm(Q, V ) of the original sample as the anchor a, the embedding mm(Q, V + ) of the factual sample as the positive p and the embedding mm(Q, V − ) of the counterfactual sample as the negative n. Before defining the contrastive loss, we first define a scoring function s that outputs high values for the positive sample and low values for the negative sample. We take the cosine similarity of the representations in the joint embedding space as our scoring function because it implicitly normalizes the embeddings. The score between the anchor and the positive s(a, p) can be described as: Similarly, the score between the anchor and the negative is defined as s(a, n). Then, following recent work in unsupervised learning (Oord et al., 2018), the contrastive loss is formulated as: can maximize a lower bound on mutual information between factual sample and original sample, enabling the model to learn the relationship between them and predict the right answer from a more causal aspect. The weighted sum of this contrastive loss and the base VQA classification loss L vqa make up the overall loss: where λ vqa and λ c are the loss weight for each loss.

Datasets
The VQA-CP dataset 1

Settings and Comparisons with SoTA
We validate the effectiveness of our method in the VQA-CP (both v1 and v2) datasets . Results on the VQA v2 are also reported in appendices for completeness. We use the standard 1 https://www.cc.gatech.edu/ aagrawal307/vqa-cp/ VQA evaluation metric (Antol et al., 2015) for accuracy report. All our implementation details are in appendices. Table 1 shows the result comparison with the stateof-the-art models on the VQA-CP v2. According to the backbone of these models, we group them into: 1) SAN based methods, including GVQA. 2) Unshuffling based methods, including CF, CF+GS.

Performance on VQA-CP v2
3) UpDn based methods, including AReg, GRL, RUBi, LMH, CSS, HINT and SCR. The results show that our Contrastive Learning (CL) building on top of UpDn+LMH+CSS outperforms these previous results, improving the overall accuracy from 57.74% to 59.18% (+1.44%). In contrast, the Gradient Supervision (GS) for the counterfactuals brings smaller gain (+0.80%) from Unshuf-fling+CF. We further explore the performance of Gradient Supervision when applied with the same set of counterfactual samples (CSS). From Table 1, we can observe that our method still outperforms the LMH+CSS+GS by 1.88%, indicating that our method can bring more self supervision from the counterfactual samples than the GS.

Different Forms of Contrastive Loss
To explore whether different forms of contrastive loss are effective in learning the counterfactual samples in VQA, we conduct experiments on the VQA-CP v2 using the varient of Margin-based Contrastive Loss (MarginCL) proposed by (Hadsell   , 2006), which is formulated as: where the D(a, p) = 1 − s(a, p) (cosine distance between a and p). The m is the margin between a and n, which is set to 0.3. Table 3 shows the experimental results. The improvements on two different VQA models demonstrate that our method is generic.

Performance of counterfactual samples and factual samples
To further explore whether our method improves the generalization capability of the VQA model, we conduct the experiments about the VQA performance of the counterfactual samples and factual samples on the VQA-CP v2 and report the result in Table 4. Comparing with the CSS and the CSS+GS, our method achieves the best performance, which demonstrates that the VQA model benefits from the contrastive learning mechanism and accordingly generalizes better on the counterfactual samples and factual samples.

Case Study
To validate the effects of our contrastive training objective, we visualize the joint embeddings of two examples and their synthesized samples by employing the t-SNE (Maaten and Hinton, 2008). As Figure 3 shows, compared with the LMH+CSS, our auxiliary training objective helps to not only pull up the original sample and factual sample but also push away the original sample and counterfactual sample in the embedding space, which may build a better causal VQA model.

Conclusion
In order to fully utilize the supervision information of synthesized counterfactual samples in robust VQA, we introduce a self-supervised contrastive learning mechanism to learn the relationship between factual samples and counterfactual samples. The experimental results demonstrate that our method improves the reasoning ability and robustness of the VQA models.

A Appendices
A.1 Implementation Details The UpDn model uses pretrained Faster R- CNN (Ren et al., 2015) to extract top K object feature embeddings. We set K = 36 in our implementation, and the dimension of each object features is 2048. For question embeddings, we preprocess the questions to a maximum of 14 words. The word embeddings are initialized with pretrained GloVe (Pennington et al., 2014) vectors with dimension of 300. A single-layer GRU (Cho et al., 2014) is used to obtain question embedding vectors with the dimension of 512. The dimension of the joint embedding is 2048. The initial learning rate of Adamax optimizer and learning rate decay schedule are followed to the public reimplementation 2 . The entire system is trained end-to-end with both L vqa and L c . The parameters are initialized from scratch and the random seed is set to 0. The loss weight λ vqa and λ c are respectively set to 1 and 2. We set batch size to 512. The model developed on the official public Pytorch codebase 3 takes about 5 hours (∼30 epochs) to train on a Nvidia RTX 2080Ti. Both Q-CSS and V-CSS are used to generate (Q, Q + , Q − ) and (I, I + , I − ).  The results on the VQA v2 are also reported in Table 5 for completeness. We observe that our