Patient Knowledge Distillation for BERT Model Compression

Pre-trained language models such as BERT have proven to be highly effective for natural language processing (NLP) tasks. However, the high demand for computing resources in training such models hinders their application in practice. In order to alleviate this resource hunger in large-scale model training, we propose a Patient Knowledge Distillation approach to compress an original large model (teacher) into an equally-effective lightweight shallow network (student). Different from previous knowledge distillation methods, which only use the output from the last layer of the teacher network for distillation, our student model patiently learns from multiple intermediate layers of the teacher model for incremental knowledge extraction, following two strategies: (i) PKD-Last: learning from the last k layers; and (ii) PKD-Skip: learning from every k layers. These two patient distillation schemes enable the exploitation of rich information in the teacher’s hidden layers, and encourage the student model to patiently learn from and imitate the teacher through a multi-layer distillation process. Empirically, this translates into improved results on multiple NLP tasks with a significant gain in training efficiency, without sacrificing model accuracy.


Introduction
Language model pre-training has proven to be highly effective in learning universal language representations from large-scale unlabeled data. ELMo (Peters et al., 2018), GPT (Radford et al., 2018) and BERT (Devlin et al., 2018) have achieved great success in many NLP tasks, such as sentiment classification (Socher et al., 2013), natural language inference (Williams et al., 2017), and question answering (Lai et al., 2017).
Despite its empirical success, BERT's computational efficiency is a widely recognized issue because of its large number of parameters. For example, the original BERT-Base model has 12 layers and 110 million parameters. Training from scratch typically takes four days on 4 to 16 Cloud TPUs. Even fine-tuning the pre-trained model with taskspecific dataset may take several hours to finish one epoch. Thus, reducing computational costs for such models is crucial for their application in practice, where computational resources are limited.
Motivated by this, we investigate the redundancy issue of learned parameters in large-scale pre-trained models, and propose a new model compression approach, Patient Knowledge Distillation (Patient-KD), to compress original teacher (e.g., BERT) into a lightweight student model without performance sacrifice. In our approach, the teacher model outputs probability logits and predicts labels for the training samples (extendable to additional unannotated samples), and the student model learns from the teacher network to mimic the teacher's prediction.
Different from previous knowledge distillation methods (Hinton et al., 2015;Sau and Balasubramanian, 2016;Lu et al., 2017), we adopt a patient learning mechanism: instead of learning parameters from only the last layer of the teacher, we encourage the student model to extract knowledge also from previous layers of the teacher network. We call this 'Patient Knowledge Distillation'. This patient learner has the advantage of distilling rich information through the deep structure of the teacher network for multi-layer knowledge distillation.
We also propose two different strategies for the distillation process: (i) PKD-Last: the student learns from the last k layers of the teacher, under the assumption that the top layers of the original network contain the most informative knowledge to teach the student; and (ii) PKD-Skip: the student learns from every k layers of the teacher, suggesting that the lower layers of the teacher network also contain important information and should be passed along for incremental distillation.
We evaluate the proposed approach on several NLP tasks, including Sentiment Classification, Paraphrase Similarity Matching, Natural Language Inference, and Machine Reading Comprehension. Experiments on seven datasets across these four tasks demonstrate that the proposed Patient-KD approach achieves superior performance and better generalization than standard knowledge distillation methods (Hinton et al., 2015), with significant gain in training efficiency and storage reduction while maintaining comparable model accuracy to original large models. To the authors' best knowledge, this is the first known effort for BERT model compression.

Related Work
Language Model Pre-training Pre-training has been widely applied to universal language representation learning. Previous work can be divided into two main categories: (i) feature-based approach; (ii) fine-tuning approach.
On the other hand, fine-tuning approaches mainly pre-train a language model (e.g., GPT (Radford et al., 2018), BERT (Devlin et al., 2018)) on a large corpus with an unsupervised objective, and then fine-tune the model with in-domain labeled data for downstream applications (Dai and Le, 2015;Howard and Ruder, 2018). Specifically, BERT is a large-scale language model consisting of multiple layers of Transformer blocks (Vaswani et al., 2017). BERT-Base has 12 layers of Transformer and 110 million parameters, while BERT-Large has 24 layers of Transformer and 330 million parameters. By pre-training via masked language modeling and next sentence prediction, BERT has achieved state-of-the-art performance on a wide-range of NLU tasks, such as the GLUE benchmark (Wang et al., 2018) and SQuAD (Rajpurkar et al., 2016).
However, these modern pre-trained language models contain millions of parameters, which hinders their application in practice where computational resource is limited. In this paper, we aim at addressing this critical and challenging problem, taking BERT as an example, i.e., how to compress a large BERT model into a shallower one without sacrificing performance. Besides, the proposed approach can also be applied to other large-scale pre-trained language models, such as recently proposed XLNet (Yang et al., 2019) and RoBERTa (Liu et al., 2019b).

