Integrating Task Specific Information into Pretrained Language Models for Low Resource Fine Tuning

Pretrained Language Models (PLMs) have improved the performance of natural language understanding in recent years. Such models are pretrained on large corpora, which encode the general prior knowledge of natural languages but are agnostic to information characteristic of downstream tasks. This often results in overfitting when fine-tuned with low resource datasets where task-specific information is limited. In this paper, we integrate label information as a task-specific prior into the self-attention component of pretrained BERT models. Experiments on several benchmarks and real-word datasets suggest that the proposed approach can largely improve the performance of pretrained models when fine-tuning with small datasets.


Introduction
Recently, Pretrained Language Models (PLMs) (Devlin et al., 2018;Radford et al., 2019) have yield significant progress on various natural language processing (NLP) tasks, e.g., neural language understanding, text generation, etc. Existing PLMs are usually pretrained in a task-agnostic manner, in which the model is expected to capture the general knowledge of natural language from a large corpus, independent of downstream-specific information. This is not a problem when data is abundant in the downstream dataset, in which case, the model can effectively extract task-specific information during fine-tuning. However, in real scenarios, data may be difficult to collect and labeling is usually expensive. We show that PLMs pretrained with general knowledge can overfit without enough guidance from the task-specific information, resulting in degraded performance during testing. * These authors contributed equally to this work A clear-cut solution to this problem is to focus more on samples that are more relevant to the target task during pretraining. However, this requires a task-specific pretraining, which in most cases is computational or time prohibitive. Another approach is to pretrain on an auxiliary dataset before fine-tuning on the target task (Phang et al., 2018). Such method requires the availability of an appropriate auxiliary datasets. Unfortunately, in some cases it may negatively impact the downstream transfer (Wang et al., 2018a). Label embeddings (Akata et al., 2015) can be regarded as a featurebased definition of a classification task, in which detailed information of the task is encoded. One natural question is whether we can combine the general knowledge in a PLM and the task-specific characterization contained within label embeddings for better fine-tuning on low-resource tasks.
In this paper, we propose to utilize the label embeddings as a task-specific prior, complementary to the general prior already encoded during pretraining. We learn and integrate these label embeddings into BERT models (Devlin et al., 2018) to regularize its self-attention modules, so the task-irrelevant tokens or patterns can be readily filtered out, while the task-specific information can be enhanced during fine-tuning. Such a modification is compatible with any PLM built upon self-attention and will not degrade the original pretrained structure.
In order to validate the performance of our approach in a real-world setting, we collected two text classification datasets from the online patient portal of a large academic health system, each with a few thousand sequences. These are the first datasets for automatic patient message triage, which constitute an important problem in the field of clinical data analysis. Experimental results show that our approach significantly improves the performance of fine-tuning on low-resource datasets, e.g., those consisting of only several thousand data samples.

Related Work
Label embeddings have been previously leveraged for image classification (Akata et al., 2015), multimodal learning between images and text (Kiros et al., 2014), text recognition in images (Rodriguez-Serrano and Perronnin, 2015), zero-shot learning (Li et al., 2015;Ma et al., 2016) and text classification (Zhang et al., 2017). Notably, LEAM (Wang et al., 2018b) jointly embeds words (tokens) and labels in a common latent space as a means to improve the performance on general text classification tasks. Further, Moreo et al. (2019) concatenates label embedding with word embeddings. However, this approach cannot be directly implemented into PLMs since the new (concatenated) embedding is not compatible with the pretrained parameters. We integrate label embeddings into the self-attention of BERT models, so the attention can be regularized to better focus on task-relevant information.

