DeFormer: Decomposing Pre-trained Transformers for Faster Question Answering

Transformer-based QA models use input-wide self-attention – i.e. across both the question and the input passage – at all layers, causing them to be slow and memory-intensive. It turns out that we can get by without input-wide self-attention at all layers, especially in the lower layers. We introduce DeFormer, a decomposed transformer, which substitutes the full self-attention with question-wide and passage-wide self-attentions in the lower layers. This allows for question-independent processing of the input text representations, which in turn enables pre-computing passage representations reducing runtime compute drastically. Furthermore, because DeFormer is largely similar to the original model, we can initialize DeFormer with the pre-training weights of a standard transformer, and directly fine-tune on the target QA dataset. We show DeFormer versions of BERT and XLNet can be used to speed up QA by over 4.3x and with simple distillation-based losses they incur only a 1% drop in accuracy. We open source the code at https://github.com/StonyBrookNLP/deformer.


Introduction
There is an increasing need to push question answering (QA) models in large volume web scale services  and also to push them to resource constrained mobile devices for privacy and other performance reasons (Cao et al., 2019). Stateof-the-art QA systems, like many other NLP applications, are built using large pre-trained Transformers (e.g., BERT (Devlin et al., 2019), XLNet (Yang et al., 2019), Roberta ). However, inference in these models requires prohibitively high-levels of runtime compute and memory making it expensive to support large volume deployments in data centers and infeasible to run on resource constrained mobile devices.
Our goal is to take pre-trained Transformerbased models and modify them to enable faster inference for QA without having to repeat the pretraining. This is a critical requirement if we want to explore many points in the accuracy versus speed trade-off because pre-training is expensive. The main compute bottleneck in Transformerbased models is the input-wide self-attention computation at each layer. In reading comprehension style QA, this amounts to computing self-attention over the question and the context text together. This helps the models create highly effective questiondependent context representations and vice-versa. Of these, building representations of the context takes more time because it is typically much longer than the question. If the context can be processed independent of the question, then this expensive compute can be pushed offline saving significant runtime latency.
Can we process the context independent of the question, at least in some of the layers, without too much loss in effectiveness? There are two empirical observations that indicate that this is possible. First, previous studies have demonstrated that lower layers tend to focus on local phenomena such as syntactic aspects, while the higher layers focus on global (long distance) phenomena such as semantic aspects relevant for the target task (Tenney et al., 2019;Hao et al., 2019;Clark et al., 2019b). Second, as we show later (see Section 2), in a standard BERT-based QA model, there is less variance in the lower layer representations of text when we vary the question. This means that in the lower layers information from the question is not as critical to form text representations. Together, these suggest that considering only local context in lower layers of Transformer and considering full global context in upper layers can provide speedup at a very small cost in terms of effectiveness.
Based on these observations, we introduce De-Former a simple decomposition of pre-trained Transformer-based models, where lower layers in the decomposed model process the question and context text independently and the higher layers process them jointly (see Figure 1 for a schematic illustration). Suppose we allow k lower layers in a n-layer model to process the question and context text independently. DeFormer processes the context texts through k lower layers offline and caches the output from the k-th layer. During runtime the question is first processed through the k-layers of the model, and the text representation for the k-th layer is loaded from the cache. These two k-th layer representations are fed to the (k + 1)-th layer as input and further processing continues through the higher layers as in the original model. In addition to directly reducing the amount of runtime compute, this also reduces memory significantly as the intermediate text representations for the context are no longer held in memory.
A key strength of this approach is that one can make any pre-trained Transformer-based QA model faster by creating a corresponding DeFormer version that is directly fine-tuned on the target QA datasets without having to repeat the expensive pre-training. Our empirical evaluation on multiple QA datasets show that with direct fine-tuning the decomposed model incurs only a small loss in accuracy compared to the full model. This loss in accuracy can be reduced further by learning from the original model. We want De-Former to behave more like the original model. In particular, the upper layers of DeFormer should produce representations that capture the same kinds of information as the corresponding layers in the original model. We add two distillation-like auxiliary losses (Hinton et al., 2015), which minimize the output-level and the layer-level divergences between the decomposed and original models.
We evaluate DeFormer versions of two transformer-based models, BERT and XLNet on three different QA tasks and two sentence-sentence paired-input tasks 1 . DeFormer achieves substantial speedup (2.7 to 4.3x) and reduction in memory (65.8% to 72.9%) for only small loss in effectiveness (0.6 to 1.8 points) for QA. Moreover, we find that DeFormer version of BERT-large is faster than the original version of the smaller BERTbase model, while still being more accurate. Ablations shows that the supervision strategies we introduce provide valuable accuracy improvements and further analysis illustrate that DeFormer provides good runtime vs accuracy trade-offs.