Model Compression & Knowledge Distillation
Our focus is model compression, i.e., making deep neural networks more compact (Han et al., 2016;Cheng et al., 2015). A similar line of work has focused on accelerating deep network inference at test time (Vetrov et al., 2017) and reducing model training time (Huang et al., 2016).
A conventional understanding is that a large number of connections (weights) is necessary for training deep networks (Denil et al., 2013;Zhai et al., 2016). However, once the network has been trained, there will be a high degree of parameter redundancy. Network pruning (Han et al., 2015;He et al., 2017), in which network connections are reduced or sparsified, is one common strategy for model compression. Another direction is weight quantization (Gong et al., 2014;Polino et al., 2018), in which connection weights are constrained to a set of discrete values, allowing weights to be represented by fewer bits. However, most of these pruning and quantization approaches perform on convolutional networks. Only a few work are designed for rich structural information such as deep language models (Changpinyo et al., 2017).
Knowledge distillation (Hinton et al., 2015) aims to compress a network with a large set of parameters into a compact and fast-to-execute model. This can be achieved by training a compact model to imitate the soft output of a larger model. Romero et al. (2015) further demonstrated that intermediate representations learned by the large model can serve as hints to improve the training process and the final performance of the compact model. Chen et al. (2015) introduced techniques for efficiently transferring knowledge from an existing network to a deeper or wider network. More recently, Liu et al. (2019a) used knowledge from ensemble models to improve single model performance on NLU tasks. Tan et al. (2019) tried knowledge distillation for multilingual translation. Different from the above efforts, we investigate the problem of compressing large-scale language models, and propose a novel patient knowledge distillation approach to effectively transferring knowledge from a teacher to a student model.

Patient Knowledge Distillation
In this section, we first introduce a vanilla knowledge distillation method for BERT compression (Section 3.1), then present the proposed Patient Knowledge Distillation (Section 3.2) in details.

Problem Definition
The original large teacher network is represented by a function f (x; θ), where x is the input to the network, and θ denotes the model parameters. The goal of knowledge distillation is to learn a new set of parameters θ for a shallower student network g(x; θ ), such that the student network achieves similar performance to the teacher, with much lower computational cost. Our strategy is to force the student model to imitate outputs from the teacher model on the training dataset with a defined objective L KD .

Distillation Objective
In our setting, the teacher f (x; θ) is defined as a deep bidirectional encoder, e.g., BERT, and the student g(x; θ ) is a lightweight model with fewer layers. For simplicity, we use BERT k to denote a model with k layers of Transformers. Following the original BERT paper (Devlin et al., 2018), we also use BERT-Base and BERT-Large to denote BERT 12 and BERT 24 , respectively.
are N training samples, where x i is the i-th input instance for BERT, and y i is the corresponding ground-truth label. BERT first computes a contextualized embedding Then, a softmax layer y i = P (y i |x i ) = sof tmax(Wh i ) for classification is applied to the embedding of BERT output, where W is a weight matrix to be learned.
To apply knowledge distillation, first we need to train a teacher network. For example, to train a 12-layer BERT-Base as the teacher model, the learned parameters are denoted as: where the superscript t denotes parameters in the teacher model, [N ] denotes set {1, 2, . . . , N }, L t CE denotes the cross-entropy loss for the teacher training, and θ BERT 12 denotes parameters of BERT 12 .
The output probability for any given input x i can be formulated as: where P t (·|·) denotes the probability output from the teacher.ŷ i is fixed as soft labels, and T is the temperature used in KD, which controls how much to rely on the teacher's soft predictions. A higher temperature produces a more diverse probability distribution over classes (Hinton et al., 2015). Similarly, let θ s denote parameters to be learned for the student model, and P s (·|·) denote the corresponding probability output from the student model. Thus, the distance between the teacher's prediction and the student's prediction can be defined as: where c is a class label and C denotes the set of class labels.
Besides encouraging the student model to imitate the teacher's behavior, we can also fine-tune the student model on target tasks, where taskspecific cross-entropy loss is included for model training: Thus, the final objective function for knowledge distillation can be formulated as: where α is the hyper-parameter that balances the importance of the cross-entropy loss and the distillation loss.