The BERT Model
The encoder of BERT and other popular PLMs are built upon the transformer architecture, which is composed of multiple layers of multi-head selfattention and position-wise feed-forward layers. Multi-head Self-attention The multi-head selfattention is an ensemble of multiple single-head self-attention modules. Let X ∈ R L×D be the embedding matrix of the input sequence with length L. For each single head, the input sequence is first mapped into the key, query and value triplet, denoted as, where i = 1, . . . , h, h is the number of heads, softmax(·) is the row-wise softmax function and d is the head dimension. A is the attention score matrix representing the compatibility between Q and K. The multi-head self-attention is defined by concatenating and projecting {H i } h i=1 , the representation of each head, intoĤ ∈ R L×D .
Positional-wise Feed Forward Layer After selfattention, a fully connected network is applied on each token representation x using which consists of two linear transformations and ReLU activations.
In BERT, the input sequence starts with a [CLS] token, whose hidden state will be extracted as the sequence representation for classification. Let CE(·, ·) be the cross-entropy loss, C(·) be the final classifier and enc(·) be the encoder consisting of a stack of transformer layers. The classification loss can be written as, where enc(X) [CLS] is the representation of [CLS] after encoding, y is the classification label and D is a dataset.
In the context of graph embeddings (Kipf and Welling, 2016), the [CLS] token acts as a super node that connects to all other tokens (nodes) and aggregates global information during self-attention (convolution). After training, the embedding of the [CLS] token should contain the task-specific information, so that it can mostly attend to task relevant information in self-attention during inference. However, embeddings of the PLMs are pretrained agnostic to downstream tasks. When fine-tuning with low-resource datasets where label information is scarce, a single [CLS] token may not capture enough task specific information, resulting in model overfitting to task irrelevant tokens or patterns in the input sequences.

Integrating Label Embedding into Self-Attention
In this paper, we propose to leverage label embeddings to optimize the self-attention modules, so the model can better focus on task-relevant information when fine-tuned with small datasets. We reformulate the representations in (1) as and X ∈ R (L−1)×D represent the embeddings of [CLS] and the other tokens in the sequence, respectively. The attention score matrix can be rewritten as, . Let X l ∈ R M ×D be the label embedding matrix, where M is the number of classes. We first compute the cross attention between X l and X as where X l is encoded in to Q l with the same mapping matrix W Q as in (1). Then, we compute a modified cross-attention row vector S by concatenating S and A l by row and keeping the maximum value of each column, As a result, S represents the maximum attention score of a input token with both [CLS] and the label embeddings. A new attention score matrix A can be obtained by replacing S with S in (5), In (8), when a token is highly relevant to one of the labels, it will result in a larger attention score in S , thus the [CLS] embedding will be less affected by irrelevant information in the sequence, unlike (2) where only attention from the current [CLS] embedding is considered. The proposed attention layer is shown in Figure 1(b). The attention score matrix A in (2) is replaced as A in (8). All other components are the same as the original layers in BERT as in (1)-(3.1).
We share the same label embedding X l for all the layers. The label embedding is adapted on each layer via W Q in the multi-head attention module. As shown in Figure 1a, we also feed X l into the final classifier C(·), so the label embeddings can be classified into their corresponding classes. The final loss for classification is then where X i l is the i-th label embedding, λ is a tradeoff parameter between the regularization on label embeddings and the original classification loss.
The label embeddings can be initialized randomly or by the pretrained embeddings of relevant keywords. When the label is not identified by keywords, e.g., in sentence entailment tasks, their embeddings can be initialized with the representations of [CLS], averaged over samples from the same class. All other parameters can be initialized from the pretrained BERT. This modification can be adapted to any PLM with self-attention modules.

Experiments
We focus on fine-tuning with small datasets. We integrate label embeddings into the pretrained (Bio)BERT models, and fine-tune on various classification benchmarks as well as two real-world clinical datasets that we collected from the online patient portal of a large academic health system. Table 1 shows the results of integrating label embedding into the pretrained bert-based-uncased model on 9 public classification benchmarks of various sizes. We find that our method improves the results from BERT on small datasets, e.g, WNLI, MRPC, CoLA, etc, which typically have only several thousand data samples available for fine-tuning. This shows that the BERT model, which is pretrained with task-agnostic objectives, is more likely (a) Attention from the BioBERT.

Public Benchmarks
(b) Attention from our method.

