TR-BERT: Dynamic Token Reduction for Accelerating BERT Inference

Existing pre-trained language models (PLMs) are often computationally expensive in inference, making them impractical in various resource-limited real-world applications. To address this issue, we propose a dynamic token reduction approach to accelerate PLMs’ inference, named TR-BERT, which could flexibly adapt the layer number of each token in inference to avoid redundant calculation. Specially, TR-BERT formulates the token reduction process as a multi-step token selection problem and automatically learns the selection strategy via reinforcement learning. The experimental results on several downstream NLP tasks show that TR-BERT is able to speed up BERT by 2-5 times to satisfy various performance demands. Moreover, TR-BERT can also achieve better performance with less computation in a suite of long-text tasks since its token-level layer number adaption greatly accelerates the self-attention operation in PLMs. The source code and experiment details of this paper can be obtained from https://github.com/thunlp/TR-BERT.


Introduction
Large-scale pre-trained language models (PLMs) such as BERT , XLNet (Yang et al., 2019) and RoBERTa  have shown great competence in learning contextual representation of text from large-scale corpora. With appropriate fine-tuning on labeled data, PLMs have achieved promising results on various NLP applications, such as natural language inference (Zhang et al., 2020b), text classification (Sun et al., 2019a) and question answering (Talmor and Berant, 2019).
Along with the significant performance improvements, PLMs usually have substantial computational cost and high inference latency, which presents challenges to their practicalities in resource-limited real-world applications, such as * Corresponding author: M. Sun (sms@tsinghua.edu.cn) real-time applications and hardware-constrained mobile applications. Even worse, these drawbacks become more severe in long-text scenarios because self-attention operation in PLMs scales quadratically with the sequence length. Therefore, researchers have made intensive efforts in PLM's inference acceleration recently. The mainstream approach is to reduce the layer number of PLMs such as knowledge distillation models (Sanh et al., 2019;Sun et al., 2019b), and adaptive inference models . Such layer-wise pruning reduces a tremendous amount of computation, but it sacrifices the models' capability in complex reasoning. Previous works (Sanh et al., 2019;Sun et al., 2019b) have found that the shallow model usually performs much worse on the relatively complicated question answering tasks than text classification tasks. It is straightforward that pruning the entire layer of PLMs may not be an optimal solution in all scenarios.
In this paper, we introduce a dynamic token reduction method TR-BERT to find out the wellencoded tokens in the layer-by-layer inference process, and save their computation in subsequent layers. The idea is inspired by recent findings that PLMs capture different information of words in different layers (e.g., BERT focuses on the word order information (Lin et al., 2019) in the bottom layers, obtains the syntactic information (Hewitt and Manning, 2019) in the middle layers, and computes the task-specific information in the top layers (Rogers et al., 2020)). Hence, we could adapt different tokens to different layers according to their specific roles in the context.
As shown in Figure 1, TR-BERT formulates the token reduction process as a multi-step selection problem. Specially, for each selection phase, TR-BERT finds out the words that require high-level semantic representations, and then selects them to higher layers. The main challenge in TR-BERT is how to determine each token's importance for text understanding in the token selection. It is highly task-dependent and requires to consider the correlation and redundancy among various tokens. TR-BERT employs the reinforcement learning (RL) method to learn the dynamic token selection strategy automatically. After the token reduction, the RL reward involves the confidence of the classifier's prediction based on the pruned network to reflect the quality of token selection. Moreover, we also add a penalty term about the number of selected tokens to the reward, by adjusting which, TR-BERT can utilize the different pruning intensities in response to various performance requirements. In TR-BERT, by selecting a few important tokens to go through the entire pipeline, the inference speed turns much faster and no longer grows quadratically with the sequence length.
We conduct experiments on eleven NLP benchmarks. Experimental results show that TR-BERT can accelerate BERT inference by 2-5 times to meet various performance demands, and significantly outperform previous baseline methods on question answering tasks. It verifies the effectiveness of the dynamic token reduction strategy. Moreover, benefiting from the long-distance token interaction, TR-BERT with 1,024 input length reaches higher performance with less inference time compared to the vanilla BERT in a suite of long-text tasks.