Patient Teacher for Model Compression
Using a weighted combination of ground-truth labels and soft predictions from the last layer of the teacher network, the student network can achieve comparable performance to the teacher model on the training set. However, with the number of epochs increasing, the student model learned with this vanilla KD framework quickly reaches saturation on the test set (see Figure 2 in Section 4). One hypothesis is that overfitting during knowledge distillation may lead to poor generalization. To mitigate this issue, instead of forcing the student to learn only from the logits of the last layer, we propose a "patient" teacher-student mechanism to distill knowledge from the teacher's intermediate layers as well. Specifically, we investigate two patient distillation strategies: (i) PKD-Skip: the student learns from every k layers of the teacher (Figure 1: Left); and (ii) PKD-Last: the student learns from the last k layers of the teacher (Figure 1: Right).
Learning from the hidden states of all the tokens is computationally expensive, and may introduce noise. In the original BERT implementation (Devlin et al., 2018), prediction is performed by only using the output from the last layer's [CLS] token. In some variants of BERT, like SDNet (Zhu et al., 2018), a weighted average of all layers' [CLS] embeddings is applied. In general, the final logit can be computed based on h final = j∈[k] w j h j , where w j could be either learned parameters or a pre-defined hyper-parameter, h j is the embedding of [CLS] from the hidden layer j, and k is the number of hidden layers. Derived from this, if the compressed model can learn from the representation of [CLS] in the teacher's intermediate layers for any given input, it has the potential of gaining a generalization ability similar to the teacher model.
Motivated by this, in our Patient-KD framework, the student is cultivated to imitate the representations only for the [CLS] token in the intermediate layers, following the intuition aforementioned that the [CLS] token is important in predicting the final labels. For an input x i , the outputs of the [CLS] tokens for all the layers are denoted as: We denote the set of intermediate layers to distill knowledge from as I pt . Take distilling from BERT 12 to BERT 6 as an example. For the PKD-Skip strategy, I pt = {2, 4, 6, 8, 10}; and for the PKD-Last strategy, I pt = {7, 8, 9, 10, 11}. Note that k = 5 for both cases, because the output from the last layer (e.g., Layer 12 for BERT-Base) is omitted since its hidden states are connected to the softmax layer, which is already included in the KD loss defined in Eqn. (5). In general, for BERT student with n layers, k always equals to n − 1.
The additional training loss introduced by the patient teacher is defined as the mean-square loss between the normalized hidden states: where M denotes the number of layers in the student network, N is the number of training samples, and the superscripts s and t in h indicate the student and the teacher model, respectively. Combined with the KD loss introduced in Section 3.1, the final objective function can be formulated as: L P KD = (1 − α)L s CE + αL DS + βL P T (8) where β is another hyper-parameter that weights the importance of the features for distillation in the intermediate layers.

Experiments
In this section, we describe our experiments on applying the proposed Patient-KD approach to four different NLP tasks. Details on the datasets and experimental results are provided in the following sub-sections.
More specifically, SST-2 is a movie review dataset with binary annotations, where the binary label indicates positive and negative reviews. MRPC contains pairs of sentences and corresponding labels, which indicate the semantic equivalence relationship between each pair. QQP is designed to predict whether a pair of questions is duplicate or not, provided by a popular online question-answering website Quora. MNLI is a multi-domain NLI task for predicting whether a given premise-hypothesis pair is entailment, contradiction or neural. Its test and development datasets are further divided into in-domain (MNLI-m) and cross-domain (MNLI-mm) splits to evaluate the generality of tested models. QNLI is a task for predicting whether a question-answer pair is entailment or not. Finally, RTE is based on a series of textual entailment challenges, created by General Language Understanding Evaluation (GLUE) benchmark (Wang et al., 2018).
For the Machine Reading Comprehension task, we evaluate on RACE (Lai et al., 2017), a largescale dataset collected from English exams, containing 25,137 passages and 87,866 questions. For each question, four candidate answers are pro-2 https://data.quora.com/First-Quora-Dataset-Release-Question-Pairs 3 The dataset is derived from Stanford Question Answer Dataset (SQuAD). vided, only one of which is correct. The dataset is further divided into RACE-M and RACE-H, containing exam questions for middle school and high school students.

