Targeted Adversarial Training for Natural Language Understanding

We present a simple yet effective Targeted Adversarial Training (TAT) algorithm to improve adversarial training for natural language understanding. The key idea is to introspect current mistakes and prioritize adversarial training steps to where the model errs the most. Experiments show that TAT can significantly improve accuracy over standard adversarial training on GLUE and attain new state-of-the-art zero-shot results on XNLI. Our code will be released upon acceptance of the paper.


Introduction
Adversarial training has proven effective in improving model generalization and robustness in computer vision (Madry et al., 2017;Goodfellow et al., 2014) and natural language processing (NLP) (Zhu et al., 2019;Jiang et al., 2019;Cheng et al., 2019;Liu et al., 2020a;Pereira et al., 2020;Cheng et al., 2020). It works by augmenting the input with a small perturbation to steer the current model prediction away from the correct label, thus forcing subsequent training to make the model more robust and generalizable. Aside from some prior work in computer vision (Dong et al., 2018;Tramèr et al., 2017), most adversarial training approaches adopt non-targeted attacks, where the model prediction is not driven towards a specific incorrect label. In NLP, the cutting-edge research in adversarial training tends to focus on making adversarial training less expensive (e.g., by reusing backward steps in FreeLB (Zhu et al., 2019)) or regularizing rather than replacing the standard training objective (e.g., in virtual adversarial training (VAT) (Jiang et al., 2019)).
By contrast, in this paper, we investigate an orthogonal direction by augmenting adversarial training with introspection capability and adopting targeted attacks to focus on where the model errs the * Equal contribution.
(a) BERT with standard fine-tuning (b) BERT with TAT fine-tuning Figure 1: Comparison of confusion matrices on MNLI development set (in-domain). X-axis and Y-axis represent the predicted and gold labels, respectively. TAT produces an accuracy gain of 1.7 absolute points. most. We observe that in many NLP applications, the error patterns are non-uniform. For example, in the MNLI development set (in-domain), standard fine-tuned BERT model tends to misclassify a non-neutral instance as "neutral" more often than the opposite label (Figure 1 top). We thus propose Targeted Adversarial Training (TAT), a simple yet effective algorithm for adversarial training. For each instance, instead of taking adversarial steps away from the gold label, TAT samples an incorrect label proportional to how often the current model makes the same error in general, and takes adversarial steps towards the chosen incorrect label. To our knowledge, this is the first attempt to apply targeted adversarial training to NLP tasks. In our experiments, this leads to significant improvement over standard non-adversarial and adversarial training alike. For example, in the MNLI development set, TAT produced an accuracy gain of 1.7 absolute points (Figure 1 bottom). On the overall GLUE benchmark, TAT outperforms state-of-theart non-targeted adversarial training methods such as FreeLB and VAT, and enables the BERT BASE model to perform comparably to the BERT LARGE model with standard training. The benefit of TAT is particularly pronounced in out-domain settings, such as in zero-shot learning in natural language inference, attaining new state-of-the-art cross-lingual results on XNLI.

