DoubleTransfer at MEDIQA 2019: Multi-Source Transfer Learning for Natural Language Understanding in the Medical Domain

This paper describes our competing system to enter the MEDIQA-2019 competition. We use a multi-source transfer learning approach to transfer the knowledge from MT-DNN and SciBERT to natural language understanding tasks in the medical domain. For transfer learning fine-tuning, we use multi-task learning on NLI, RQE and QA tasks on general and medical domains to improve performance. The proposed methods are proved effective for natural language understanding in the medical domain, and we rank the first place on the QA task.


Background
The MEDIQA 2019 shared tasks (Ben Abacha et al., 2019) aim to improve the current state-ofthe-art systems for textual inference, question entailment and question answering in the medical domain.This ACL-BioNLP 2019 shared task is motivated by a need to develop relevant methods, techniques and gold standards for inference and entailment in the medical domain and their application to improve domain-specific information retrieval and question answering systems.The shared task consists of three parts: i) natural language inference (NLI) on MedNLI, ii) Recognizing Question Entailment (RQE), and iii) Question Answering (QA).
Recent advancement in NLP such as BERT (Devlin et al., 2018) has facilitated great improvements in many Natural Language Understanding (NLU) tasks (Liu et al., 2019b).BERT first trains a language model on an unsupervised large-scale corpus, and then the pretrained model is fine-tuned to adapt to downstream NLU tasks.This finetuning process can be seen as a form of transfer learning, where BERT learns knowledge from the large-scale corpus and transfer it to downstream tasks.
We investigate NLU in the medical (scientific) domain.From BERT, we need to adapt to i) The change from general domain corpus to scientific language; ii) The change from low-level language model tasks to complex NLU tasks.Although there is limited training data in NLU in the medical domain, we fortunately have pre-trained models from two intermediate steps: • General NLU embeddings: We use MT-DNN (Liu et al., 2019b) trained on GLUE benchmark (Wang et al., 2019).MT-DNN is trained on 10 tasks including NLI, question equivalence, and machine comprehension.These tasks correspond well to the target MEDIQA tasks but in different domains.
• Scientific embeddings: We use SciBERT (Beltagy et al., 2019), which is a BERT model, but trained on SemanticScholar scientific papers.Although SciBERT obtained state-of-the-art results on several singlesentence tasks, it lacks knowledge from other NLU tasks such as GLUE.
In this paper, we investigate different methods to combine and transfer the knowledge from the two different sources and illustrate our results on the MEDIQA shared task.We name our method as DoubleTransfer, since it transfers knowledge from two different sources.Our method is based on fine-tuning both MT-DNN and SciBERT using multi-task learning, which has demonstrated the efficiency of knowledge transformation (Caruana, 1997;Liu et al., 2015;Xu et al., 2019;Liu et al., 2019b), and integrating models from both domains with ensembles.Related Works.Transfer learning has been widely used in training models in the medical do-Algorithm 1 Multi-task Fine-tuning with External Datasets Require: In-domain datasets D 1 , ..., D K 1 , External domain datasets D K 1 +1 , ..., D K 2 , max epoch, mixture ratio α 1: Initialize the model M 2: for epoch= 1, 2, ..., max epoch do Randomly pick αN mini-batches from K 2 k=K 1 D k and add to S Evaluate development set performance on D 1 , ..., D K 1 12: end for Ensure: Model with best evaluation performance main.For example, Romanov and Shivade (2018) leveraged the knowledge learned from SNLI to MedNLI; a transfer from general domain NLI to medical domain NLI.They also employed word embeddings trained on MIMIC-III medical notes, which can be seen as a language model in the scientific domain.SciBERT (Beltagy et al., 2019) studies transferring knowledge from SciB-ERT pretrained model to single-sentence classification tasks.Our problem is unique because of the prohibitive cost to train BERT: Either BERT or SciBERT requires a very long time to train, so we only explore how to combine the existing embeddings from SciBERT or MT-DNN.Transfer learning is also widely used in other tasks of NLP, such as machine translation (Bahdanau et al., 2014) and machine reading comprehension (Xu et al., 2019).

Methods
We propose a multi-task learning method for the medical domain data.It employs datasets/tasks from both medical domain and external domains, and leverage the pre-trained model such as MT-DNN and SciBERT for fine-tuning.An overview of the proposed method is illustrated in Figure 1.To further improve the performance, we propose to ensemble models trained from different initialization in the evaluation stage.Below we detail our methods for fine-tuning and ensembles.

