Rethinking Network Pruning – under the Pre-train and Fine-tune Paradigm

Transformer-based pre-trained language models have significantly improved the performance of various natural language processing (NLP) tasks in the recent years. While effective and prevalent, these models are usually prohibitively large for resource-limited deployment scenarios. A thread of research has thus been working on applying network pruning techniques under the pretrain-then-finetune paradigm widely adopted in NLP. However, the existing pruning results on benchmark transformers, such as BERT, are not as remarkable as the pruning results in the literature of convolutional neural networks (CNNs). In particular, common wisdom in pruning CNN states that sparse pruning technique compresses a model more than that obtained by reducing number of channels and layers, while existing works on sparse pruning of BERT yields inferior results than its small-dense counterparts such as TinyBERT. In this work, we aim to fill this gap by studying how knowledge are transferred and lost during the pre-train, fine-tune, and pruning process, and proposing a knowledge-aware sparse pruning process that achieves significantly superior results than existing literature. We show for the first time that sparse pruning compresses a BERT model significantly more than reducing its number of channels and layers. Experiments on multiple data sets of GLUE benchmark show that our method outperforms the leading competitors with a 20-times weight/FLOPs compression and neglectable loss in prediction accuracy.


Introduction
Pre-trained language models, such as BERT (Devlin et al., 2019), become the standard and effective methods for improving the performance of a variety of natural language processing (NLP) tasks. These models are pre-trained in a self-supervised fashion and then fine-tuned for supervised downstream tasks. However, these models suffer from the heavy model size, making them impractical for resourcelimited deployment scenarios and incurring cost concerns (Strubell et al., 2019).
In parallel, an emerging subfield has studied the redundancy in deep neural network models (Zhu and Gupta, 2017;Gale et al., 2019) and proposed to prune networks without sacrificing performance, such as the lottery ticket hypothesis (Frankle and Carbin, 2019). Common wisdom in CNN literature shows that sparse pruning leads to more compression rate than structural pruning. For example, for the same number of parameters (0.46M), the sparse MobileNets improve by 11.2% accuracy over the dense ones (Zhu and Gupta, 2017). However, similar conclusions are not observed for pre-trained language models.
The main question this paper attempts to answer is: how to perform sparse pruning under the pre-train and fine-tune paradigm? Answering this question correctly is challenging. First, these models adopt pre-training and fine-tuning procedures, during which the general-purpose language knowledge and the task-specific knowledge are learned respectively. Thus, it is desirable and challenging to keep the weights that are important to both knowledge during pruning. Second, unlike CNNs, pre-trained language models have a complex architecture consisting of embedding, self-attention, and feed-forward layers.
To address these challenges, we propose Sparse-BERT, a knowledge-aware sparse pruning method for pre-trained language models, with a special focus on the widely used BERT model. SparseBERT is executed in the fine-tuning stage. It preserves both general-purpose and task-specific language knowledge while pruning. To preserve the generalpurpose knowledge learned during pre-training,  is the general pre-training and fine-tuning procedure (Section 3.1). g is an encoder. g L and g L D are the encoders well-trained on the pre-training and fine-tuning datasets respectively. L and D are the general-purpose language knowledge and the task-specific knowledge respectively. There is a domain error between pre-training and testing, and a generalization error between fine-tuning and testing.
Genera. Error / 6 , 2 6 Teacher = 3 57 Student = 3 5 → 3 57 9: SparseBERT uses the pre-trained BERT without fine-tuning as the initialized model and prunes the linear transformations in self-attention and feedforward layers, which is inspired by the recent findings that self-attention and feed-forward layers are overparameterized (Michel et al., 2019;Voita et al., 2019) and are also the most computation consumption parts (Ganesh et al., 2020). To learn the task-specific task knowledge during pruning while preserving the general-purpose knowledge at the same time, we apply knowledge distillation (Hinton et al., 2015). We adopt the task-specific fine-tuned BERT as the teacher network and the pre-trained BERT that is being pruned as the student. We feed the downstream task data into the teacher-student framework to train the student to reproduce the behaviors of the teacher. We summarize different types of BERT pruning approaches in Figure 1 (see Section 3.2 for detailed discussion) Experimental results on the GLUE benchmark demonstrate that SparseBERT outper-forms all the leading competitors and achieves 1.4% averaged loss with down to only 5% remaining weights compared to BERT-base.

