Layer-wise Guided Training for BERT: Learning Incrementally Refined Document Representations

Although BERT is widely used by the NLP community, little is known about its inner workings. Several attempts have been made to shed light on certain aspects of BERT, often with contradicting conclusions. A much raised concern focuses on BERT’s over-parameterization and under-utilization issues. To this end, we propose o novel approach to fine-tune BERT in a structured manner. Specifically, we focus on Large Scale Multilabel Text Classification (LMTC) where documents are assigned with one or more labels from a large predefined set of hierarchically organized labels. Our approach guides specific BERT layers to predict labels from specific hierarchy levels. Experimenting with two LMTC datasets we show that this structured fine-tuning approach not only yields better classification results but also leads to better parameter utilization.


Introduction
Despite BERT's (Devlin et al., 2019) popularity and effectiveness, little is known about its inner workings. Several attempts have been made to demystify certain aspects of BERT (Rogers et al., 2020), often leading to contradicting conclusions. For instance, Clark et al. (2019) argue that attention measures the importance of a particular word when computing the next level representation for this word. However, Kovaleva et al. (2019) showed that most attention heads contain trivial linguistic information and follow a vertical pattern (attention to [cls], [sep], and punctuation tokens), which could be related to under-utilization or overparameterization issues. Other studies attempted to link specific BERT heads with linguistically interpretable functions (Htut et al., 2019;Clark et al., 2019;Kovaleva et al., 2019;Voita et al., 2019;Hoover et al., 2020;Lin et al., 2019), agreeing that no single head densely encodes enough relevant information but instead different linguistic features are learnt by different attention heads. We hypothesize that the aforementioned largely contributes to the lack of attention-based explainability of BERT. Another open topic is how the knowledge is distributed across BERT layers. Most studies agree that syntactic knowledge is gathered in the middle layers (Hewitt and Manning, 2019;Goldberg, 2019;Jawahar et al., 2019), while the final layers are more task-specific. Most importantly, it seems that any semantic knowledge is spread across the model, explaining why non-trivial tasks are better solved at the higher layers (Tenney et al., 2019).
Driven by the above discussion, we propose a novel fine-tuning approach where different parts of BERT are guided to directly solve increasingly challenging classification tasks following an underlying label hierarchy. Specifically, we focus on Large Scale Multilabel Text Classification (LMTC) where documents are assigned with one or more labels from a large predefined set. The labels are organized in a hierarchy from general to specific concepts. Our approach attempts to tie specific BERT layers with specific hierarchy levels. In effect, each of these layers is responsible for predicting the labels of the corresponding level. We experiment with two LMTC datasets (EURLEX57K, MIMIC-III) and several variations of structured BERT training. Our contributions are: (a) We propose a novel structured approach to fine-tune BERT where specific layers are tied to specific hierarchy levels; (b) We show that structured training yields better results than the baseline across all levels of the hierarchy, while also leading to better parameter utilization.   700 words long and is annotated with one or more concepts from EUROVOC 2 which contains 7,391 concepts organized in an 8-level hierarchy. We truncate the hierarchy to 6 levels by discarding the last 2 levels which contain 50 rarely used labels. 3 MIMIC-III (Johnson et al., 2017) contains approx. 52k discharge summaries from US hospitals. Each summary is approx. 1.6k words long and is annotated with one or more ICD-9 4 codes. ICD-9 contains 22,395 codes organized in a 7-level hierarchy. We truncate the hierarchy to 6 levels, discarding the first level which contains only 4 general codes. 3 Label Augmentation: In both datasets, we make the assumption that if a label l is assigned to a document then all of its ancestors should also be assigned to this document. Hence, we augment labels by annotating a document with all the ancestors of its assigned labels. For instance, in EUROVOC, if a document is annotated with the label grape it will also be annotated with grape's ancestors, i.e., fruit, plant product, and agri-foodstuffs ( Figure 2). This assumption is perfectly valid, while also having the added side effect of providing a more accurate test-bed for evaluation. For example, if a classifier mistakenly annotated the document with citrus fruit, a sibling of grape, in the non-augmented case it would receive a score of zero. By contrast, in the augmented case, assuming it correctly identified all the ancestors of citrus fruit it would receive a much higher score of 0.75 having correctly assigned the three ancestors of grape but not the (more specialized) label itself. Thus, we believe the model is evaluated more fairly in the augmented case with respect to the hierarchy. This type of evaluation is also in-line with the literature on hierarchical classification (Kosmopoulos et al., 2015).