Fine-tuning details
Algorithm.We fine-tune the two types of pretrained models on all the three tasks using multitask learning.As suggested by MEDIQA paper, we also fine-tune our model on MedQuAD (Abacha and Demner-Fushman, 2019), a medical QA dataset.We will provide details for fine-tuning on these datasets in Section 2.3.We additionally regularize the model by also training on MNLI (Williams et al., 2018).To prevent the negative transfer from MNLI, we put a larger weight on MEDIQA data by sampling MNLI data with less probability.Our algorithm is presented in Algorithm 1 and illustrated as Figure 1, which is a mixture ratio method for multitask learning inspired by Xu et al. (2019).We start with in-domain datasets D 1 , ...D K 1 (i.e., the MEDIQA tasks, K 1 = 3) and external datasets D K 1 +1 , ..., D K 2 (in this case MNLI).We cast all the training samples as sentence pairs (s 1 , s 2 ) ∈ D k , k = 1, 2, ..., K 2 .In each epoch of training, we use all mini-batches from in-domain data, while only a small proportion (controlled by α ) of minibatches from external datasets are used to train the model.In our experiments, the mixture ratio α is set to 0.5.We use MedNLI, RQE, QA, and MedQuAD in medical domain as in-domain data and MNLI as external data.For MedNLI, we additionally find that using MedNLI as in-domain data and RQE, QA, MedQuAD as external data can also help boost performance.We use models trained using both setups of external data for en-D k < l a t e x i t s h a 1 _ b a s e 6 4 = " n I w x 8 4 n z / 6 S Z I 4 < / l a t e x i t > < l a t e x i t s h a 1 _ b a s e 6 4 = " n I w x 8 4 n z / 6 S Z I 4 < / l a t e x i t > < l a t e x i t s h a 1 _ b a s e 6 4 = " n I < l a t e x i t s h a 1 _ b a s e 6 4 = " 9 E t K C 8 U x d h 3 q 9 J 1 t p 9 a 9 r 9 7 A         et al., 2019).We add a simple softmax layer (or linear layer for QA and MedQuAD tasks) atop BERT as the answer module for fine-tuning.For initialization in step 1 in Algorithm 1, we initialize all BERT weights with the pretrained weights, and randomly initialize the answer layers.After multi-task fine-tuning, the joint model is further fine-tuned on each specific task to get better performance.We detail the training loss and finetuning process for each task in Section 2.3.
Objectives.MedNLI and RQE are binary classification tasks, and we use a cross-entropy loss.Specifically, for a sentence pair X we compute the loss where c iterates over all possible classes, 1(X, c) is the binary indicator (0 or 1) if class label c is the correct classification for X, and P r (c|X) is the model prediction for probability of class c for sample X.
We formulate QA and MedQuAD as regression tasks, and thus a MSE loss is used.Specifically, for a question-answer pair (Q, A) we compute the MSE loss as where y is the target relevance score for pair (Q, A), and score(Q, A) is the model prediction for the same pair.

Model Ensembles
After fine-tuning, we ensemble models trained from MT-DNN and SciBERT, and using different setups of in-domain and external datasets.The traditional methods typically fuse models by averaging the prediction probability of different models.For our setting, the in-domain data is very limited and it tends to overfit; this means the predictions can be arbitrarily close to 1, favoring to more over-fitting models.To prevent over-fitting, we ensemble the models by using a majority vote on their predictions, and resolving ties using sum of prediction probabilities.Suppose we have M models, and the m-th model predicts the answer pm for a specific question.In other words, we first obtain the majority of predictions by computing the majority maj({ŷ m } M m=1 ), and resolve the ties by computing the sum of prediction probabilities M m=1 p(y) m .For QA tasks (QA and MedQuAD), the task is cast as a regression problem, where a positive number means correct answer, and negative otherwise.We have pm ∈ R. We first compute the average score pensem = 1 M M m=1 pm .We also compute the prediction as ŷm = I(p m ≥ 0), where I is the indicator function.We compute the ensemble prediction through a similar majority vote as the classification case: To be precise, we predict the majority if a tie does not exist, or the sign of pensem otherwise.The final ranking of answers is carried out by first rank the (predicted) positive answers, and then the (predicted) negative answers.