Baselines and Training Details
For experiments on the GLUE benchmark, since all the tasks can be considered as sentence (or sentence-pair) classification, we use the same architecture in the original BERT (Devlin et al., 2018), and fine-tune each task independently.
For experiments on RACE, we denote the input passage as P , the question as q, and the four answers as a 1 , . . . , a 4 . We first concatenate the tokens in q and each a i , and arrange the input of BERT as [CLS] P [SEP] q +a i [SEP] for each input pair (P, q + a i ), where [CLS] and [SEP] are the special tokens used in the original BERT. In this way, we can obtain a single logit value for each a i . At last, a softmax layer is placed on top of these four logits to obtain the normalized probability of each answer a i being correct, which is then used to compute the cross-entropy loss for modeling training.
We fine-tune BERT-Base (denoted as BERT 12 ) as the teacher model to compute soft labels for each task independently, where the pretrained model weights are obtained from Google's official BERT's repo 4 , and use 3 and 6 layers of Transformers as the student models (BERT 3 and BERT 6 ), respectively. We initialize BERT k with the first k layers of parameters from pre-trained BERT-Base, where k ∈ {3, 6}. To validate the effectiveness of our proposed approach, we first conduct direct fine-tuning on each task without using any soft labels. In order to reduce the hyperparameter search space, we fix the number of hidden units in the final softmax layer as 768, the batch size as 32, and the number of epochs as 4 for all the experiments, with a learning rate from {5e-5, 2e-5, 1e-5}. The model with the best validation accuracy is selected for each setting.
Besides direct fine-tuning, we further implement a vanilla KD method on all the tasks by optimizing the objective function in Eqn. (5). We set the temperature T as {5, 10, 20}, α = {0.2, 0.5, 0.7}, and perform grid search over T , α and learning rate, to select the model with the best validation accuracy. For our proposed Patient-KD approach, we conduct additional search over β  from {10, 100, 500, 1000} on all the tasks. Since there are so many hyper-parameters to learn for Patient KD, we fix α and T to the values used in the model with the best performance from the vanilla KD experiments, and only search over β and learning rate.

Experimental Results
We submitted our model predictions to the official GLUE evaluation server to obtain results on the test data. Results are summarized in Table 1. Compared to direct fine-tuning and vanilla KD, our Patient-KD models with BERT 3 and BERT 6 students perform the best on almost all the tasks except MRPC. For MNLI-m and MNLI-mm, our 6-layer model improves 1.1% and 1.3% over finetune (FT) baselines; for QNLI and QQP, even though the gap between BERT 6 -KD and BERT 12 teacher is relatively small, our approach still succeeded in improving over both FT and KD base-lines and further closing the gap between the student and the teacher models. Furthermore, in 5 tasks out of 7 (SST-2 (-2.3% compared to BERT-Base teacher), QQP (-0.1%), MNLI-m (-2.2%), MNLI-mm (-1.8%), and QNLI (-1.4%)), the proposed 6-layer student coached by the patient teacher achieved similar performance to the original BERT-Base, demonstrating the effectiveness of our approach. Interestingly, all those 5 tasks have more than 60k training samples, which indicates that our method tends to perform better when there is a large amount of training data.
For the QQP task, we can further reduce the model size to 3 layers, where BERT 3 -PKD can still have a similar performance to the teacher model. The learning curves on the QNLI and MNLI datasets are provided in Figure 2. The student model learned with vanilla KD quickly saturated on the dev set, while the proposed Patient-  KD keeps learning from the teacher and improving accuracy, only starting to plateau in a later stage. For the MRPC dataset, one hypothesis for the reason on vanilla KD outperforming our model is that the lack of enough training samples may lead to overfitting on the dev set. To further investigate, we repeat the experiments three times and compute the average accuracy on the dev set. We observe that fine-tuning and vanilla KD have a mean dev accuracy of 82.23% and 82.84%, respectively. Our proposed method has a higher mean dev accuracy of 83.46%, hence indicating that our Patient-KD method slightly overfitted to the dev set of MRPC due to the small amount of training data. This can also be observed on the performance gap between teacher and student on RTE in Table 5, which also has a small training set.
We further investigate the performance gain from two different patient teacher designs: PKD-Last vs. PKD-Skip. Results of both PKD variants on the GLUE benchmark (with BERT 6 as the student) are summarized in Table 2. Although both strategies achieved improvement over the vanilla KD baseline (see Table 1), PKD-Skip performs slightly better than PKD-Last. Presumably, this might be due to the fact that distilling information across every k layers captures more diverse representations of richer semantics from low-level to high-level, while focusing on the last k layers tends to capture relatively homogeneous semantic information.
Results on RACE are reported in Table 3, which shows that the Vanilla KD method outperforms direct fine-tuning by 4.42%, and our proposed patient teacher achieves further 1.6% performance lift, which again demonstrates the effectiveness of Patient-KD.