Decomposing Transformers for Faster Inference
The standard approach to using transformers for question answering is to compute the self-attention over both question and the input text (typically a passage). This yields highly effective representations of the input pair since often what information to extract from the text depends on the question and vice versa. If we want to reduce complexity, one natural question to ask is whether we can decompose the Transformer function over each segment of the input, trading some representational power for gains in ability to push processing the text segment offline.
The trade-off depends on how important it is to have attention from question tokens when forming text representations (and vice versa) in the lower layers. To assess this, we measured how the text representation changes when paired with different questions. In particular, we computed the average passage representation variance when paired with different questions. The variance is measured using cosine distance between the passage vectors and their centroid. As Figure 2 shows that in the lower layers, the text representation does not change as much as it does in the upper layers, suggesting ignoring attention from question tokens in lower layers may not be a bad idea. This is also in agreement with results on probing tasks which suggest that lower layers tend to model mostly local phenomena (e.g., POS, syntactic categories), while higher layers tend to model more semantic phenomena that are task dependent (e.g, entity co-reference) relying on wider contexts. We define the representation variance as the average cosine distance from the centroid to all representation vectors. In this figure, the variance is averaged for 100 paragraphs (each paired with 5 different questions) and normalized to [0, 1]. Smaller variance in the lower layers indicates the passage representation depends less on the question, while higher variance in the upper layers shows the passage representation relies more on the interaction with the question.
Here we formally describe our approach for decomposing attention in the lower layers to allow question independent processing of the contexts.

DeFormer
First, we formally define the computation of a Transformer for a paired-task containing two segments of text, T a and T b . Let the token embedding representations of segment T a be A = [a 1 ; a 2 ; ...; a q ] and of The full input sequence X can be expressed by concatenating the token representations from segment T a and T b as X = [A; B]. The Transformer encoder has n layers (denoted L i for layer i), which transform this input sequentially: X l+1 = L i (X l ). For the details of the Transformer layer, we refer the reader to (Vaswani et al., 2017). We denote the application of a stack of layers from layer i to layer j be denoted as L i:j . The output representations of the full Transformer, A n and B n can be written as: (1) Figure 3 shows a schematic of our model. We decompose the computation of lower layers (up to layer k) by simply removing the cross-interactions between T a and T b representations. Here k is a hyper-parameter. The output representations of the decomposed Transformer, A n and B n can be expressed as: Transformer-based QA systems process the input question and context together through a stack of self-attention layers. So applying this decomposition to Transformer for QA allows us to process the question and the context text independently, which in turn allows us to compute the context text representations for lower layers offline. With this change the runtime complexity of each lower layer is reduced from O((p + q) 2 ) to O(q 2 + c), where c denotes cost of loading the cached representation.

Auxiliary Supervision for DeFormer
DeFormer can be used in the same way as the original Transformer. Since DeFormer retains much of the original structure, we can initialize this model with the pre-trained weights of the original Transformer and fine-tune directly on downstream tasks. However, DeFormer looses some information in the representations of the lower layers. The upper layers can learn to compensate for this during finetuning. However, we can go further and use the original model behavior as an additional source of supervision.
Towards this end, we first initialize the parameters of DeFormer with the parameters of a pretrained full Transformer, and fine-tune it on the downstream tasks. We also add auxiliary losses that make DeFormer predictions and its upper layer representations closer to the predictions and corresponding layer representations of the full Transformer.
Knowledge Distillation Loss: We want the prediction distribution of DeFormer to be closer to that of the full Transformer. We minimize the Kullback-Leibler divergence between decomposed Transformer prediction distribution P A and full Transformer prediction distribution P B :