Dataset-Specific Details
MedNLI: Since the MEDIQA shared task uses a different test set than the original MedNLI dataset, we merge the original MedNLI development set into the training set and use evaluation performance on the original MedNLI test set.Furthermore, MedNLI and MNLI are the same NLI tasks, thus, we shared final-layer classifiers for these two tasks.For MedNLI, we find that each consecutive 3 samples in all the training set contain the same premise with different hypothesizes, and contains exactly 1 entail, 1 neutral and 1 contradiction.To the end, in our prediction, we constrain the three predictions to be one of each kind, and use the most likely prediction from the model prediction probabilities.RQE: We use the clinical question as the premise and question from FAQ as the hypothesis.We find that the test data distribution is quite different from the train data distribution.To mitigate this effect, we randomly shuffle half of the evaluation data into the training set and evaluate on the remaining half.

QA:
We use the answer as the premise and the question as the hypothesis.The QA task is cast as both a ranking task and a classification task.Each question is associated with a relevance score in {1, 2, 3, 4}, and an additional rank over all the answers for a specific question is given.We use a modified score to incorporate both information: suppose there are m questions with relevance score s ∈ {1, 2, 3, 4}.Then the i-th most relevant answer in these m questions get modified score s − i−1 m .In this way the scores are uniformly distributed in (s − 1, s].We shift all scores by −2 so that a positive score leads to a correct answer and vice versa.We also tried pairwise losses to incorporate the ranking but did not find it to boost the performance very much. We find that the development set distribution is inconsistent with test data -the training and test set consist of both LiveQAMed and Alexa questions, whereas the development set seems to only contain LiveQAMed questions.We shuffle the training and development set to make them similar: We use the last 25 questions in original development set (LiveQAMed questions) and the last 25 Alexa questions (from the original training set) as our development set, and use the remaining questions as our training set.This results in 1,504 training pairs and 431 validation pairs.Due to the limited size of the QA dataset, we use cross-validation that divides all pairs into 5 slices and train 5 models by using each slice as a validation set.We train MT-DNN and SciBERT on both these 5 setups and obtain 10 models, and ensemble all the 10 models obtained.MedQuAD: We use 10,109 questions from MedQuAD because the remaining questions are not available due to copyright issues.The original MedQuAD dataset only contains positive question pairs.We add negative samples to the dataset by randomly sampling an answer from the same web page.For each positive QA pair, we add two negative samples.The resulting 30,327 pairs are randomly divided into 27,391 training pairs and 2,936 evaluation pairs.Then we use the same method as QA to train MedQuAD; we also share the same answer module between QA and MedQuAD.

Implementation and Hyperparameters
We implement our method using PyTorch 1 and Pytorch-pretrained-BERT2 , as an extension to MT-DNN3 .We also use the pytorch-compatible SciBERT pretrained model provided by Al-lenNLP4 .Each training example is pruned to at most 384 tokens for MT-DNN models and 512 tokens for SciBERT models.We use a batch size of 16 for MT-DNN, and 40 for SciBERT.For finetuning, we train the models for 20 epochs using a learning rate of 5 × 10 −5 .After that, we further fine-tune the model from the best multi-task model for 6 epochs for each dataset, using a learning rate of 5 × 10 −6 .We ensemble all models with an accuracy larger than 87.7 for MedNLI, 83.5 for shuffled RQE, and 83.0 for QA.We ensemble 4 models for MedNLI, 14 models for RQE.For QA, we ensemble 10 models from cross-validation and 7 models using the normal training-validation approach.

Results
In this section, we provide the leaderboard performance and conduct an analysis of the effect of ensemble models from different sources.

Test Set Performance and LeaderBoards
The results for MedNLI dataset is summarized in Table 1.Our method ends up the 3rd place on the leaderboard and substantially improving upon previous state-of-the-art (SOTA) methods.
The results for RQE dataset is summarized in Table 2. Our method ends up the 7th place on the leaderboard.Our method has a very large discrepancy between the dev set performance and test set performance.We think this is because the test set is quite different from dev set, and that the dev set is very small and easy to overfit to.
The results for QA dataset is summarized in Table 3.Our method reaches the first place on the leaderboard based on accuracy and precision score and 3rd-highest MRR.We note that the Spearman score is not consistent with other scores in the leaderboard; actually, the Spearman score is computed just based on the predicted positive answers, and a method can get very high Spearman score by never predict positive labels.

Ensembles from Different Sources
We compare the effect of ensembling from different sources in