Background and Pilot Analysis
To investigate the potential impact of the token reduction in PLMs, we first introduce the Trans-former architecture. After that, we conduct pilot experiments as well as empirical analyses for the lower and upper bound of the token reduction in this section.
The Transformer architecture (Vaswani et al., 2017) has been widely adopted by the pre-trained language models (PLMs) for inheriting its high capacity. Basically, each Transformer layer wraps a Self-Attention module (Self-ATT) and a Feed-Forward-Network module (FFN) by the residual connection and layer normalization. Formally, given a sequence of n words, the hidden state of the i-th layer, H i = (h 1 , h 2 , . . . , h n ), is computed from the previous layer state: where i ∈ [1, L], L is the number of stacked Transformer layers, LN denotes the LayerNorm layer. For each Transformer layer, the complexity of the Self-Attention module scales quadratically with the sequence length. Therefore, the speed of Transformer architecture will decline heavily when the sequences become longer. Previous findings (Rogers et al., 2020) reveal that some words, such as function words, do not require high-layer modeling, since they store little information and have been well handled by PLMs in bottom layers. Hence, selecting only the important words for high-layer computation may be a possible way to accelerate the PLMs' inference. To verify this assumption, we conduct a theoretical token elimination experiment in question answering (on SQuAD 2.0 (Rajpurkar et al., 2018)) and text classification (on IMDB (Maas et al., 2011)). We use the full-layer representations for the selected tokens and the early-layer representation of the deleted tokens for the prediction. To be specific, we eliminate tokens immediately after the l=4th layer and adopt the following three strategies to select the retained tokens: Random Strategy (Lower Bound) selects tokens randomly, assuming that all tokens are equivalent for understanding.
Residual Strategy (Upper Bound) directly utilizes the model prediction of the original model to guide the token selection. Specially, we define a token's importance according to the influence on the model prediction when it's not selected. When substituting the r-th layer representation H r with the l-th layer representation H l (r > l) , we define the approximate variation to model loss as the token importance: I = ∂loss ∂Hr (H r − H l ). Here, we set r = 9 since other values get a little worse results. Note that we could not obtain the model loss in the prediction stage. Hence, the Residual Strategy could be viewed as an upper bound of token selection to some extent when we ignore the correlation and redundancy among the selected tokens.
Attention Strategy is adopted by PoWER-BERT (Goyal et al., 2020) and L-Adaptive (Kim and Cho, 2020). It accumulates the attention values from other tokens to a given token. It selects the tokens receiving the greatest attentions, considering them responsible for retaining and disseminating the primary information of the context. As shown in Figure 2, both Attention Strategy and Residual Strategy achieve considerable results, which demonstrates that to select important tokens is feasible for accelerating the inference of PLMs. Besides, the Residual Strategy outperforms the Attention strategies by a margin, especially at the low token remaining proportion (+31.8% F1 on SQuAD 2.0 and +9.5% accuracy on IMDB when selecting 10% tokens). It suggests that the accumulated attention values still cannot well reflect tokens' importance in text understanding, which requires further explorations.

Methodology
In this section, we present TR-BERT, which adopts a cascade token reduction to prune the BERT model at token-level granularity dynamically. In a onestep token reduction process, TR-BERT estimates the importance of each token, reserves the important ones, and delivers them to the higher layer. To better select important tokens for text understanding while satisfying various acceleration requirements, we employ the reinforcement learning (RL) method to automatically learn a dynamic token selection strategy. Figure 1 shows the model architecture of TR-BERT. To inherit the high capacity from the PLMs, TR-BERT keeps the same architecture as BERT. Differently, as the layer gets deeper, TR-BERT gradually shortens the sequence length via token reduction modules, aiming to reduce the computational redundancy of unimportant tokens.