Structured Learning with BERT
Before we proceed with the description of our methods (Figure 1), we introduce some notation. Given a label hierarchy L of depth d, L n denotes the set of labels in the n th level of this hierarchy (n d).
token in the i th BERT layer, 5 and σ is the sigmoid activation function. Note that the sizes of W i and b depend on the number of labels that f i is responsible for predicting, i.e., if f i predicts the labels of L n , W i ∈ R |Ln|×768 and b i ∈ R |Ln|×1 .
FLAT: This is a simple baseline which uses f 12 to predict all labels in the hierarchy in a flat manner. In effect, W 12 ∈ R |L|×768 and b i ∈ R |L|×1 . Note   LAST-SIX: This method uses the classifiers f 7 through f 12 to predict the labels in L 1 through L 6 , respectively. Our intuition is that the layers 1-6 will retain and enhance their pre-trained functionality, i.e., syntactic knowledge, contextualized representations, while layers 7-12 will leverage this knowledge to better solve their individual tasks. We also expect that the model will show higher parameter utilization for the layers 7-12 since they are forced to solve gradually more refined classification tasks.
ONE-BY-ONE: This method utilizes the full depth of BERT in a "skip one, use one" fashion, i.e., it uses classifiers f i , i ∈ {2, 4, . . . , 12}. In effect, the odd layers (1, 3, . . . , 11) are updated only indirectly, through the classification tasks of the even layers. We expect that the odd layers will learn rich latent representations to facilitate the classifiers of the even layers. Spreading the classification tasks across the whole depth of the model will potentially lead to better parameter utilization. On the other hand, it could harm the model's pre-trained functionality and hence its performance.

IN-PAIRS:
This method also exploits the full depth of BERT, but now the layers are grouped in 6 pairs, p n ∈ {(1, 2), (3, 4), . . . , (11, 12)}. The classifier responsible for the labels of L n operates on the concatenated [cls] tokens of the corresponding pair, e.g., f 1 = σ(W 1 · [c 1 ; c 2 ] + b 1 ) is trained on the labels of L 1 . We expect IN-PAIRS to have better pa-rameter utilization than ONE-BY-ONE, although the risk of hindering performance is now even higher.
HYBRID: Similarly to LAST-SIX, this method skips some of the lower BERT layers (3 instead of 6). Also, it ties L 1 , L 2 , and L 6 , which are the hierarchy levels with the fewest labels to layers 4, 5, and 12, respectively. Finally, similarly to IN-PAIRS the remaining BERT layers are grouped in pairs and are tied to the rest of the hierarchy levels. We expect the first three layers to retain and enhance their pre-trained functionality, while the hierarchy levels with a large number of labels will benefit from the additional parameters at their disposal.