Targeted Adversarial Training (TAT)
In this paper, we focus on fine-tuning BERT models (Devlin et al., 2018) in our investigation of targeted adversarial training, as this approach has proven very effective for a wide range of NLP tasks.
The training algorithm seeks to learn a function f (x; θ) : x → C as parametrized by θ, where C is the class label set. Given a training dataset D of input-output pairs (x, y) and the loss function l(., .) (e.g., cross entropy), the standard training objective would minimize the empirical risk: By contrast, in adversarial training, as pioneered in computer vision (Goodfellow et al., 2014;Hsieh et al., 2019;Madry et al., 2017;Jin et al., 2019), the input would be augmented with a small perturbation that maximize the adversarial loss: where the inner maximization can be solved by projected gradient descent (Madry et al., 2017).
Recently, adversarial training has been successfully applied to NLP as well (Zhu et al., 2019;Jiang et al., 2019;Pereira et al., 2020). In particular, FreeLB (Zhu et al., 2019) leverages the free adversarial training idea (Shafahi et al., 2019) by reusing the backward pass in gradient computation to carry out inner ascent and outer descent steps simultaneously. SMART (Jiang et al., 2019) instead Algorithm 1 TAT Input: T : the total number of iterations, X = {(x 1 , y 1 ), ..., (x n , y n )}: the dataset, f (x; θ): the machine learning model parametrized by θ, σ 2 : the variance of the random initialization of perturbation δ, : perturbation bound, K: the number of iterations for perturbation estimation, η: the step size for updating perturbation, τ : the global learning rate, α: the smoothing proportion of adversarial training in the augmented learning objective, Π: the projection operation and C: the classes. 1: for t = 1, .., T do 2: for (x, y) ∈ X do for m = 1, .., K do 6: end for 12: end for Output: θ regularizes the standard training objective using virtual adversarial training (Miyato et al., 2018): Effectively, the adversarial term encourages smoothness in the input neighborhood, and α is a hyperparameter that controls the trade-off between standard errors and adversarial errors. In standard adversarial training, the algorithm simply tries to perturb the input x away from the gold label y given the current parameters θ. It is agnostic to which incorrect label f (x) might be steered towards. By contrast, in Targeted Adversarial Training (TAT), we would explicitly pick a target y t = y and try to steer the model towards y t . Intuitively, we would like to focus training on where the model currently errs the most. We accomplish this by keeping a running tally of e(y, y t ), which is the current expected error of predicting y t when the gold label is y, and sample y t from C \y = C − {y} in proportion to e(y, y t ). See Algorithm 1 for details. TAT can be applied to the We conducted an oracle experiment where e(y, y t ) was taken from the confusion matrix from standard training and found that it performed similarly as our online version.
It is more challenging to apply TAT to regression tasks, as we would need to keep track of a continuous error distribution. To address this problem, we quantize the value range into ten bins and apply TAT similarly as in the classification setting (once a bin is chosen, a value is sampled uniformly within).

Experiments
We compare targeted adversarial training (TAT) with standard training and state-of-the-art adversarial training methods such as FreeLB (Zhu et al., 2019) and VAT (Miyato et al., 2018;Jiang et al., 2019). We use the standard uncased BERT BASE model (Devlin et al., 2018), unless noted otherwise. Due to the additional overhead incurred during training, adversarial methods are somewhat slower than standard training. Like VAT, TAT requires an additional K adversarial steps compared to standard training. In practice, K = 1 suffices for TAT and VAT, so they are just slightly slower (roughly 2 times compared to standard training). FreeLB, by contrast, typically requires 2-5 steps to attain good performance, so is significantly slower.

Implementation Details
Our implementation is based on the MT-DNN toolkit (Liu et al., 2020b). We follow the default hyperparameters used for fine-tuning the uncased BERT base model (Devlin et al., 2018;Liu et al., 2020b). Specifically, we use 0.1 for the dropout rate except 0.2 for MNLI, 0.01 for the weight decay rate and the Adamax (Kingma and Ba, 2014) optimizer with the default Lookahead (Zhang et al., 2019) to stabilize training. We select the learning rate from {5e−5, 1e−4} for all the models. The maximum training epoch is set to 6, and the we follow (Jiang et al., 2019) to set adversarial training hyperparameters: = 1e−5 and η = 1e−4. In our experiments, we simply set α = 1 in Eq 1.

Standard GLUE Evaluation
We first compare adversarial training methods on the standard GLUE benchmark (Wang et al., 2018). See Table 1 for the results 1 . TAT consistently outperforms both standard training and the state-ofthe-art adversarial training methods of FreeLB and VAT. Remarkably, BERT BASE with targeted adversarial training performs on par with BERT LARGE with standard training overall, and outperforms the latter by a large margin on tasks with smaller datasets such as RTE, MRPC and STS-B, which illustrates the benefit of TAT in improving model generalizability.   (Romanov and Shivade, 2018). See Table 2 for the results. TAT substantially outperforms standard training and state-of-the-art adversarial training methods. Interestingly, the gains are particularly pronounced on the two hardest datasets, HANS and MedNLI. HANS used heuristic rules to identify easy instances for MNLI-trained BERT models and introduced modifications to make them harder. MedNLI is from the biomedical domain, which is substantially different from the general domain of MNLI. This provides additional evidence that targeted adversarial training is especially effective in enhancing generalizability in out domains.

