Exploring the Limits of Simple Learners in Knowledge Distillation for Document Classification with DocBERT

Fine-tuned variants of BERT are able to achieve state-of-the-art accuracy on many natural language processing tasks, although at significant computational costs. In this paper, we verify BERT’s effectiveness for document classification and investigate the extent to which BERT-level effectiveness can be obtained by different baselines, combined with knowledge distillation—a popular model compression method. The results show that BERT-level effectiveness can be achieved by a single-layer LSTM with at least 40\times fewer FLOPS and only {\sim}3\% parameters. More importantly, this study analyzes the limits of knowledge distillation as we distill BERT’s knowledge all the way down to linear models—a relevant baseline for the task. We report substantial improvement in effectiveness for even the simplest models, as they capture the knowledge learnt by BERT.


Introduction
Transformer-based (Vaswani et al., 2017) pretrained contextual word embedding models such as BERT (Devlin et al., 2019) and XLNet (Yang et al., 2019) currently power many of the state-of-the-art models across various natural language processing (NLP) tasks. However, these models consume immense computational resources (Strubell et al., 2019). With the surge of such pre-trained models being developed in quick succession, there is a need for effective compression techniques for their inexpensive deployment.
Knowledge distillation (KD; Hinton et al. 2015;Ba and Caruana 2014) has been shown to be a fairly straightforward and effective model-agnostic compression method, which transfers knowledge learnt by huge models into more efficient models. In this work, we investigate if BERT-level effectiveness can be achieved by more efficient models using KD. And more importantly, if so, how simple can these models be?
We investigate these questions through the lens of document classification-a setting where these computational concerns are particularly relevant due to potentially long document lengths. Further, in previous work, neural networks as an architectural choice have been questioned owing to the effectiveness of simple bag-of-words baselines (Adhikari et al., 2019).
We first confirm that a fine-tuned BERT model leads to state-of-the-art model quality by a substantial margin on standard document classification benchmarks. Following this, we investigate the extent to which BERT-level effectiveness can be obtained by various different baselines, combined with KD. We demonstrate, quite surprisingly, that it is possible to apply KD successfully on impoverished student models, such as a single-layer convolutional neural network (CNN) (Kim, 2014) and even linear models. The key contributions of this work are as follows: 1. We develop and release * a fine-tuned BERT model (DocBERT), which achieves state-of-theart model quality for document classification. While this finding is perhaps obvious, we carefully document experimental results.

Background and Methods
Typically, the task of document classification deals with classifying long texts (documents). More often than not, a document may be associated with more than one label, thus exposing the classifiers to multi-label classification and class imbalance.
Here, we review a subset of approaches developed to solve the task and highlight the methods that we compare and build upon in this work.

Document Classification Models
Neural network-based models. In recent years neural network-based architectures have dominated the task of document classification. Many researchers (Kim, 2014;Conneau et al., 2017;Johnson and Zhang, 2017) show convolutional neural networks to be effective for classifying single-label short texts. Furthermore, Liu et al. (2017) develop a variant of the popular KimCNN (Kim, 2014), XML-CNN, for addressing the multi-label nature of document classification, which they call extreme classification. Alternatively, others (Yang et al., 2016;Adhikari et al., 2019;Yang et al., 2018) show effective use of recurrent neural networks to exploit semantic representations by treating documents as a sequence of words or sentences for classification. In this work, we explore several neural baseline models and use both LSTM (Hochreiter and Schmidhuber, 1997) and KimCNN architectures for knowledge distillation experiments. Non-neural models. Logistic regression (LR) and support vector machines (SVM) trained on tf-idf vectors form efficient and effective baselines for document classification. Adhikari et al. (2019) show LR and SVM surpass most of the neural baselines on multiple datasets, questioning the need for employing neural networks to model syntactic structure for document classification. Here, we explore both LR and SVMs, and we perform knowledge distillation experiments using an LR model. Large-scale pre-training. Recent work (Howard and Ruder, 2018;Devlin et al., 2019;Yang et al., 2019) has demonstrated the effectiveness of largescale pre-training for NLP tasks. In this work, we use BERT as a representative of this approach and demonstrate the power of fine-tuned BERT on document classification (termed DocBERT).