Experiments
We report R-Precision (Manning et al., 2009) at each hierarchy level as well as micro (flat) and macro averages across all levels. 6 Table 1 shows the results in both datasets. In EURLEX57K our structured methods always outperform the baseline mostly by a large margin. LAST-SIX achieves the best overall results and is superior than the other structured methods in all hierarchy levels indicating that allowing the lower layers to retain and enhance their pre-trained functionality is crucial. Similar observations can be made for MIMIC-III, but in this case the importance of not damaging BERT's pretrained functionality is even higher, as evident by the only minor improvements ONE-BY-ONE and IN-PAIRS have compared to FLAT. 7    Figure 3 shows the average angular distances between the [cls] representations of each layer on development data of EURLEX57K. 8 The angular distance is calculated on unit (L 2 normalized) vectors, takes values in [0, 1], and a distance of 0.5 indicates an angle of 90 • . We observe that ONE-BY-ONE leads to larger angles between the representations than LAST-SIX which in turn yields larger angles than FLAT. In effect, ONE-BY-ONE and to a lesser extent LAST-SIX lead to a better parameter utilization than FLAT. To better support this claim we provide a geometric interpretation. We first L 2 normalize all [cls] representations. Each normalized representation can be interpreted as a vector having its initial point at the origin and its terminal point at the surface of a 768-dimensional hyper-sphere (centered at the origin). The larger the angle between two [cls] vectors the further apart they are on the hyper-sphere's surface. Effectively, [cls] vectors with large angles between them cover a larger subarea of the hyper-sphere's surface indicating that the vector space is utilized to a higher extent which directly implies better parameter utilization.
Comparing attention distributions: Figure 4 shows the KL-Divergence of the average (across heads) attention for all layers on the development data of EURLEX57K. 8 A high KL-Divergence indicates that two layers attend to different sub-word units. Moreover, Table 2 reports the entropy (left column per layer) of the average (across heads) attention distribution per layer. A high entropy indicates that a layer attends to more sub-word units. Table 2 also reports the average KL-Divergence (right column per layer) between the attention distributions of each possible pair of heads in a layer. A high KL-Divergence indicates that each head attends to different sub-word units. A first observation is that FLAT attends to almost the same subword units across layers (small entropy differences and KL-Divergence across layers). Interestingly, the different attention heads focus on different subword units only in the middle layers (5-8). On the other hand, all the structured methods show better utilization of the attention mechanism, having higher entropy and KL-Divergence both across heads (Table 2) and across layers (Figure 4).

Related Work
Our approach is similar to Wehrmann et al. (2018) but they experiment with fully connected networks, which are not well suited for text classification, contrary to stacked transformers (Vaswani et al., 2017;Devlin et al., 2019). Similarly, Yan et al. (2015) used Convolutional Neural Networks, albeit with shallow hierarchies (2 levels). Although our approach leverages the label hierarchy it should not be confused with hierarchical classification methods (Silla and Freitas, 2011), which typically employ one classifier per node and cannot scale-up to large hierarchies when considering neural classifiers. A notable exception is the work of You et al.
(2019) who employed one bidirectional LSTM with label-wise attention (You et al., 2018) per hierarchy node. However, for their method to scale-up, they use probabilistic label trees (Khandagale et al., 2019) to organize the labels in their own shallow hierarchy which does not follow the abstraction level of the original hierarchy. To the best of our knowledge we are the first to apply this approach to pre-trained language models.

Conclusions and Further Work
We proposed a novel guided approach to fine-tune BERT, where specific layers are tied to specific hierarchy levels. Experimenting with two LMTC datasets, we showed that structured training not only yields better results than a flat baseline, but also leads to better parameter utilization. In the future we will try to further increase the parameter utilization by guiding BERT's attention heads to explicitly focus on specific hierarchy parts. We also plan to improve the explainability of our methods with respect to the utilization of their parameters.  Table 3: Label distribution across EUROVOC and ICD-9 hierarchy levels. Concepts (labels) are arranged from more abstract (level 1-2) to more specialized ones (levels 6-8). Labels with an asterisk are truncated in our experiments.

References
MIMIC-III, we use regular expressions tailored for the biomedical domain. While document length is severely reduced post normalization, if a document still has a larger number of tokens, i.e. more than 512, we use the first 512 tokens and ignore the rest.