Model Architecture
The token reduction modules are required to measure the importance of tokens and offer an integral selection scheme. Due to the lack of direct supervision, we employ the policy network for training the module, which adopts a stochastic policy and uses a delayed reward to guide the policy learning. In one-step reduction, we perform action sampling for the current sequence. The selected tokens are conveyed to the next Transformer layer for further computation. In contrast, the unselected tokens are terminated with their representation remaining unchanged. After all the actions are decided, we fetch each token's representation from the layer where it terminated, and compute the golden label's likelihood as a reward. To be specific, we introduce state, action, reward, and objective function as follows: State State s t consists of the token representations inherited from the previous layer before the t-th token reduction layer.
Action We adopt two alternative actions for each token, {Select, Skip}, where the token can be selected for further computation or be skipped to the final layer. We implement the policy network as a two-layer feed-forward network with GeLU activation (Hendrycks and Gimpel, 2017): (2) where a t denotes the action at state s t for sequence representation H st = {h 1 , h 2 , ..., h n } at t-th reduction, θ = {W 1 , W 2 , b 1 , b 2 } are trainable parameters, and σ(.) is sigmoid activation function. For the selected token set {t 1 , t 2 , ..., t n * }, where n * ≤ n, we conduct a Transformer layer operation on their corresponding representations: For the selected tokens, their representation H is conveyed to the next layer for further feature extraction and information aggregation. For the other skipped tokens, their representations in the current layer are regarded as their final representations.
Reward Aiming to select significant tokens for making a precise decision in the prediction layer, we adopt the likelihood of predicting the golden label as a reward. For example, when classifying the input sequence X, we use the models' predicting probability of the ground-truth label Y to reflect the quality of the token selection. In addition, to encourage the model to delete more redundant tokens for accelerating, we include an additional punitive term by counting the number of selected tokens. Hence, the overall reward R is defined as: where t |{a t = Select}| denotes the total number of the selected tokens in all token reduction modules, and λ is a harmonic coefficient to balance two reward terms.
Objective Function We optimize the policy network to maximize the expected reward. Formally, our objective function is defined as: where T is the number of states. According to the REINFORCE algorithm (Williams, 1992) and policy gradient method (Sutton et al., 1999), we update network with the policy gradient as below:

Model Training
Our policy network is integrated into the original Transformer network, and we train both of them simultaneously. The entire training process involves three steps: (1) Fine-tune the PLM model for downstream tasks with the task-specific objective; (2) Freeze all the parameters except that of the policy network, conduct reinforcement learning (RL), and update the policy network to learn token reduction strategy; (3) Unfreeze all parameters and train the entire network with the task-specific objective and RL objective simultaneously.
Due to the large searching space, RL learning is difficult to converge. We adopt imitation learning (Hussein et al., 2017) for warming up the training of the policy network. To be specific, in the RL training, we sample several action sequences via the policy network to compute rewards. And we guide the optimization direction by providing heuristic action sequences sampled by the Residual Strategy during the early training period, which could roughly select the most important tokens. The heuristic action sequence is defined as selecting the top K important tokens and skipping the others, where K is defined as the expected selected number of the current policy network. In our preliminary experiment, both the heuristic action sequence and expected selected number mechanism are beneficial to the stable training.
To further improve the performance of our pruned model, we also adopt Knowledge Distillation (KD) (Hinton et al., 2015) to transfer knowledge from the intact original fine-tuned model.

Complexity Analysis
For a Transformer layer with a hidden size of d and an input sequence of n tokens, the Self-Attention module consumes O(n 2 d) time and memory complexity while the Feed-Forward Network takes O(nd 2 ). That is, our token reduction gains nearlinear speedup when n is relatively smaller than d. Therefore, when the input sequence gets longer, such as up to 1,024 tokens, our method can enjoy a more effective speedup.
In the RL training, we compute loss on the pruned model, so the acceleration is still valid for this stage. Since we focus on accelerating BERT inference, we consider the extra training consumption on the pruned model is acceptable.

Experiment
In this section, we first introduce the baseline models and the evaluation datasets. After that, we verify the effectiveness of TR-BERT on eleven NLP benchmarks. Finally, we conduct a detailed analysis and case study on TR-BERT to investigate the selected tokens' characteristics.