Layerwise Representation Similarity Loss:
We want the upper layer representations of De-Former to be closer to those of full Transformer. We minimize the euclidean distance between token representations of the upper layers of decomposed Transformer and the full Transformer. Let v j i be the representation of the j th token in the i th layer in the full transformer, and let u j i be the corresponding representation in DeFormer. For each of the upper layers k + 1 through n, we compute a layerwise representation similarity (lrs) loss as follows: We add the knowledge distillation loss (L kd ) and layerwise representation similarity loss (L lrs ) along with the task specific supervision Loss (L ts ) and learn their relative importance via hyperparameter tuning: We use Bayesian Optimization (Močkus, 1975) to tune the γ, α and β instead of simple trial-anderror or grid/random search. This is aimed at reducing the number of steps required to find a combination of hyper-parameters that are close to the optimal one.

Datasets
We use the pre-trained uncased BERT base and large 2 models on five different paired-input problems covering 3 QA tasks, and in addition two other sentence-sentence tasks 3 .
2 Whole Word Masking version 3 We pick these as additional datasets to show the utility of decomposition in other information seeking applications SQuAD v1.1 (Stanford Question Answering Dataset) (Rajpurkar et al., 2016) is an extractive question answering datasets containing >100,000 question and answer pairs generated by crowd workers on Wikipedia articles. RACE (Lai et al., 2017) is reading comprehension dataset collected from the English exams that are designed to evaluate the reading and reasoning ability of middle and high school Chinese students. It has over 28,000 passages and 100,000+ questions. BoolQ (Clark et al., 2019a) consists of 15942 yes/no questions that are naturally occurring in unprompted and unconstrained settings. MNLI (Multi-Genre Natural Language Inference) (Williams et al., 2018) is a crowd-sourced corpus of 433k sentence pairs annotated with textual entailment information. QQP (Quora Question Pairs) (Iyer et al., 2019) consists of over 400,000 potential duplicate question pairs from Quora.
For all 5 tasks, we use the standard splits provided with the datasets but in addition divide the original training data further to obtain a 10% split to use for tuning hyper-parameters (tune split), and use the original development split for reporting efficiency (FLOPs, memory usage) and effectiveness similar to QA, where one of the inputs can be assumed to be available offline. For instance, we may want to find answer (premise) sentences from a collection that support information contained in a query (hypothesis) sentence. Another use case is FAQ retrieval, where a user question is compared against a collection of previously asked questions. metrics (accuracy or F1 depending on the task).

Implementation Details
We implement all models in TensorFlow 1.15 (Abadi et al., 2015) based on the original BERT (Devlin et al., 2019) and the XLNet (Yang et al., 2019) codebases. We perform all experiments on one TPU v3-8 node (8 cores, 128GB memory) with bfloat16 format enabled. We measure the FLOPs and memory consumption through the TensorFlow Profiler 4 . For DeFormer models, we tune the hyperparameters for weighting different losses using bayesian optimizaiton libray (Nogueira, Fernando, 2019) with 50 iterations on the tune split (10% of the original training sets) and report the performance numbers on the original dev sets. The search range is [0.1, 2.0] for the 3 hyper-parameters. We put the detail hyper-parameters in the section A.
For DeFormer-BERT and DeFormer-XLNet, we compute the representations for one of the input segments offline and cache it. For QA we cache the passages, for natural language inference, we cache the premise 5 and for question similarity we cache the first question 6 . Table 1 shows the main results comparing performance, inference speed and memory requirements of BERT-base and DeFormer-BERT-base when using nine lower layers, and three upper layers (see Subsection 3.4 for the impact of the choice of upper/lower splits). We observe a substantial speedup and significant memory reduction in all the datasets, while retaining most of the original model's effectiveness (as much as 98.4% on SQuAD and 99.8% on QQP datasets), the results of XLNet in the same table demonstrates the decomposition effectiveness for different pre-trained Transformer architectures. Table 2 shows that the decomposition brings 2x speedup in inference and more than half of memory reduction on both QQP and MNLI datasets, which take pairwise input sequences. The effectiveness of decomposition generalizes further beyond QA tasks as long as the input sequences are paired. 4 https://www.tensorflow.org/versions/ r1.15/api_docs/python/tf/profiler/ profile 5 One use case is where we want to find (premise) sentences from a collection that support information contained in a query (hypothesis) sentence. 6 One use case is FAQ retrieval, where a user question is compared against a collection of previously asked questions Efficiency improvements increase with the size of the text segment that can be cached.