Figure 2: Examples of the attention from the [CLS]
token in the final attention layer. The sequences are sampled from the Message-urgency dataset. Red color indicates higher attention score. It can be shown that our method can better focus on keywords, e.g., 'chest', 'bad' and 'stairs', which are more likely to ocurr on urgent requests. Alternatively, BioBERT fine-tuned on such a small dataset tends to overfit to task-irrelevant words, such as 'holiday ', 'school', 'tests', etc. to overfit when there is limited task-specific information during fine-tuning. However, our method produces comparable results on larger datasets such as MNLI and QQP. This is consistent with the study in Lazar (2003) where additional priors are less useful when the size of dataset grows larger. These results suggest that our method is more suitable for fine-tuning with smaller amounts of data, and that our approach to injecting the label information is at least not detrimental to the original pretrained model. This supports the intuition of combining the pretrained general knowledge and the task-specific information for better fine-tuning with small datasets.
We note that label information can improve the results on many tasks of neural language inference, e.g., WMLI and QQP, where classes are not identified by keywords, but rather certain patterns in the input sentence pair. This may be because the self-attention will encode these input patterns into intermediate tokens, which act as pseudo keywords

Patient Message Triage
We further evaluate the proposed approach in realworld scenarios of patient message classification. This is a task motivated by the increasing popularity of online patient portals. Most of the patient messages generated from the portal are non-urgent, while the doctors are expected to focus on the urgent requests, which amount to only a small portion (about 10%) of all messages. As a result, the heath providers will have to spend considerable time just identifying urgent messages, thus being less efficient at emergency responses. We obtain two healthcare datasets -Message-urgency and Acknowledgmentfrom a large academic health system online portal. Detailed description of these two datasets can be found in Appendix A. We employ our method on the BioBERT pretrained model (Lee et al., 2020), which has the same architecture as BERT but further pretrained on the clinical corpora. Results are shown in Table 2. Our model improves on all the baselines in terms of F1 score, which validates the usefulness of the proposed method for low-resource fine-tuning in the real scenarios.

Conclusion
We propose to integrate task specific information into PLMs that are pretrained with task-agnostic objectives. To do this, we leverage label embeddings to regularize the self-attention in PLMs. Results on public benchmarks and real-world datasets suggest that our method can effectively improve the results for low resource fine-tuning.

A Description of healthcare datasets
In this work, we utilized 1,756 web portal messages generated from 10/2014 to 08/2018 by adult patients (> 18 years old) of a large academic medical center. The Electronic Health Record (EHR) system (Epic Verona, WI, USA) with associated patient portal (MyChart) was the source of all patient messages. A custom-built Application Programming Interface (API) securely made available the portal messages from the EHR enterprise data warehouse into a highly protected virtual network space offered by the medical center. Approved users were allowed access to work with the identifiable   protected health information. These messages included free, unstructured plain text sent by patients to their healthcare team. Responses and messages sent from the clinician or health system to the patient were excluded from the analysis.

A.1 Message-urgency dataset
In message-urgency dataset, portal messages were manually labeled by experienced sub-specialty (cardiology) clinicians into three levels of priority: non-urgent, medium and urgent. Non-urgent labels include notes of appreciation (e.g., thank you). The Medium urgency class contains messages that could be reasonably responded to in 1-3 days. Urgent messages are those requiring an immediate phone call to the patient by the clinician. Conditions suggesting acute myocardial infarction, exacerbation of heart failure respiratory distress or possible stroke were labeled as urgent and would be inappropriate for an asynchronous patient portal.

A.2 Acknowledgment dataset
This acknowledgment dataset is randomly selected from patient's responses to the hospital. A signifi-cant portion of these messages is purely acknowledgment, like 'Thank you'. It would be helpful if this type of messages can be filtered out, so that hospital staff can focus on non-trivial messages. A doctor and a nurse labelled and validated this dataset.

B Implementation Details
For all the experiments, we use finetune the pretrained model for 3 epoches with learning rate 2e-5 and batch size 32. We use the Adam training algorithm. λ is generally set to 3. We set warm up steps as 10 percent of the total training steps. We do not apply weight decay and the norm of all the gradients are clipped by 1. Experiments on the public benchmarks are run on a TITAN X (Pascal) 1080 gpu. The healthcare experiment are run on the CPU in a secured virtual machine system.