Baselines
We adopt two pre-trained models and three pruned networks as our baselines for comparison: BERT ) is a Transformerbased pre-trained model. We use the BERT BASE model 1 , which consists of 12 Transformer layers and supports a maximum sequence length of 512.
BERT L is our implemented BERT, which can support input sequences with up to 1,024 tokens. We initialize the parameters of BERT L with that of BERT, where the additional position embedding is initialized with the first 512 ones. After that, we continue to train it on Wikipedia 2 for 22k steps.
DistilBERT (Sanh et al., 2019) is the most popular distilled version of BERT, which leverages the knowledge distillation to learn knowledge from the BERT model. We use the 6-layer DistilBERT released by Hugging Face 3 . In addition, we use the same method to distill BERT with 3 layers to obtain DistilBERT 3 .
DeFormer (Cao et al., 2020) is designed for question answering, which encodes questions and passages separately in lower layers. It precomputes all the passage representation and reuses them to speed up the inference. In our experiments, we do not count DeFormer's pre-computation.
PoWER-BERT (Goyal et al., 2020) is mainly designed for text classification, which also decreases the length of a sequence as layer increases. It adopts the Attention Strategy to measure the significance of each token and always selects tokens with the highest attention. Given a length penalty, PoWER-BERT searchs a fixed length pruning configuration for all examples.
DynaBERT  can not only adjust model's width by varying the number of attention heads, but also provide an adaptive layer depth to satisfy different requirements. For a given speed demand, we report its best performance with all the feasible width and depth combination options.

Datasets
To verify the effectiveness of reducing the sequence length, we evaluate TR-BERT on several tasks with relatively long context, including question answering and text classification. Table 1 shows the context length of these datasets. We adopt seven question-answering datasets, including SQuAD 2.0 (Rajpurkar et al., 2018), NewsQA (Trischler et al., 2017), NaturalQA (Kwiatkowski et al., 2019), RACE (Lai et al., 2017), HotpotQA (Yang et al., 2018), TriviaQA (Joshi et al., 2017) and Wiki-Hop (Welbl et al., 2018). And we also evaluate models on four text classification datasets, including YELP.F (Zhang et al., 2015), IMDB (Maas et al., 2011), 20NewsGroups (20News.) (Lang, 1995), and Hyperpartisan (Hyperp.) (Kiesel et al., 2019). Among them, HotpotQA, TriviaQA and WikiHop possess abundant contexts for reading, while the performance of question answering (QA) models heavily relys on the amount of text they read. To fairly compare BERT and BERT L , we split the context into slices and apply a sharednormalization training objective (Clark and Gardner, 2018) to produce a global answer candidate comparison across different slices for the former two extractive QA datasets. And we average the candidate scores in all slices for WikiHop. Details of all datasets are shown in the Appendix.

Experimental Settings
We adopt a maximum input sequence length of 384 for SQuAD 2.0, 1,024 for long-text tasks and 512 for others. We use the Adam optimizer (Kingma and Ba, 2015) to train all models. The detailed training configuration is shown in the Appendix.
For the RL training, we sample 8 action sequences each time and average their rewards as the reward baseline. In the second training process which aims to warm up the policy network, we employ 20% imitation learning steps for question answering tasks and 50% steps for text classification tasks. We search the number of token reduction module T ∈ [1, 2, 3]. And we find the models with T = 2 gets similar quality and speed trade-offs as the models with T = 3, and both of them perform better than models with T = 1. Thus we adopt T = 2 for simplification. We denote the pruned models from BERT, BERT L and DistilBERT 6 as TR-BERT 12 , TR-BERT L , TR-BERT 6 , respectively. For BERT and BERT L , we attach the token reduction modules before the second and the sixth layers.   For DistilBERT 6 , we insert the token reduction modules before the second and the fourth layers.
To avoid the pseudo improvement by pruning padding for TR-BERT, we evaluate all models with input sequences without padding to the maximum length. For each dataset, we report the F1 scores or accuracy (Acc.), and the FLOPs speedup ratio compared to the BERT model. The model's FLOPs are consistent in the various operating environment. Therefore, it is convenient to estimate and compare the models' inference time by FLOPs.