Results
Small Distilled or Large Decomposed? Table 3 compares performance, speed and memory of BERT-base, BERT-large and DeFormer-BERTlarge. DeFormer-BERT-large is 1.6 times faster than the smaller BERT-base model. Decomposing the larger model turns out to be also more effective than using the smaller base model (+2.3 points) This shows that with decomposition, a large Transformer can run faster than a smaller one which is half its size, while also being more accurate. Distilling a larger model into a smaller one can yield better accuracy than training a smaller model from scratch. As far as we know, there are two related but not fully comparable results. (1) Tang et al. (2019) distill BERT to a small LSTM based model where they achieve 15x speedup but at a significant drop in accuracy of more than 13 points on MNLI.
(2) Sanh et al. (2019) distill BERT to a smaller six layer Transformer, which can provide 1.6x speedup but gives >2 points accuracy drop on MNLI and >3 points F1 drop on SQuAD. A fair comparison requires more careful experimentation exploring different distillation sizes which requires repeating pre-training or data augmentation -an expensive proposition.
Device Results: To evaluate the impact on different devices, we deployed the models on three different machines (a GPU, CPU, and a mobile phone). Table 4 shows the average latency in answering a question measured on a subset of the SQuAD dataset. On all devices, we get more than three times speedup. Table 5 shows the contribution of auxiliary losses for fine-tuning DeFormer-BERT on SQuAD dataset. The drop in effectiveness when not using Layerwise Representation Similarity (LRS in table), and Knowlege Distillation (KD) losses shows the utility of auxiliary supervision. Figure 4a and figure 4b show how the effectiveness and inference speed of DeFormer-BERT changes as we change the separation layer. Inference speedup scales roughly quadratically with respect to the number of layers with decomposed attention. The drop in effectiveness, on the other hand, is negligible for separating at lower layers (until layer 3 for the base model and until layer 13 for the large model) and increases slowly after that  Table 1: (i) Performance of original fine-tuned vs fine-tuned models of DeFormer-BERT-base and DeFormer-XLNet-base, (ii) Performance drop, inference speedup and inference memory reduction of DeFormer-over original models for 3 QA tasks. DeFormer-BERT-base uses nine lower layers, and three upper layers with caching enabled, DeFormer-XLNet-base use eight lower layers, and four upper layers with caching enabled. For SQuAD and RACE we also train with the auxiliary losses, and for the others we use the main supervision loss -the settings that give the best effectiveness during training. Note that the choice of the loss doesn't affect the efficiency metrics.    Table 4: Inference latency (in seconds) on SQuAD datasets for BERT-base vs DeFormer-BERT-base, as an average measured in batch mode. On the GPU and CPU batch size is 32 and on the phone (marked by *) batch size is 1.

Ablation Study
with a dramatic increase in the last layers closest to the output. The separation layer choice thus allows trading effectiveness for inference speed.

Divergence of DeFormer and original BERT representations
The main difference between the original BERT and the DeFormer-BERT is the absence of cross  Table 5: Ablation analysis on SQuAD datasets for DeFormer-BERT-base and DeFormer-BERT-large models. LRS is the layerwise representation similarity loss. KD is the knowledge distillation loss on the prediction distributions.
attention in the lower layers. We analyze the differences between the representations of the two models across all layers. To this end, we randomly select 100 passages from SQuAD dev dataset as well as randomly selecting 5 different questions that already exist in the dataset associated with each passage. For each passage, we encode all 5 question-passage pair sequence using both the finetuned original BERT-base model and the DeFormer-   BERT-base model, and compute their distance of the vector representations at each layer. Figure 5 shows the averaged distances of both the question and passage at different layers. The lower layer representations of the passage and questions for both models remain similar but the upper layer representations differ significantly, supporting the idea that lack of cross-attention has less impact in the lower layers than in the higher ones. Also, using the auxiliary supervision of upper layers has the desired effect of forcing DeFormer to produce representations that are closer to the original model. This effect is less pronounced for the question representations.

