A Global Past-Future Early Exit Method for Accelerating Inference of Pre-trained Language Models

Early exit mechanism aims to accelerate the inference speed of large-scale pre-trained language models. The essential idea is to exit early without passing through all the inference layers at the inference stage. To make accurate predictions for downstream tasks, the hierarchical linguistic information embedded in all layers should be jointly considered. However, much of the research up to now has been limited to use local representations of the exit layer. Such treatment inevitably loses information of the unused past layers as well as the high-level features embedded in future layers, leading to sub-optimal performance. To address this issue, we propose a novel Past-Future method to make comprehensive predictions from a global perspective. We first take into consideration all the linguistic information embedded in the past layers and then take a further step to engage the future information which is originally inaccessible for predictions. Extensive experiments demonstrate that our method outperforms previous early exit methods by a large margin, yielding better and robust performance.


Introduction
Pre-trained language models (PLMs), e.g., BERT (Devlin et al., 2019), RoBERTa  and XLNet (Yang et al., 2019), have obtained remarkable success in a wide range of NLP tasks. Despite their impressive performance, PLMs are usually associated with large memory requirement and high computational cost. Such drawbacks slow down the inference and further encumber the application of PLMs in the scenarios where inference time and computation budget are restricted.
To address this issue, a growing number of studies focusing on improving model efficiency have * Equal contribution 1 The code is available at https://github.com/ lancopku/Early-Exit emerged recently. Particularly, Kaya et al. (2019) point out that the current over-parameterized models conduct excessive computation for simple instances, which is actually undesirable and computationally wasteful. In light of this observation, an increasing amount of work seeks various early exit methods, of which the basic idea is to exit early without passing through the entire model during inference. Concretely, for NLP tasks, they couple branch classifiers with each layer of the pre-trained language models and stop forward propagation at an intermediate layer. Then the current branch classifier makes a prediction based on the representation of the token that is used as the aggregated sequence representation for classification tasks and is referred to as the state of the layer in this work.
However, existing work on early exit has two major drawbacks. First, existing work (Xin et al., 2020; uses only local states in the early exit framework. They inevitably lose valuable features that are captured by passed layers but are ignored for prediction, leading to less reliable prediction results. Moreover, these methods abandon the potentially useful features captured by the future layers that have not been passed, which may hurt the performance of the instances requiring high-level features embedded in the deep layers. Consequently, their performance dramatically declines when the inference exits earlier for a higher speed-up ratio. These two major drawbacks hinder the progress of early exit research and motivate us to develop a new mechanism using the hierarchical linguistic information embedded in all layers (Jawahar et al., 2019) from a global perspective. However, up to now, a global early exit mechanism remains a under-explored challenging problem. We extend the existing methods to their corresponding global versions and find that naive global strategies only result in poor performance. Meanwhile, the future states are originally inaccessible in the early exit framework, which also remains a bottleneck for a global prediction considering both past and future states.
In this paper, we focus on the aforementioned problems and first put into practice a global Past-Future early exit mechanism. The term global is two-fold: (1) instead of using one or several local state(s) for prediction in previous work, all the available past states are effectively incorporated in our method; (2) furthermore, to grasp the features embedded in the deep layers, the originally inaccessible future states are approximated by imitation learning and are also engaged for prediction. The comparison of the previous method and our method is illustrated in Figure 1. By combining both past and future states, our model is able to make more accurate predictions for downstream tasks.
Extensive experiments reveal that the proposal significantly outperforms previous early exit methods. Particularly, it surpasses the previous methods by a large margin when the speed-up ratio is relatively high. In addition, extensive experiments with different pre-trained language models as backbones demonstrate consistent improvement over the baseline methods, which verifies the generality of our method.
To summarize, our contributions are as follows: • We propose a set of global strategies which effectively incorporate all available states and they achieve better performance compared to the existing naive global strategies.
• Our early exit method first utilizes the future states which are originally inaccessible at the inference stage, enabling more comprehensive global predictions.
• Experiments show that our proposal achieves better performance compared to the previous state-of-the-art early exit methods.

Related Work
Large-scale pre-trained language models (Devlin et al., 2019; based on the Transformer (Vaswani et al., 2017) architecture demonstrate superior performance in various NLP tasks. However, the impressive performance is on the basis of massive parameters, leading to large memory requirement and computational cost during inference. To overcome this bottleneck, increasing studies work on improving the efficiency of overparameterized pre-trained language models. Knowledge distillation (Hinton et al., 2015;Turc et al., 2019;Jiao et al., 2019;Li et al., 2020a) compacts the model architecture to obtain a smaller model that remains static for all instances at the inference stage. Sanh et al. (2019) focus on reducing the number of layers since their investigation reveals variations on hidden size dimension have a smaller impact on computation efficiency. Sun et al. (2019) learn from multiple intermediate layers of the teacher model for incremental knowledge extraction instead of only learning from the last hidden representations. Further, Wang et al. (2020) design elaborate techniques to drive the student model to mimic the self-attention module of teacher models.  compress model by progressive module replacing, showing a new perspective of model compression. However, these static model compression methods treat the instances requiring different computational cost without distinction. Moreover, they have to distill a model from scratch to meet the varying speed-up ratio requirements.
To meet different constraints for acceleration, another line of work studies instance-adaptive methods to adjust the number of executed layers for different instances. Li et al. (2020b) select models in different sizes depending on the difficulty of input instance. Besides, early exit is a practical method to adaptively accelerate inference and is first proposed for computer vision tasks (Kaya et al., 2019;Teerapittayanon et al., 2016). Elbayad et al. (2020); Xin et al. (2020); Schwartz et al. (2020) follow the essential idea and leverage the method in NLP tasks. To prevent the error from one single classifier,  make the model stop inference when a cross-layer consistent prediction is achieved. However, researches on the subject has been mostly restricted to only use the local states around the exit layer.

Method
We first introduce the strategies to incorporate multiple states and the imitation learning method for generating approximations of future states. Then we introduce the merging gate to adaptively fuse past and future states. At last, we show the training process and the exit condition during inference.

Incorporation of Past States
Existing work (Xin et al., 2020) focuses on making exit decision based on a single branch classifier. The consequent unreliable result motivates the recent advance  that uses consecutive states to improve the accuracy and robustness. However, the model prediction is still limited to use several local states. In contrast, we investigate how to incorporate all the past states from a global perspective. The existing strategy using consecutive consistent prediction labels can be easily extended to a global version that counts the majority of the predicted labels which is regarded as a voting strategy. Another alternative is the commonly-used ensemble strategy that averages the output probabilities for prediction. Besides these naive solutions, we explore the following strategies to integrate multiple states into a single one: • Max-Pooling: The max-pooling operation is performed on all available states, resulting in the integrated state.
• Avg-Pooling: The average-pooling operation is performed on all available states, resulting in the integrated state.
• Attn-Pooling: The attentive-pooling takes the weighted summation of all available states as the integrated state. The attention weights are computed with the last state as the query.
• Concatenation: All available states are concatenated and then fed into a linear transformation layer to obtain the compressed state.
• Sequential Neural Network: All available states are sequentially fed into an LSTM and the hidden output of the last time-step is regarded as the integrated state.
Formally, the state of the i-th layer is denoted as s i . When forward propagation proceeds to the i-th intermediate layer, all the past states s 1:i are incorporated into a global past state s p : where G(·) refers to one of the state incorporation strategies.

Imitation of Future States
Existing work for early exit stops inference at an intermediate layer and ignores the underlying valuable features captured by the future layers. Such treatment is partly rationalized by the recent claim (Kaya et al., 2019) that shallow layers are adequate to make a correct prediction. However, Jawahar et al. (2019) reveal that the pre-trained language models capture a hierarchy of linguistic information from the lower to the upper layers, e.g., the lower layers learn the surface or syntactic features while the upper layers capture high-level information like the semantic features. We hypothesize that some instances not only rely on syntactic features but also require semantic features. It is actually undesirable to only consider features captured by shallow layers. Therefore, we propose to take advantage of both past and future states. Normally, we can directly fetch the past states, while using future information is intractable how since the future states are inaccessible before passing through the future layers. To bridge this gap, we propose a simple method to approximate the future states in light of imitation learning (Ross et al., 2011;Nguyen, 2016;Ho and Ermon, 2016). We couple each layer with an imitation learner. During training, the imitation learner is encouraged to mimic the representation of the real state of that layer. Through this layer-wise imitation, we can obtain approximations of the future states with minimum cost. The illustration of the future imitation learning during inference is shown in Figure 2.
To be precise, we intend to obtain a state approximation of the j-th layer if the forward pass exits at the intermediate i-th layer for any j > i. During training, we pass through the entire n-layer model but we simulate the situation that the forward pass ends up at the i-th layer for any i < n. The j-th learner corresponding to the j-th layer takes s i as input and outputs an approximationŝ i j of the real state s j . Then s j serves as a teacher to guide the jth imitation learner. We adopt cosine similarity as the distance measurement and penalize the discrepancy between the real state s j and the learned statê s i j . Let L i cos denotes the imitation loss of the situation that the forward pass exits at the i-th layer, it is computed as the average of the similarity loss for any j > i. Since the exit layer i can be any number between 2 to n during inference, we go through all possible number i and average the corresponding L i cos , resulting the overall loss L cos : where · denotes the L 2 norm. Learner j (·) is a simple feed-forward layer with learnable parameters W i and b i .
During training, the forward propagation is computed on all layers and all imitation learners are encouraged to generate representations close to the real states. During inference, the forward propagation proceeds to the i-th intermediate layer and the subsequent imitation learners take the i-th real state as input to generate the approximations of future states. Then the approximations are incorporated into a comprehensive future state s f with one of the global strategies introduced before: whereŝ i i+1:n denotes the approximations of the states from the (i+1)-th layer to the n-th layer.

Adaptive Merging Gate
We then explore how to adaptively merge the past information and future information. Intuitively, the past state s p and the future state s f are of different importance since the authentic past states are more reliable than our imitated future states. In addition, different instances depend differently on high-level features learned by future layers. Therefore, it is indispensable to develop an adaptive method to automatically combine the past state s p and the future state s f . In our work, we design an adaptive merging gate to automatically fuse the past state s p and the future state s f . As the forward propagation proceeds to the i-th layer, we compute the reliability of the past state s p , and the final merged representation is a trade-off between these two states: where z i is the merged final state and FFN(·) is a linear feed forward layer of the merging gate. During training, each layer can generate the approximated states of future and obtain a merged final state which is used for prediction. Then the model will be updated with the layer-wise crossentropy loss against the ground-truth label y. The merging gate adaptively learns to adjust the balance under the supervision signal given by ground-truth labels. However, with the layer-wise optimization objectives, the shallow layers will be updated more frequently since they receive more updating signals from higher layers. To address this issue, we heuristically re-weight the cross entropy loss of each layer 2017 depending on its depth i and get its weight w i . The updating procedure is formalized as: The overall loss is computed as follows:

Fine-tuning and Inference
Here we introduce the fine-tuning technique and the exit condition at the inference stage.
Fine-tuning The representations learned by shallow layers have a big impact on performance in the early exit framework since the prediction largely depends on the states of shallow layers. Most existing work updates all of the model layers at each step during fine-tuning to adapt to the data of downstream tasks. However, we argue that such an aggressive updating strategy may undermine the well-generalized features learned in the pretraining stage. In our work, we try to balance the requirements of maintaining features learned in pre-training and adapting to data at the fine-tuning stage. Specifically, the parameters of a layer will be frozen with a probability p and the probability p linearly decreases from the first layer to the L-th layer in a range of 1 to 0.
Inference Following Xin et al. (2020), we quantify the prediction confidence e with the entropy of the output distribution p i of i-th layer: The inference stops once the confidence e(p i ) is lower than a predefined threshold τ . The hyperparameter τ is adjusted according to the required speed-up ratios. If the exit condition is never reached, our model degrades into the common case of inference that the complete forward propagation is accomplished.  (2020), we manually adjust the exit threshold τ and calculate the speed-up ratio by comparing the actually executed layers in forward propagation and the required complete layers. For a n-layer model, the speed-up ratio is: where m i is the number of examples that exit at the i-th layer of the model.

Baselines
The proposed method can be practical for a range of existing pre-trained language models. Without losing generality, we conduct experiments with several well-known PLMs as backbones, namely, BERT, RoBERTa, and ALBERT (Lan et al., 2019). Both BERT and RoBERTa suffer from the problem of over-parameterization. ALBERT largely alleviates this problem and is very efficient in terms of model size, the results on which verify the effectiveness on such parameter-efficient models. We mainly compare our method with other methods targeting on reducing the depth of models, including the recent early exit methods and the method directly reducing model depth to m layers which is denoted as (AL)BERT-mL.

Overall Comparison
We compare our model performance with the baseline methods when different backbone models are adopted and show the result in Table 1 and Table 2. Both PABEE  and Dee-BERT (Xin et al., 2020) accelerate inference with a highest 2× speed-up ratio. To be consistent, we adjust the exit threshold to obtain a 2× speed-up ratio and report the results in Table 1   method maintains a comparable result with the original models on most datasets. We also notice that directly reducing layers performs well and serves as a strong baseline. Nevertheless, our proposal significantly outperforms such a method as well as the other two early exit methods.
We then adopt a more aggressive 3.00× speedup ratio to verify the effectiveness of our method. According to Table 2, the performance of PABEE and DeeBERT deteriorates badly. In contrast, our model exhibits more robust and stable performance, showing its superiority over previous early exit methods. Particularly, ALBERT is already very efficient in model size owing to its layer-sharing mechanism. Results shown in the bottom of Table 2 suggest that our model can obtain a good result with minimum performance loss on such a parameterefficient model.
The success of our proposal might be attributed to the global perspective for prediction. DeeBERT makes prediction with the help of the state of a single branch classifier, leading to less reliable results. Although PABEE employs cross-layer prediction to prevent error from one single classifier, they ignore much available information of past states as well as the high-level semantic features captured by future layers. Different from those methods, our method jointly takes into consideration the hierarchical linguistic information embedded in all layers and thus is able to produce more accurate results.

Performance-Efficiency Trade-Off
To further verify the robustness and efficiency of our method, we visualize the performanceefficiency trade-off curves in Figure 3 on a representative subset of the GLUE dev set. The backbone   model is BERT. Please refer to the Appendix A for results of RoBERTa and ALBERT. As can be seen from Figure 3, the performance of previous stateof-the-art early exit methods drops dramatically when the speed-up ratio increases, which limits their practicality for higher acceleration requirements. By comparison, our method demonstrates more tolerance of speed-up ratio. It significantly improves performance compared to previous bestperforming early exit models under the same speedup ratio, especially in the case that the speed-up ratio is high, indicating that it can be applied in a wider range of acceleration scenarios.

Effect of Global Strategies
The results of different global strategies on a representative subset of GLUE dev are shown in Table 3. The naive global strategies including voting and ensemble perform poorly, which demonstrates that existing global strategies can only achieve suboptimal performance. In contrast, we design simple yet effective global strategies to incorporate past states which bring significant improvement compared to baselines. In addition, we empirically find that the concatenation strategy works best from an overall point of view. We assume that such a strategy allows interaction among different states, yielding better performance. In addition, the effect of the merging gate can be found in Appendix B.

Analysis of Future Information
To assess whether and how future information contributes to the prediction, we first evaluate the Global Future version of our early exit method where all the approximations of futures states are incorporated through the concatenation strategy. Effect of future information is backed with the results shown in Table 4. We observe that the Global Future mechanism brings improvement on most datasets for both 2× speed-up ratio and 3× speedup ratio, which confirms that the approximations of future states help enhance the model ability in prediction. Beyond that, the future states can be especially advantageous for the models with a higher speed-up ratio. Recall that approximations of future states complement the high-level semantic information and the exit at shallow layers loses more semantic information in comparison with the exit at deep layers. Therefore, the benefit of future information is more significant compared to the exit at shallow layers, which is validated by the larger improvement gap with a 3× speed-up ratio. We also investigate the effect of future information on exit time. Figure 4 demonstrates the distribution of exit layers with and without future information. When future information is engaged, we observe that the proportion of exit at shallow layers increases. The observation conforms with our intuition: with the approximations of future states supplemented for prediction, the merged state at a shallow layer is able to make a confident and correct prediction. Thus the exit time is earlier compared to situations without future states, result-  Table 4: Effect of the approximated future states. BERT-local denotes the early exit method using only current state and Global Future represents the incorporation of future states. Results are on the GLUE dev set. ing in a higher speed-up ratio. To be more specific, for MRPC, the speed-up ratios with and without future states are 1.69 and 1.99, and are 1.92 and 2.04 for MNLI, respectively. Meanwhile, we observe a performance boost with future states involved. It confirms our assumption that the high-level semantic features embedded in future states help improve performance in early exit framework.

Comparison with Distillation Methods
As an alternative method to accelerate inference, knowledge distillation also exhibits promising performance for NLP tasks. We provide comparison with typical knowledge distillation methods in Table 5. Existing model TinyBERT (Jiao et al., 2019) exerts multiple elaborate strategies to achieve the state-of-the-art results, including the expensive general distillation process and a vast amount of augmented data for fine-tuning. We remove these two techniques to exclude the effect of extra training data. Under the same settings, we observe that our method outperforms the distillation methods with the same speed-up ratio. In general, early exit and distillation methods improve inference efficiency from different perspec-  tives. The distillation methods are more efficient in saving memory usage, but the downside is that such static methods suffer from high computation cost to adapt to different speed-up ratios. A new student model has to be trained from scratch if the speedup requirement changes. By contrast, dynamic methods are more flexible to meet different acceleration requirements. Concretely, simple instances will be processed by passing through fewer layers and complex instances may require more layers. Moreover, the speed-up ratio can be easily adjusted depending on the acceleration requests. Nevertheless, early exit and distillation accelerate inference from different perspectives and these two kinds of techniques can be integrated to further compress the model size and accelerate the inference time.

Conclusions
We propose a novel Past-Future early exit method from a global perspective. Unlike previous work using only local states for prediction, our model employs all available past states for prediction and propose a novel approach to engage the future states which are originally inaccessible for prediction. Experiments illustrate that our method achieves significant improvement over baseline methods with different models as backbones, suggesting the superiority of our early exit method.

A More Performance-Efficiency Trade-Off Curves
Performance-efficiency curves with RoBERTa and ALBERT as backbones are shown in Figure 5 and Figure 6 respectively. Similar to the observation with BERT as backbone, the performance of Dee-BERT and PABEE becomes progressively worse as the speed-up ratio increases. In contrast, our past-future early exit method shows more robust results.   Table 6: Ablation study of the merging gate. The speedup ration is approximately 2.00× and the model implementation is based on BERT.