Analysis of Model Efficiency
We have demonstrated that the proposed Patient-KD method can effectively compress BERT 12 into BERT 6 models without performance sacrifice. In this section, we further investigate the efficiency of Patient-KD on storage saving and inference-time speedup. Parameter statistics and inference time  are summarized in Table 4. All the models share the same embedding layer with 24 millon parameters that map a 30k-word vocabulary to a 768dimensional vector, which leads to 1.64 and 2.4 times of machine memory saving from BERT 6 and BERT 3 , respectively.
To test the inference speed, we ran experiments on 105k samples from QNLI training set (Rajpurkar et al., 2016). Inference is performed on a single Titan RTX GPU with batch size set to 128, maximum sequence length set to 128, and FP16 activated. The inference time for the embedding layer is negligible compared to the Transformer layers. Results in Table 4 show that the proposed Patient-KD approach achieves an almost linear speedup, 1.94 and 3.73 times for BERT 6 and BERT 3 , respectively.

Does a Better Teacher Help?
To evaluate the effectiveness of the teacher model in our Patient-KD framework, we conduct additional experiments to measure the difference between BERT-Base teacher and BERT-Large teacher for model compression.
Each Transformer layer in BERT-Large has 12.6 million parameters, which is much larger than the Transformer layer used in BERT-Base. For a compressed BERT model with 6 layers, BERT 6 with BERT-Base Transformer (denoted as BERT 6 [Base]) has only 67.0 million parameters,   Results are summarized in Table 5. When the teacher changes from BERT 12 to BERT 24 (i.e., Setting #1 vs. #2), there is not much difference between the students' performance. Specifically, BERT 12 teacher performs better on SST-2, QQP and QNLI, while BERT 24 performs better on MNLI-m, MNLI-mm and RTE. Presumably, distilling knowledge from a larger teacher requires a larger training dataset, thus better results are observed on MNLI-m and MNLI-mm.
We also report results on using BERT-Large as the teacher and BERT 6 [Large] as the student. Interestingly, when comparing Setting #1 with #3, BERT 6 [Large] performs much worse than BERT 6 [Base] even though a better teacher is used in the former case. The BERT 6 [Large] student also has 1.6 times more parameters than BERT 6 [Base]. One intuition behind this is that the compression ratio for the BERT 6 [Large] model is 4:1 (24:6), which is larger than the ratio used for the BERT 6 [Base] model (2:1 (12:6)). The higher compression ratio renders it more challenging for the student model to absorb important weights.
When comparing Setting # 2 and #3, we ob-serve that even when the same large teacher is used, BERT 6 [Large] still performs worse than BERT 6 [Base]. Presumably, this may be due to initialization mismatch. Ideally, we should pre-train BERT 6 [Large] and BERT 6 [Base] from scratch, and use the weights learned from the pretraining step for weight initialization in KD training. However, due to computational limits of training BERT 6 from scratch, we only initialize the student model with the first six layers of BERT 12 or BERT 24 . Therefore, the first six layers of BERT 24 may not be able to capture high-level features, leading to worse KD performance.
Finally, when comparing Setting #3 vs. #4, where for setting #4 we use Patient-KD-Skip instead of vanilla KD, we observe a performance gain on almost all the tasks, which indicates Patient-KD is a generic approach independent of the selection of the teacher model (BERT 12 or BERT 24 ).

Conclusion
In this paper, we propose a novel approach to compressing a large BERT model into a shallow one via Patient Knowledge Distillation. To fully utilize the rich information in deep structure of the teacher network, our Patient-KD approach encourages the student model to patiently learn from the teacher through a multi-layer distillation process. Extensive experiments over four NLP tasks demonstrate the effectiveness of our proposed model.
For future work, we plan to pre-train BERT from scratch to address the initialization mismatch issue, and potentially modify the proposed method such that it could also help during pre-training. Designing more sophisticated distance metrics for loss functions is another exploration direction. We will also investigate Patient-KD in more complex settings such as multi-task learning and meta learning.