Related Work
A lot of efforts have been made on studying network redundancy and pruning networks without accuracy loss (Gale et al., 2019;Renda et al., 2020). For example, the work on lottery ticket hypothesis (Frankle and Carbin, 2019) showed that there exist sparse smaller subnetworks capable of training to full accuracy in CNNs. Common wisdom in CNN literature shows that spare pruning leads to much more compression rate than structural pruning (Gale et al., 2019;Elsen et al., 2020). For example, for the same number of parameters (0.46M), the sparse MobileNets achieve 61.8% accuracy while the dense ones achieve 50.6% (Zhu and Gupta, 2017). However, similar observations are not observed in existing approaches for pretrained language models (Fan et al., 2019;Michel et al., 2019;Chen et al., 2020;McCarley et al., 2020;Jiao et al., 2020). Our method aims to fill the gap and summarize these pruning strategies. There are other compression approaches for pre-trained language models, such as quantization (Zafrir et al., 2019) and weight factorization , which are out of the scope of this work.
We first formalize the knowledge transfer involved in fine-tuning pre-trained language models. Then, we introduce our SparseBERT.

Knowledge Transfer under the Pre-train
and Fine-tune Paradigm The practice of fine-tuning pre-trained language models has become prevalent in various NLP tasks. The two-stage procedure is illustrated in Figure 1(a). The language model is denoted by f = g • h, where g is a text encoder and h is a task predictor head. Text encoders, like Transformers in BERT, are used to map input sentences to hidden representations and task predictors further map the representations to the label space. The pre-trained model is trained on a large amount of data examples (x p , y p ) from the pre-training task domain via different tasks that resemble language modeling. During pre-training, the general-purpose language knowledge, denoted by L, is learned based on (x p , y p ). L contains a subset that is related to the downstream task, denoted by L D , and the amount of L is far greater than that of L D (see Figure 2(a)). To transfer knowledge L (especially L D ) from pre-training domain to downstream domain, the well-trained encoder g L is used to initialize the downstream encoder. In fine-tuning, downstream encoder is trained based on the task-specific knowledge D preserved in a small amount of data examples (x d , y d ) from downstream domain. Finally, the well-trained downstream encoder g L D is evaluated on test data.

Two Basic Pruning Strategies
Intuitively, there are two pruning strategies. One is that pruning is applied to the downstream encoder g L during fine-tuning (see Figure 1(b)). However, because the loss to update the weights during finetuning is exclusively based on the data examples (x d , y d ) from the downstream task domain, this pruning strategy might destruct the knowledge L D , which is learned based on (x p , y p ) and encoded in the initialization of g L .
The other strategy is that pruning is executed during pre-training (see Figure 1(c)). The generated pruned network preserves a subset of knowledge L, denoted by L pr . Unfortunately, because this strategy ignores the downstream task information and the amount of L is extremely large, i.e., L L pr , the knowledge L pr could be much different from L D that we hope to preserve (see Figure 2(a)).

The Proposed Pruning Strategy
As shown in Figure 1(d), SparseBERT executes pruning at the distilling stage. It prunes the pretrained encoder without fine-tuning, g L , while finetuning the pruned encoder based on the downstream dataset (x d , y d ). Recent findings indicate that self-attention and feed-forward layers are overparameterized and are the most computation consumption parts (Michel et al., 2019;Voita et al., 2019;Ganesh et al., 2020). Thus, SparseBERT applies network pruning to the linear transformations matrices in self-attention and feed-forward layers (see Figure 3). The choice of pruning approach is flexible. We choose magnitude weight pruning (Han et al., 2015) in this paper, mainly because it is one of the most effective and popular pruning methods. More details about the pruning strategy used in SparseBERT can be found in the codes.