Single-Model Performance
For completeness, we report the single-model performance on the MedNLI development set under various multi-task learning setups and initializations in Table 5. (1) The Naïve approach denotes only MedNLI, RQE, QA, MedQuAD is considered as in-domain data in Algorithm 1 without any external data; (2) The Ratio approach denotes that we consider MedNLI as in-domain data, and RQE, QA, MedQuAD as external data in Algorithm 1; (3) The Ratio+MNLI approach denotes that we consider MedNLI, RQE, QA, MedQuAD as in-domain data and MNLI as external data in Algorithm 1.Note that MNLI is much larger than the medical datasets, so if we use RQE, QA, MedQuAD, MNLI as external data, the performance is very similar to the third setting.We did not conduct experiments on single-dataset settings, as previous works have suggested that multitask learning can obtain much better results than single-task models (Liu et al., 2019b;Xu et al., 2019).Overall, the best results are achieved via using SciBERT as the pre-trained model, and multi-task learning with MNLI.The models trained by mixing in-domain data (the second setup) is also competitive.We therefore use models from both setups for ensemble.

Init Model
Naïve

Conclusion
We present new methods for multi-source transfer learning for the medical domain.Our results show that ensembles from different sources can improve model performance much more greatly than ensembles from a single source.Our methods are proved effective in the MEDIQA2019 shared task.
batches in S in a random order to obtain a sequence B = (b 1 , ..., b L ), where L = N + αN 8: for each mini-batch b ∈ B do 9:Perform gradient update on M with loss l(b) = (s 1 ,s 2 )∈b l(s 1 , s 2 ) 9 S J I s o c I Q j r X u I S 8 1 f o 6 V Y Y T T S b m f a Z p i M s I D 2 r N U 4 I R q P 5 8 d P o E n V o l g L J U t Y e B M / T 2 R 4 0 T r c R L a z g S b o V 7 0 p u J / X i 8 z 8 a W f M 5 F m h g o y X x R n H B o J p y n A i C l K D B 9 b g o l i 9 l Z I h l h h Y m x W Z R s C W n x 5 m b T P 6 s i r o / v z a u O 6 i K M E j s A x O A U I X I A G u A V N 0 A I E Z O A Z v I I 3 5 8 l 5 c d 6 d j 3 n r i l P M H I I / c D 5 / A N k y k e Y = < / l a t e x i t > < l a t e x i t s h a 1 _ b a s e 6 4 = " 9 E t K C 8 U x d h 3 q 9 J 1 t p 9 a 9 r 9 7 A P p E = " > A A A B + H i c b V D L S g M x F M 3 4 r P X R U Z d u g k V w U c p E B N 0 I R T e C m w r 2 A e 0 w Z D K Z N j S T D E l G q E O / x I 0 L R d z 6 K e 7 8 G 9 N 2 F t p 6 4 M L h n H u 5 9 5 4 w 5 U w b z 9 S J I s o c I Q j r X u I S 8 1 f o 6 V Y Y T T S b m f a Z p i M s I D 2 r N U 4 I R q P 5 8 d P o E n V o l g L J U t Y e B M / T 2 R 4 0 T r c R L a z g S b o V 7 0 p u J / X i 8 z 8 a W f M 5 F m h g o y X x R n H B o J p y n A i C l K D B 9 b g o l i 9 l Z I h l h h Y m x W Z R s C W n x 5 m b T P 6 s i r o / v z a u O 6 i K M E j s A x O A U I X I A G u A V N 0 A I E Z O A Z v I I 3 5 8 l 5 c d 6 d j 3 n r i l P M H I I / c D 5 / A N k y k e Y = < / l a t e x i t > < l a t e x i t s h a 1 _ b a s e 6 4 = " 9 E t K C 8 U x d h 3 q 9 J 1 t p 9 a 9 r 9 7 A P p E = " > A A A B + H i c b V D L S g M x F M 3 4 r P X R U Z d u g k V w U c p E B N 0 I R T e C m w r 2 A e 0 w Z D K Z N j S T D E l G q E O / x I 0 L R d z 6 K e 7 8 G 9 N 2 F t p 6 4 M L h n H u 5 9 5 4 w 5 U w b z 9 S J I s o c I Q j r X u I S 8 1 f o 6 V Y Y T T S b m f a Z p i M s I D 2 r N U 4 I R q P 5 8 d P o E n V o l g L J U t Y e B M / T 2 R 4 0 T r c R L a z g S b o V 7 0 p u J / X i 8 z 8 a W f M 5 F m h g o y X x R n H B o J p y n A i C l K D B 9 b g o l i 9 l Z I h l h h Y m x W Z R s C W n x 5 m b T P 6 s i r o / v z a u O 6 i K M E j s A x O A U I X I A G u A V N 0 A I E Z O A Z v I I 3 5 8 l 5 c d 6 d j 3 n r i l P M H I I / c D 5 / A N k y k e Y = < / l a t e x i t > < l a t e x i t s h a 1 _ b a s e 6 4 = " 9 E t K C 8 U x d h 3 q 9 J 1 t p 9 a 9 r 9 7 A P p E = " > A A A B + H i c b V D L S g M x F M 3 4 r P X R U Z d u g k V w U c p E B N 0 I R T e C m w r 2 A e 0 w Z D K Z N j S T D E l G q E O / x I 0 L R d z 6 K e 7 8 G 9 N 2 F t p 6 4 M L h n H u 5 9 5 4 w 5 U w b z 9 S J I s o c I Q j r X u I S 8 1 f o 6 V Y Y T T S b m f a Z p i M s I D 2 r N U 4 I R q P 5 8 d P o E n V o l g L J U t Y e B M / T 2 R 4 0 T r c R L a z g S b o V 7 0 p u J / X i 8 z 8 a W f M 5 F m h g o y X x R n H B o J p y n A i C l K D B 9 b g o l i 9 l Z I h l h h Y m x W Z R s C W n x 5 m b T P 6 s i r o / v z a u O 6 i K M E j s A x O A U I X I A G u A V N 0 A I E Z O A Z v I I 3 5 8 l 5 c d 6 d j 3 n r i l P M H I I / c D 5 / A N k y k e Y = < / l a t e x i t > k = K 1 +1, • • • , K 2 < l a t e x i t s h a 1 _ b a s e 6 4 = " g g 8 2 f n 5 P w b p G h 2 t e 5 l m 9 8 7 y A g u p b I R 4 i g b D S m Z V 0 C P b i y 8 u k U 6 / Z V s 2 + v 6 g 0 b v I 4 i u A Y n I A z Y I N L 0 A B 3 o A X a A I M J e A a v 4 M 1 4 M l 6 M d + N j 3 l o w 8 p l D 8 A f G 5 w 8 z M p Q 3 < / l a t e x i t > < l a t e x i t s h a 1 _ b a s e 6 4 = " g g 8 2 f n 5 P w b p G h 2 t e 5 l m 9 8 7 y A g u p b I R 4 i g b D S m Z V 0 C P b i y 8 u k U 6 / Z V s 2 + v 6 g 0 b v I 4 i u A Y n I A z Y I N L 0 A B 3 o A X a A I M J e A a v 4 M 1 4 M l 6 M d + N j 3 l o w 8 p l D 8 A f G 5 w 8 z M p Q 3 < / l a t e x i t > < l a t e x i t s h a 1 _ b a s e 6 4 = " g g 8 2 f n 5 P w b p G h 2 t e 5 l m 9 8 7 y A g u p b I R 4 i g b D S m Z V 0 C P b i y 8 u k U 6 / Z V s 2 + v 6 g 0 b v I 4 i u A Y n I A z Y I N L 0 A B 3 o A X a A I M J e A a v 4 M 1 4 M l 6 M d + N j 3 l o w 8 p l D 8 A f G 5 w 8 z M p Q 3 < / l a t e x i t > < l a t e x i t s h a 1 _ b a s e 6 4 = " g g 8 2 f n 5 P w b p G h 2 t e 5 l m 9 8 7 y A