Zero-Shot Learning on Cross-Lingual Natural Language Inference
We also conducted zero-shot evaluation in the crosslingual setting by comparing standard and adver-  Table 3 for the results. Targeted adversarial training (TAT) demonstrates a clear advantage in improving zeroshot transfer learning across languages, especially for languages most different from English, such as Urdu. Overall, TAT produces a new state-of-the-art result of 81.7% over 15 languages on XNLI. As we have seen in Figure 1 earlier, TAT reduces the errors across the board on MNLI development set. To understand how TAT improves performance, we conducted a more detailed analysis by subdividing the dataset based on the degree of human agreement. Here, there are three label classes and each sample instance has 5 human annotations. The samples can be divided into four categories: 5-0-0, 4-1-0, 3-2-0, 3-1-1. E.g., 3-1-1 signifies that there are three votes for one label and one for each of the other two labels. In Figure 2, we see that TAT outperforms the baseline consistently over all categories, with higher improvement on the more ambiguous samples, especially for out-domain samples. This suggests that TAT is most helpful for the challenging instances that exhibit higher ambiguity and are more different from training examples.

Analysis
We also visualize the loss landscape of both the standard training and TAT, shown in Figure 3. TAT   has a wider and flatter loss surface, which generally indicates better generalization (Hochreiter and Schmidhuber, 1997;Hao et al., 2019;Li et al., 2018).

Conclusion
We present the first study to apply targeted attacks in adversarial training for natural language under-standing. Our TAT algorithm is simple yet effective in improving model generalizability for various NLP tasks, especially in zero-shot learning and for out-domain data. Future directions include: applying TAT in pretraining and other NLP tasks e.g., sequence labeling, exploring alternative approaches for target sampling.
In the experiments, GLUE is used for the normal setting, while the other datasets are used for the zero-shot setting.
• SNLI. The Stanford Natural Language Inference (SNLI) dataset contains 570k human annotated sentence pairs, in which the premises are drawn from the captions of the Flickr30 corpus and hypotheses are manually annotated (Bowman et al., 2015). This is the most widely used entailment dataset for NLI.
• SciTail. This is a textual entailment dataset derived from a science question answering (SciQ) dataset (Khot et al., 2018). The task involves assessing whether a given premise entails a given hypothesis. In contrast to other entailment datasets mentioned previously, the hypotheses in SciTail are created from science questions while the corresponding answer candidates and premises come from relevant web sentences retrieved from a large corpus. As a result, these sentences are linguistically challenging and the lexical similarity of premise and hypothesis is often high, thus making SciTail particularly difficult.
• MedNLI. This is a textual entailment dataset in the clinical domain. It was derived from medical history of patients and annotated by doctors. The task involves assessing whether a given premise entails a given hypothesis. The hypothesis sentences in this dataset were generated by clinicians, while corresponding answer candidates and premises come from MIMIC-III v1.3 (Johnson et al., 2016), a database containing 2,078,705 clinical notes written by healthcare professionals. Its specialized domain nature makes MedNLI a challenging dataset.
• HANS. This is an NLI evaluation set that tests three hypotheses about invalid heuristics that NLI models are likely to learn: lexical overlap (assume that a premise entails all hypotheses constructed from words in the premise), subsequence (assume that a premise entails all of its contiguous subsequences), and constituent. HANS is a challenging dataset that aims to test how much models are vulnerable to such heuristics, and standard training often results in models failing catastrophically, even models such as BERT (McCoy et al., 2019).
• XNLI. This is a cross-lingual natural language inference dataset built by extending the development and test sets of the Multi-Genre Natural Language Inference Corpus (Williams et al., 2018) to 15 languages, including low-resource languages such as Swahili. This corpus was designed to evaluate cross-language sentence understanding, where models are supposed to be trained in one language and tested in different ones. Validation and test sets are translated from English to 14 languages by professional translators, making results across different languages directly comparable (Artetxe and Schwenk, 2019).