Knowledge Distillation
Knowledge distillation (KD; Hinton et al., 2015;Ba and Caruana, 2014) is an effective model-agnostic approach to model compression, where an efficient student model captures the knowledge learnt by privileged but cumbersome teacher model(s). The knowledge transfer takes place by forcing the student to mimic the soft target probabilities of the teacher. Hinton et al. (2015) highlight that it is in the interest of the generalizability of the student model to capture the exact class probabilities from a better model, the teacher. In supervised settings, the student is trained using a distillation objective in combination with the classification objective: where λ is a hyperparameter chosen to weigh the two different optimization objectives. The L classif ication term is the task-specific classification loss, which is most often the cross-entropy loss between the logits of the student model and the target labels, while the distillation term L distill quantifies the difference between the student predictions and the teacher. In this work, we use a fine-tuned BERT model as the teacher and experiment with various baseline architectures for the students. Following Hinton et al. (2015), we set L distill to be equal to the Kullback-Leibler divergence between the class probabilities output by the student and the teacher BERT model.

Datasets
We use the following four datasets to evaluate BERT: Reuters-21578 (Reuters; Apté et al., 1994), arXiv Academic Paper dataset (AAPD; Yang et al., 2018), IMDB reviews, and Yelp 2014 reviews. Reuters and AAPD are multi-label datasets while documents in IMDB and Yelp '14 contain only a single label per document. Table 1 summarizes the statistics of these datasets. For Reuters, we use the standard ModApté splits (Apté et al., 1994); for AAPD, we use the splits provided by Yang et al. (2018); for IMDB and Yelp, following Yang et al. (2016), we randomly sample 80% of the data for training and 10% each for validation and test.

Training and Hyperparameters
As a simple and straightforward adaptation of BERT models (Devlin et al., 2019) for document classification, we introduce a fully-connected layer over the final hidden state corresponding to the [CLS] input token. During fine-tuning, we optimize the entire model end-to-end, with the additional softmax classifier parameters W ∈ R K×H , where H is the dimension of the hidden state vectors and K is the number of classes. We minimize the cross-entropy and binary cross-entropy loss for single-label and multi-label tasks, respectively. While fine-tuning BERT, we optimize the number of epochs, batch size, learning rate, and maximum sequence length (MSL; i.e., the number of tokens that documents are truncated to).
For knowledge distillation, we train the LSTM, KimCNN, and LR to capture the learnt representations from BERT large using the objective of the type shown in Equation (1). Depending upon the dataset, we use cross-entropy or binary cross-entropy loss as L classif ication , Equation (1). For L distill , following Hinton et al. (2015), we minimize the Kullback-Leibler (KL) divergence KL(p||q) where p and q are the class probabilities produced by the student and the teacher models, respectively.
To build an effective transfer set for distillation as suggested by Ba and Caruana (2014), we augment the training splits of the datasets by applying POS-guided word swapping and random masking, as in , along with randomizing the order of the sentences of documents in the training set. The transfer set sizes for Reuters, IMDB and AAPD are 3×, 4×, and 4× their training splits, respectively; for Yelp2014, no data augmentation was performed due to computational restrictions. Refer to the appendix for further details regarding the training hyperparameters.

Results and Discussion
In Table 2, which shows our main results, we report the mean F 1 scores for multi-label datasets and accuracy for single-label datasets, along with the corresponding standard deviation across five runs. Due to their higher computational costs, we report the scores from only a single run per task for BERT base and BERT large .
Rows 1-7 report the model quality of pre-BERT models (that do not take advantage of pre-training). As observed by Adhikari et al. (2019), LR and SVM trained with tf-idf vectors form effective baselines as they challenge many neural networkbased baselines (e.g., HAN) on multiple datasets. This raises the question whether neural networks are a suitable architectural choice for document classification. However, at a much higher computational cost, the regularized LSTM (Adhikari et al., 2019) (row 7) achieves the best numbers for the class of models that do not exploit pre-training.
Consistent with Devlin et al. (2019), the BERTbased models achieve state-of-the-art results on all four datasets (see Table 2, rows 8 and 9), with the BERT large model consistently achieving the highest model quality (compared to BERT base ).
Surprisingly, distilled LSTM (KD-LSTM, row 10) achieves parity with BERT base on average for Reuters, AAPD, and IMDB. In fact, it outperforms BERT base (on both dev and test) in at least one of the five runs. For Yelp, we see that KD-LSTM reduces the difference between BERT base and LSTM, but not to the same extent as in the other datasets.
Next, we explore the limits of KD by further distilling BERT base all the way down to KimCNN (Kim, 2014) (a single-layer CNN) and LR. It is not surprising that these models don't come close to BERT base owing to their limited expressivity. However, interestingly, we see massive leaps in model quality of these models after distillation (rows 1-3; 11-12). Specifically for multi-label datasets, both these models beat or come close to HAN and SGM, which are far more complex models. To put things in perspective, LR is a simple fully-connected layer and KimCNN contains merely ∼ 0.4% parameters of BERT base . These results demonstrate that KD can yield a broad spectrum of baselines for varying computational costs, all of which can be useful depending on the requirements.   Adhikari et al. (2019). Model names of type "KD-X" (rows 10-12) refer to X trained using knowledge distillation from the fine-tuned BERT large (row 9).