Inference Cost
DeFormer enables caching of text representations that can be computed offline. While a full-scale analysis of the detailed trade-offs in storage versus latency is beyond the scope of this paper, we present a set of basic calculations to illustrate that the storage cost of caching can be substantially smaller compared to the inference cost. Assuming a use case of evaluating one million questionpassage pairs daily, we first compute the storage requirements of the representations of these passages. With the BERT-base representations we estimate this to be 226KB per passage and 226GB in total for 1 million passages. The cost of storing this data and the added compute costs and reading these passages at the current vendor rates amounts to a total of $61.7 dollars per month. To estimate inference cost, we use the compute times we obtain from our calculations and use current vendor rates for GPU workloads which amounts to $148.5 dollars to support the 1 million question-passage pair workload. The substantial reduction in cost is because the storage cost is many orders of magnitude cheaper than using GPUs. Details of these calculations are listed in the Appendix.

Related work
Speeding up inference in a model requires reducing the amount of compute involved. There are two broad related directions of prior work: (i) Compression techniques can be used to reduce model size through low rank approximation Kim et al., 2015;Tai et al., 2015;, and model weights pruning (Guo et al., 2016;Han et al., 2015), which have been shown to help speedup inference in CNN and RNN based models. For Transformers, Michel et al. (2019) explore pruning the attention heads to gain inference speedup. This is an orthogonal approach that can be combined with our decomposition idea. However, for the paired-input tasks we consider, pruning heads only provides limited speedup. In more recent work Ma et al. (2019) propose approximating the quadratic attention computation with a tensor decomposition based multilinear attention model. However, it is not clear how this multi-linear approximation can be applied to pre-trained Transformers like BERT.
(ii) Distillation techniques can be used to train smaller student networks to speedup inference. Tang et al. (2019) show that BERT can be used to guide designing smaller models (such as singlelayer BiLSTM) for multiple tasks. But for the tasks we study, such very small models suffer a significant performance drop. For instance there is a 13% accuracy degration on MNLI task. Another closely related recent work is DistillBERT (Sanh et al., 2019), which trains a smaller BERT model (half the size of BERT-base) that runs 1.5 times faster than the original BERT-base.However, the distilled model incurs a significant drop in accuracy. While more recent distillation works such as (Jiao et al., 2019) and (Sun et al., 2020) further improve the speedups, our decomposition also achieves similar accuracy performance. More importantly, this distillation model usually undergo expensive pretraining on the language modeling tasks before they can be fine-tuned for the downstream tasks.
Previous QA neural models like BIDAF (Seo et al., 2016), QANet (Yu et al., 2018) and many others contain decomposition as part of their neural architecture design. In contrast, the focus of our work is to show that large pre-trained Transformer models can be decomposed at the fine-tuning stage to bring effectiveness of SOTA pre-trained transformers at much lower inference latency.
In this work, we ask if can we speedup the inference of Transformer models without compressing or removing model parameters. Part of the massive success of pre-trained Transformer models for many NLP task is due to a large amount of parameters capacity to enable complex language representations. The decomposition we propose makes minimal changes retaining the overall capacity and structure of the original model but allows for faster inference by enabling parallel processing and caching of segments.
DeFormer applies to settings where the underlying model relies on input-wide self-attention layers. Even with models that propose alternate ways to improve efficiency, as long as the models use inputwide self-attention, DeFormer can be applied as a complementary mechanism to further improve inference efficiency. We leave an evaluation of applying DeFormer on top of other recent efficiency optimized models for future work.

Conclusion
Transformers have improved the effectiveness of NLP tools by their ability to incorporate large contexts effectively in multiple layers. This however imposes a significant complexity cost. In this work, we showed that modeling such large contexts may not always be necessary. We build a decomposition of the transformer model that provides substantial improvements in inference speed, memory reduction, while retaining most of the original model's accuracy. A key benefit of the model is that its architecture remains largely the same as the original model which allows us to avoid repeating pretraining and use the original model weights for finetuning. The distillation techniques further reduce the performance gap with respect to the original model. This decomposition model provides a simple yet strong starting point for efficient QA models as NLP moves towards increasingly larger models handling wider contexts.