Overall Results
The comparison between TR-BERT and the baselines are shown in Table 2 and Figure 3. We adjust the length penalty coefficient of TR-BERT for an intuitional comparison. From the experimental results, we have the following observations: (1) TR-BERT 12 achieves higher performance while using less computation on all span-extraction QA datasets compared to all the baselines. For example, TR-BERT 12 outperforms DynaBERT by 1.8 F1 with faster speed. TR-BERT 12 even achieves better performance than BERT at low speedup rate, which demonstrates that discarding some redundant information in the top layer helps to find the correct answer. For multiple-choice RACE, TR-BERT 12 achieves better performance than DeFormer while doesn't need to pre-compute the passage representation.
(2) TR-BERT 6 performs better than PoWER-BERT by a margin in text classification tasks. It shows that the fixed pruning configuration and the attention-based selection strategy adopted by PoWER-BERT may not be flexible to accelerate inference for various input sequences. In contrast, 1.0x 1.5x 2.0x 2.5x 3.0x 3.5x 4.0x 4.5x 5.0x   our dynamic token selection can automatically determine the proper pruning length and tokens for each example according to the actual situation, which leads to a more effective model acceleration.
Overall, TR-BERT retains most of BERT's performance though it omits lots of token interactions in the top layers. It shows that TR-BERT learns a satisfactory token selection strategy through reinforcement learning, and could effectively reduce the redundant computation of tokens that have been extracted enough information in the bottom layers.

Fuse Layer-wise and Token-wise Pruning
Since layer-wise pruning and token-wise pruning are compatible, we also explore the incorporation of these two pruning strategies. We apply our dynamic token reduction on the 6-layer DistilBERT to obtain TR-BERT 6 . The trade-off comparison of  TR-BERT 12 and TR-BERT 6 is shown in Figure 3, from which we have the following findings: (1) In general, as the speedup ratio increases, the performance of all models decrease, which indicates that retaining more token information usually results in a more potent model.
(2) TR-BERT 6 consistently outperforms TR-BERT 12 on all tasks at a high speedup ratio. In this situation, the budget doesn't allow enough tokens to go through the top layers. TR-BERT 6 makes a more elaborate pruning than TR-BERT 12 at bottom layers to obtain a better effectiveness.
(3) At low speedup ratio, TR-BERT 12 performs better than TR-BERT 6 on the question answering tasks, but worse on the text classification tasks. In general, a deep Transformer architecture can offer multi-turn feature extraction and information propagation, which can meet the complex reasoning requirements for question answering. In contrast, the result of text classification usually depends on the keywords in the context, for which a shallow model is an affordable solution. To obtain a better trade-off, we can flexibly employ a deep and narrow model for question answering and a shallow and wide model for text classification.

Results on Long-text Tasks
With token pruning, TR-BERT is able to process a longer sequence. We apply our dynamic token pruning strategy on BERT L , which can process sequence with up to 1,024 tokens, to obtain TR-BERT L , and conduct experiments on four datasets with longer documents, including HotpotQA, Trivi-aQA, WikiHop and Hyperparisan. Results on longtext tasks are shown in Table 3, from which we have the following observations: (1) BERT L achieves better performance than BERT, especially on HotpotQA and WikiHop, which require the long-range multi-hop reasoning; (2) Compared to the vanilla BERT, TR-BERT L achieves 8.2% F1 improvement with 1.56x speedup on HotpotQA, obtains 1.7% F1 improvement with 1.24x speedup on TriviaQA, gains 4.65x speedup on WikiHop and 1.96x speedup on Hyperparisan without performance drops. Compared to BERT which can only deal with up to 512 tokens at a time, BERT L considers a longer-range token interaction and obtains a more complete reasoning chain. However, the running time of BERT L also increase as the input sequence's length extends, which poses a challenge to the utilization of longer text. TR-BERT L inherits the broader view from BERT L to get a better performance with a faster inference. Moreover, the inference acceleration effect of TR-BERT L is relatively better than TR-BERT within 512 tokens, which is coincident to the above complexity analysis section. With a longer sequence, TR-BERT can achieve extra speedup , because it significantly saves the time of the Self-Attention module, which demonstrates that TR-BERT can be further applied to process much longer tokens with limited computation.