Conclusion and Future Work
In this paper we improve baselines for document classification by fine-tuning BERT (DocBERT). Using DocBERT, we show the effectiveness of KD over a range of efficient models-a single-layer LSTM model, a single layer CNN, and a logistic regression trained on tf-idf. This provides us with a spectrum of baselines for varying tradeoffs in classification accuracy and complexity. In fact, we show that the distilled LSTM model achieves BERT base parity on a majority of datasets, using only ∼ 3% parameters of the latter. While distillation is an effective way to reduce computational cost during inference, it doesn't aid in reducing resources needed for training. Thus, methods for reducing the computational resources required while training deserve attention in future.

A.1 Training Hyperparameters
While fine-tuning BERT, we optimize the number of epochs, batch size, learning rate, and maximum sequence length (MSL), the number of tokens that documents are truncated to. We observe that model quality is quite sensitive to the number of epochs, and thus the number must be tailored for each dataset. We train on Reuters, AAPD, and IMDB for 30, 20, and 4 epochs, respectively. Due to resource constraints, we train on Yelp for only one epoch. As is the case with Devlin et al. (2019), we find that choosing a batch size of 16, learning rate of 2×10 −5 , and MSL of 512 tokens yields optimal model quality on the validation sets for all the datasets.
For distillation, we train an LSTM to capture the learnt representations from BERT large using the objective shown in Equation (1). We use a batch size of 128 for the multi-label tasks and 64 for the single-label tasks. We find the learning rates and dropout rates used in Adhikari et al. (2019) to be optimal even for the distillation process.
To build an effective transfer set for distillation as suggested by Hinton et al. (2015), we augment the training splits of the datasets by applying POSguided word swapping and random masking for data augmentation, similar to . For the distillation objective given in Equation (1), we use a λ of 1 for multi-label datasets and 4 for single-label datasets.
A.2 Hyperparameter Analysis for DocBERT MSL analysis. A decrease in the maximum sequence length (MSL) corresponds to only a minor loss in F 1 on Reuters (see top-left subplot in Figure 2), possibly due to Reuters having shorter documents. On IMDB (top-right subplot in Figure 2), lowering the MSL corresponds to a drastic fall in accuracy, suggesting that the entire document is necessary for this dataset.
On the one hand, these results appear obvious. Alternatively, one can argue that, since IMDB contains longer documents, truncating tokens may hurt less. The top two subplots in Figure 2 show that this is not the case, since truncating to even 256 tokens causes accuracy to fall lower than that of the much smaller LSTM reg (see Table 2). From these results, we conclude that any amount of truncation is detrimental in document classification, but the level of degradation may differ. Epoch analysis. The bottom two subplots in Figure 2 illustrate the F 1 score of BERT fine-tuned using different numbers of epochs for AAPD and Reuters. Contrary to Devlin et al. (2019), who achieve the state of the art on small datasets with only a few epochs of fine-tuning, we find that smaller datasets require many more epochs to converge. On both the datasets (see Figure 2), we see a significant drop in model quality when the BERT models are fine-tuned for only four epochs, as suggested in the original paper. On Reuters, using four epochs result in an F 1 worse than even logistic regression ( Table 2, row 1).