B Experimental Setup
All our methods build on BERT-BASE and are implemented in Tensorflow 2. For EURLEX we use the original BERT-BASE (Devlin et al., 2019), while for MIMIC-III we use SCIBERT (Beltagy et al., 2019), which has the same architecture (12 layers, 768 hidden units, 12 attention heads), and better suits biomedical documents. 11 Our models are tuned by grid searching three learning rates (2e-5, 3e-5, 5e-5) and two drop-out rates (0, 0.1). We use the Adam optimizer (Kingma and Ba, 2015) with early stopping on validation loss. In preliminary experiments, we found that weighting individual losses with respect to the number of labels in each level is crucial. We therefore weigh each loss by the percentage of labels at the corresponding level, i.e., w n = |Ln| |L| , where |L n | is the number of labels in the n th level of the hierarchy and |L| is the total number of labels across all levels, e.g., in EURLEX57K, w 1 = 21 8093 ≈ 0.0026.

C Evaluation in LMTC
The literature of LMTC (Rios and Kavuluru, 2018;Chalkidis et al., 2019) mostly uses information retrieval evaluation measures. We support the premise that when the number of labels is that large the problem mimics retrieval with each document acting as a query and the model having to score relevant labels higher than the rest. However in our study, it would be really confusing to report the standard retrieval metrics Recall@R, Precision@K, nDCG@K since we evaluate our classifiers at each hierarchy depth and reasonable values for K have large fluctuations between levels, as the number of labels per level vastly varies (see Table 3). Instead, we prefer R-Precision (Manning et al., 2009), which is the Precision@R where R is the number of gold labels associated with each document. It follows that R-Precision can neither under-estimate (penalize) nor over-estimate the performance of the models (Chalkidis et al., 2019).

D Peculiarities of MIMIC-III dataset
In our experiments we observe a hindered performance in MIMIC-III, which can be attributed to a number of characteristics of the dataset. Firstly, documents contain a lot of non-trivial biomedical terminology which naturally makes the classification task more difficult. Further, discharge summaries describe a patient's condition during their hospitalization and therefore proper label annotations change throughout the document as the patient's diagnosis changes or as they exhibit new symptoms, e.g., "the patient was admitted to the hospital with no heart issues, [. . . ] the patient had a heart failure and died.". Both the in-domain language and the constant change of events make the dataset more challenging than EURLEX57K, where documents are more organized and well-written also with simpler language. It therefore seems reasonable that in MIMIC-III allowing lower BERT layers to retain and enhance the preliminary functionality, without explicitly guiding them, is of utmost importance. We would like to highlight that even though we use SCIBERT (Beltagy et al., 2019), which is based on a new scientific vocabulary, we observe that specialized biomedical terms are often over-fragmented in multiple sub-word units, e.g. 'atelectasis' splits into ['ate', '##lect', '##asis']. Thus, the initial layers need to decipher these over-fragmented sub-word units and reconstruct the original word semantics. On the contrary, in EURLEX57K, classifying general concepts in the initial layers, even considering only the sub-word unit embeddings is plausible.

E Discussion on model utilization
We present additional results for the rest of the methods (IN-PAIRS, HYBRID). Figure 5 shows the average angular distances between the [cls] representations of each layer ( Figure 5) for all con-  sidered methods. We observe that the distances of IN-PAIRS between consecutive [cls] representations follow a similar pattern with those of ONE-BY-ONE, with the exception of 0.25+ distances which are more dense in the upper layers for IN-PAIRS. This is reasonable, since in IN-PAIRS all layers directly contribute to the classification tasks. The pattern of HYBRID is very similar to ONE-BY-ONE and IN-PAIRS, except for the first three non-guided layers in which distances bear close resemblance to those of the corresponding layers in LAST-SIX. Similar observations hold for MIMIC-III (Figure 7). Finally, Figure 6 shows the KL-Divergence of the average (across heads) attention for all layers on the development data. All structured methods show better utilization of the attention mechanism than FLAT, having higher KL-Divergence across layers. Contrary, in MIMIC-III, all structured methods fol-  low a similar pattern of low KL-Divergence across layers (Figure 8), even lower than the upper layers of FLAT, i.e., the models attend to similar sub-word positions across layers. We aim to further study and explain this behaviour in future work.