Case Study
To investigate the characteristics of the selected tokens, we conduct a detailed case study on various datasets. As shown in Table 4, TR-BERT chooses to abandon the function word, such as the, and, with, in the first token reduction module as the first module is placed at the bottom layer of BERT. The second token reduction module is placed at the middle layer of BERT, and we could observe that it is used to retaining task-specific tokens. In the first example about question answering, the second token reduction module maintains the whole question and the question-related tokens from the context for further propagating messages. In the second and third examples about movie review sentimental classification, the second token reduction module chooses to select sentimental words, such as great, excited, disappointed to determine whether the given sequence is positive or negative.
Although we train the token reduction module without direct human annotations, TR-BERT can remain the meaningful tokens in the bottom layer and select the higher layer's task-relevant tokens. It demonstrates that the pruned network's groundtruth probability is an effective signal to facilitate the reinforcement learning for token selection.

Related Work
Researchers have made various attempts to accelerate the inference of PLMs, such as quantization (Shen et al., 2020;Zhang et al., 2020a), attention head pruning (Michel et al., 2019;, dimension reduction (Sun et al., 2020;Chen et al., 2020), and layer reduction (Sanh et al., 2019;Sun et al., 2019b;Jiao et al., 2019). In current studies, one of the mainstream methods is to dynamically select the layer number of Transformer layers to make a on-demand lighter model (Fan et al., 2020;. However, these methods operate at the whole text and they cannot perform pruning operations in a smaller granularity, such as the token-level granularity. To consider the deficiencies of layer-level pruning methods, researchers decide to seek solutions from a more meticulous perspective by developing methods to extend or accelerate the selfattention mechanism of the Transformer. For example, Sparse Trasformer (Child et al., 2019), Long-Former (Beltagy et al., 2020) and Big Bird (Zaheer et al., 2020) employ the sparse attention to allow model to handle long sequences. However, these methods only reduce the CUDA memory but cannot be not faster than the full attention. Besides, researchers also explore the feasibility of reducing the number of involved tokens. For example, Funnel-Transformer (Dai et al., 2020) reduces the sequence length with pooling for less computation, and finally up-samples it to the full-length representation. Universal Transformer (Dehghani et al., 2019) builds a self-attentive recurrent sequence model, where each token uses a dynamic halting layer. And DynSAN (Zhuang and Wang, 2019) ap-plies a gate mechanism to measure the importance of tokens for selection. Spurred by these attempts and positive results, we introduce TR-BERT in this study, which can creatively prune the network at the token level. To be specific, our work aims to accelerate the Transformer by deleting tokens gradually as the layer gets deeper. Compared with these models, TR-BERT is easy to adapt to the current PLMs models without a significant amount of pretraining and is flexible to adjust the model speed according to different performance requirements.
The main idea of TR-BERT is to select essential elements and infuse more computation on them, which is widely adopted in various NLP tasks. ID-LSTM  selects important and task-relevant words to build sentence representation for text classification. SR-MRS (Nie et al., 2019) retrieves the question-related sentences to reduce the size of reading materials for question answering. TR-BERT can be viewed as a unified framework on the Transformer for the important element selection, which can be easy to be applied in wide-range tasks.

Conclusion and Future Work
In this paper, we propose a novel method for accelerating BERT inference, called TR-BERT, which prunes BERT at token-level granularity. Specifically, TR-BERT utilizes reinforcement learning to learn a token selection policy, which is able to select general meaningful tokens in the bottom layers and select task-relevant tokens in the top layers. Experiments on eleven NLP tasks demonstrate the effectiveness of TR-BERT as it accelerates BERT inference by 2-5 times for various performance demand. Besides, TR-BERT achieves a better quality and speed trade-off on long-text tasks, which shows its potential to process large amounts of information in the real-world applications.
In the future, we would like to attempting to apply TR-BERT in the pre-training process of PLMs. Through the automatically learned token reduction module, it is possible to reveal how BERT stores syntactic and semantic information in various tokens and different layers. And it's also worth speeding up the time-consuming pre-training process.