Figure 1 :
Figure 1: Illustration of the proposed multi-source multi-task learning method.
For the classification task (MedNLI and RQE), we have pm ∈ R C , where C is the number of categories.Let ŷm = arg max i p(i) m be the prediction of model m, where p(i) m is the i-th dimension of pm .The final prediction is chosen as

Table 2 :
The leaderboard for RQE task (link).Scores are accuracy(%).Our method ranked the 7th on the leaderboard.

Table 3 :
The leaderboard for QA task (link).Table4).On the other hand, if we ensemble three models from different sources (#1+#2+#5 and #1+#5+#6 in Table4), the resulting model gains more than 3% in accuracy compared to the numerical average.This shows that ensembling from different sources has a great advantage than ensembling from single-source models.

Table 4 :
Comparison of ensembles from different sources.Avg.Acc stands for average accuracy, the numerical average of each individual model's accuracy.Esm.Acc stands for ensemble accuracy, the accuracy of the resulting ensemble model.For ensembles, MT-DNN means all the three models are from MT-DNN, and similarly for SciBERT; MultiSource denotes the ensemble models come from two different sources.

Table 5 :
Single model performance on MedNLI developlment data.Naiïve means simply integrating all medical-domain data; Ratio means using MedNLI as in-domain data and other medical domain data as external data; Ratio+MNLI means using medical domain data as in-domain and MNLI as external.