Knowledge Distillation Helps Pruning Preserve Task-Specific Knowledge
To mitigate the loss of L D , we propose to utilize knowledge distillation while pruning. We use the task-specific fine-tuned BERT as the teacher network and the pre-trained BERT that is being pruned as the student (see Figure 1(d) and Figure 3). The motivation is that the task-specific fine-tuned BERT preserves L D . By feeding downstream task data (x d , y d ) into the teacher-student framework, we help the student reproduce the behaviors of the teacher to learn both L d and L as much as possible.
We design the distillation loss as is the difference between the embedding layers of student and teacher. L att = MSE(A S i , A T i ) is the difference between attention matrices and i is the layer index. L hid = MSE(H S i , H T i ) is the difference between hidden representations. L prd = -softmax(z T ) · log_softmax(z S /temp) is the soft cross-entropy loss between the logits of student and teacher. temp represents the temperature value. The proposed distillation loss is inspired by (Jiao et al., 2020) and it helps the student imitate the teacher's behavior as much as possible. In addition, we perform the same data augmentation as (Jiao et al., 2020) does to generate more task-specific data for teacher-student learning. Notably, the choices of distillation loss and data augmentation method are flexible and we found the ones we adopted worked well in general.

GLUE Benchmark
We evaluate SparseBERT on four data sets from the GLUE benchmark (Wang et al., 2018). To test if SparseBERT is applicable across tasks, we include the tasks of both single sentence and sentence-pair classification. We report the results on dev sets. We run 3, 20, 20, 50 epochs for QNLI, MRPC, RTE, CoLA separately. The baselines include BERTbase, ELMo (Peters et al., 2018) The results are shown in Table 1. Compared to BERT-base, SparseBERT achieves 1.4% averaged performance loss with down to 5% weights. In addition, SparseBERT outperforms all leading competitors with the highest sparsity.

SparseBERT v.s. Pruning at Downstream
We compare SparseBERT with the pruning described in Figure 1(b) on the question answer tasks of SQuAD v1.1 and v2.0 (Rajpurkar et al., 2016(Rajpurkar et al., , 2018. Given a question and a passage containing the answer, the two tasks are to predict the answer text span in the passage. The difference between them is that SQuAD v2.0 allows for the possibility that no short answer exists in the passage. We follow the general setting of SparseBERT, except that we only apply the logit distillation, i.e., L distil = L prd , and do not perform data augmentation, which are the most common distillation strategies.
The results are shown in Figure 4. It is observed that SparseBERT consistently outperforms the baseline method, especially at high sparsity. The performance gain of SparseBERT decreases on SQuAD v2.0 mainly because SQuAD v2.0 is more challenging than SQuAD v1.1. These observations demonstrate advantage of SparseBERT compared to pruning at downstream.

SparseBERT v.s. Pruning at Pre-Training
To get more insights about the advantage of Sparse-BERT over the pruning described in Figure 1(c), we compare their fitting abilities. Specifically, we use TinyBERT as an example of the baseline pruning method. We compare SparseBERT with TinyBERT with 4 layers and 312 hidden dimensions, which has a similar number of parameters as SparseBERT (sparsity=95%). SparseBERT only distills knowledge from the same layers as TinyBERT does.
We vary the number of pruning epochs and report the results (loss on training set and accuracy on dev set) on RTE in Figure 5. It is observed that SparseBERT consistently shows smaller training loss while higher evaluation performance, which demonstrates that SparseBERT has a better fitting ability when pruning compared to the baseline.

Hardware Performance
Sparse networks were not hardware-friendly in the past. However, hardware platforms with sparse tensor operation support have been rising up. For example, the latest release of Nvidia high-end GPU A100 has native support of sparse tensor operation up to 2x compression rate, while startup company such as Moffett AI has developed computing platform with sparse tensor operation acceleration up to 32x compression rate. Here we deployed SparseBERT of different sparse compression ratios (1,2,4,8,16,20) on Moffett AI's latest hardware platform ANTOM to measure the real inference speedup induced by sparse compression, where '4' indicates the model is compressed by a factor of 4, with 75% of the parameters being zeros. As shown in Figure 6, the sparse compression has almost linear speedup up to 4x and leads to more than 10x speedup when compression rate is 20x.

Inference/Training Time
We studied the time and convergence speed. For example, to get the reported 20x pruned result (Table 1), it needed 12 epochs of fine-tuning on MRPC and each epoch took 1.5 h (two RTX 2080 Ti). The inference time was around 20 s.

Conclusion
We introduce SparseBERT, a knowledge-aware sparse pruning method for pre-trained language models, with a focus on BERT. We summarize different types of BERT pruning approaches and compare SparseBERT with leading competitors. Experimental results on GLUE and SQuAD benchmarks demonstrate the superiority of SparseBERT.