FedED: Federated Learning via Ensemble Distillation for Medical Relation Extraction

Unlike other domains, medical texts are inevitably accompanied by private information, so sharing or copying these texts is strictly restricted. However, training a medical relation extraction model requires collecting these privacy-sensitive texts and storing them on one machine, which comes in conﬂict with privacy protection. In this paper, we propose a privacy-preserving medical relation extraction model based on federated learning, which enables training a central model with no single piece of private local data being shared or exchanged. Though federated learning has distinct advantages in privacy protection, it suffers from the communication bottleneck, which is mainly caused by the need to upload cumbersome lo-cal parameters. To overcome this bottleneck, we leverage a strategy based on knowledge distillation. Such a strategy uses the uploaded predictions of ensemble local models to train the central model without requiring uploading lo-cal parameters. Experiments on three publicly available medical relation extraction datasets demonstrate the effectiveness of our method.


Introduction
Privacy -like eating and breathing -is one of life's basic requirements.
-Katherine Neville Relation extraction is a task of mining factual knowledge from the free text by labeling relations between entity mentions and has attracted increasing attention in recent years, such as Zeng et al. (2014); Xu et al. (2015a,b); ; Baldini Soares et al. (2019); Song et al. (2019). Applying automatic relation extraction to medical texts, such as electronic health records and discharge summaries, can be useful for many applications, including drug repurposing and medical knowledge graph construction.
Unlike other domains, medical texts are highly privacy-sensitive, because these texts can include some of the most intimate details about one's life, which document a patient's physical and mental health, and can include information on social behaviors, personal relationships and financial status (Gostin and Hodge, 2002). To prevent private information leakage, sharing or copying medical texts is strictly restricted.
Previous relation extraction methods require centralizing the underlying training data from different medical platforms, such as hospitals and healthcare centers, on one server for training, while holding the centralized privacy-sensitive data puts patients' privacy at risk. This is one of the reasons that hinder the use of relation extraction in clinical practice. As a possible solution, federated learning (McMahan et al., 2016) is proposed to make full use of privacy-sensitive data. Training local models with private data at local platforms and aggregating local models in the central server compose the federated learning process. In the framework of federated learning, no single piece of private data is uploaded to or stored on the central server, and only local models' parameters are sent to the server for updating the central model.
Though federated learning has distinct advantages in privacy protection compared to centralized training, federated learning algorithms, such as FedAvg (McMahan et al., 2016), require frequent communication between local platforms and the central server to upload and download models' parameters. Communication is a critical bottleneck of applying federated learning to relation extraction, which is largely due to the following reasons: First, the state-of-the-art relation extraction models (Baldini Soares et al., 2019;Li et al., 2019b;Thillaisundaram and Togia, 2019) usually utilize transformer-based pretrained language models (Raffel et al., 2019;Devlin et al., 2019;Yang et al., 2019b) as backbone encoders, which have millions or even billions of parameters. Second, the framework of federated learning includes a massive number of local platforms (Li et al., 2019a), and communication between each platform and the central server is necessary. Third, upload bandwidth is typically limited to 1 MB/s or less in most situations 1 . Considering the cumbersome model, numerous local platforms and the limited upload bandwidth, it will take an excessive amount of time during frequent upload processes. For example, in a single communication, uploading a BERT-Large (Devlin et al., 2019) model takes more than 21 minutes and uploading a T5 (Raffel et al., 2019) model takes more than 12 hours. In order to overcome the communication bottleneck in federated relation extraction, it is necessary to develop a communication-efficient method that iteratively sends small messages as part of the training process, as opposed to sending the entire pretrained language encoder.
In this paper, we introduce a privacy-preserving medical relation extraction model, named FedED. To prevent private information leakage, we leverage federated learning without sharing raw privacysensitive medical texts. To overcome the communication bottleneck in federated relation extraction, we focus on reducing the size of transmitted messages at each communication round. To this end, we formulate the central aggregation process in federated learning as learning a compact central model (student) from the ensemble (Dietterich, 2000;Breiman, 2001) of multiple local models (teacher). From this perspective, only the predicted labels on a small dataset need to be uploaded to the central server, because learning from a "teacher" model only requires the behavior of the "teacher" rather than the entire"teacher" network (Hinton et al., 2015). Besides, the ensemble model (teacher) is powerful, which defines the upper extreme of aggregating when limited to a single communication in federated learning (Yurochkin et al., 2019). To transfer the knowledge in the ensemble model to the central model, we leverage a strategy based on knowledge distillation (Hinton et al., 2015), which trains the central model by forcing it to have a similar prediction with the ensemble model. To demonstrate the effectiveness of our method, we conduct extensive experiments on three different medical relation extraction datasets. The results show that our method not only outperforms the baselines but also is communication-efficient.
We summarize our contributions as follows: • To protect patients' privacy, we propose the first (to the best of our knowledge) privacypreserving medical relation extraction model based on federated learning, which decouples the model training from the need for direct access to the highly privacy-sensitive data.
• To overcome the communication bottleneck in federated learning, we leverage a knowledge distillation based strategy that utilizes the uploaded predictions of ensemble local models to train the central model without requiring uploading the entire local models' parameters.
• The method yields promising results on three different medical relation extraction datasets, and we perform various experiments to verify the effectiveness of the proposed method.

Related Work
Our work builds on a rich line of recent efforts on relation extraction models and federated learning.

Relation Extraction
Relation extraction is a long-standing NLP task of mining factual knowledge from free texts by labeling relations between entity mentions. There are a number of recent neural network approaches applied to relation extraction, such as Zeng et al. In this paper, we also adopt a pretrained language model as the backbone encoder. Applying relation extraction models to the medical field has great practical value, and there is a rich literature on medical relation extraction. Some studies focused on clinical relation extraction (Sahu et al., 2016;Munkhdalai et al., 2018;Ningthoujam et al., 2019) and some studies concentrated on biomedical relation extraction (Peng et al., 2017;Song et al., 2018Song et al., , 2019. Compared with previous studies, we develop a federated relation extraction system to protect patients' privacy in medical relation extraction.

Federated Learning
Recently, McMahan et al. (2016), Konečnỳ et al. (2016a) and Konečnỳ et al. (2016b) proposed the concept of federated learning. The main idea of federated learning is to build machine learning models based on data sets that are distributed across multiple local platforms while preventing data leakage. Federated learning can be divided into three categories, i.e., horizontal federated learning, vertical federated learning and federated transfer learning, based on the distribution characteristics of the data (Yang et al., 2019a). This work focuses on horizontal federated learning, where local datasets share the same feature space but different in samples. Federated learning has the advantage of protecting privacy, so it is widely used in various fields. Chen et al. (2018) combined federated learning with meta learning for the recommendation. Kim et al. (2017) proposed federated tensor factorization for computational phenotyping without sharing patient-level data. Liu and Miller (2020) proposed federated pretraining of BERT model using clinical notes from multiple silos. Ge et al. (2020) proposed a privacy-preserving medical NER method based on federated learning.

Task Definition
Relation Extraction devotes to extracting relational facts from sentences. Given a sentence with an entity pair e 1 and e 2 , this task aims to identify the relation between e 1 and e 1 . In this paper, we focus on applying relation extraction to the medical domain. Define K medical platforms {P 1 , P 2 , ..., P K }, each with a private relation extraction dataset D i , and a central server that has a small valid dataset D v . Since the medical data is usually private and sensitive, the goal is to obtain a medical relation extraction model on the central server under the condition that any local platform P i does not expose its private data D i to others.
To solve this task, we propose a privacypreserving medical relation extraction model based on federated learning. In the following sections, we introduce the basic medical relation extraction model at first. Then, we present how to conduct privacy-preserving training in a communicationefficient way.

Medical Relation Extraction Model
Given the impressive performance of recent deep transformers (Vaswani et al., 2017) trained on variants of language modeling, we utilize the BERT model (Devlin et al., 2019) as the backbone encoder. In this section, we explore a simple way of representing relations with the deep transformers model. The model architecture is shown in Figure  1 and the details are as follows: Firstly, we construct the input sequence s = {w 0 , w 1 , ..., w n }, where w 0 = [CLS] and w n = [SEP] are special start and end markers. Next, to ensure generalization of the model, we follow previous studies (He et al., 2013;Kim et al., 2015;Chauhan et al., 2019) to perform entity blinding on the sequence, where the words in the sequence matching the entity are replaced with the target entity label. Then, in order to highlight entity mentions, we augment the sequence with four reserved word pieces, i.e., e1 , /e1 , e2 and /e2 , to mark the begin and end of each entity mention. After that, we get the prepared sequenceŝ.
Given the prepared sequenceŝ as input, the output of BERT is expressed as H ∈ R m×d , where m is the prepared sequence length and d is the output dimension of the BERT encoder. We use the first token of the sequence (the [CLS] token) as the sequence representation, which is denoted as h 0 ∈ R d . In addition, we obtain entity mention representations by summing the final hidden layers corresponding to the word pieces in each entity mention, and get two vectors h e 1 = sum([h i ...h j ]) ∈ R d and h e 2 = sum([h k ...h l ]) ∈ R d representing the two entity mentions. Finally, the sequence representation and these two entity mention representations are concatenated to be the input of a fully connected layer: where W and b are trainable model parameters.

Federated Training
To protect patients' privacy, we utilize federated learning to train the medical relation extraction model. In the federated framework, two types of models are needed, i.e., the local model and the central model, which share the same network structure but have different permissions to access private data. Local models are deployed in local platforms, such as hospitals, and can access their respective private local data. In contrast, the central model is deployed in a central server, such as a cloud server, which is strictly prohibited from accessing to patients' private data. Here, following previous studies (McMahan et al., 2016;Bonawitz et al., 2019;Ge et al., 2020), we assume the central server belongs to one trusted third party, which means it will not make any vicious attack to local platforms.
In this section, we present how to train the relation extraction model in the federated way, including secure local model update and the ensemble distillation based central model update.

Secure Local Model Update
The local model in each medical platform is trained on its own private data. We assume that the local platform P i is selected to perform local computation in a round. The local platform P i computes the gradients of loss over all the data D i held by it to update the parameters of the local model. We adopt the cross-entropy as the local loss function, which is defined as follows: where |D i | represents the number of sentences in this local private data D i and Θ indicates all parameters of the local model. After local model training, the local model in P i accesses to valid data D v in the central server, makes a prediction on it based on the trained parameters and uploads the predicted labels to the central server. Compared with centralized training, the local model is only trained on its own data, and only the predicted labels are uploaded rather than directly sharing raw data, which generally contains less privacy-sensitive information.

Central Model Update via Ensemble Distillation
The central server coordinates massive local models to collaboratively train the central model. To this end, there are a coordinator and an aggregator in the central server.
Coordinator controls the entire training process and is responsible for accepting and forwarding local platform connections. At the beginning of each communication round, the coordinator builds the medical relation extraction model in the central server and initializes the model. Then, the coordinator randomly selects a C-fraction of local medical platforms, since we cannot require that all local platforms are always online in the real-world scenario. After that, the coordinator distributes the parameters of central model to all selected local platforms, and the selected local models are initialized based on these parameters, which ensures that all selected local models are trained from the same initial condition at this round. Then, the selected local models are trained on their respective private data at each local platform. The coordinator monitors each selected platform for any possible uploads. Once it receives uploads from one platform, the coordinator will store them for future aggregation. When all selected platforms finish local training, the stored uploads are sent to the aggregator to inference new central model parameters.
Aggregator is the most critical part of federated training, which optimizes the central model based on massive trained local models. To transfer the knowledge in the massive trained local models to the central model, we resort to teacherstudent framework. The ensemble of local models is viewed as the teacher, while the central model is regarded as the student. The knowledge in the teacher is transferred to the central model by forcing them to have a similar prediction for any input instance. To this end, the central model is trained to minimize a distillation loss function where the target is the distribution of class probabilities pre-dicted by the ensemble model. The typical choice of the distillation loss function is the Kullback-Leibler (KL) divergence between the distributions, D KL (q||p), where p and q are the output label distributions of the student and the teacher respectively. The distribution of the teacher can be attained as follows: where z(y i |s) is the logit of ensemble model (teacher) for class i, which is represented as the mean of selected local models' logits for this class, and τ is a temperature parameter that controls the shape of the distribution for distilling richer knowledge from the ensemble model. In addition to the distillation loss, it is also beneficial to train the central model to predict the ground truth labels using the standard cross-entropy loss. The overall objective is defined as follows: The overall training procedure of FedED is illustrated in Algorithm 1.

Experiments
In this section, we carry out an extensive set of experiments with the aim of answering the following research questions: • RQ1: Does our model outperform the baseline methods? (see Section 4.4) • RQ2: Is federated learning effective in medical relation extraction? (see Section 4.5) • RQ3: Is our approach communicationefficient? (see Section 4.6) • RQ4: What is the impact of increasing parallelism on our model? (see Section 4.7) • RQ5: What is the impact of increasing computation per local platform on our model? (see Section 4.8) In the remainder of the section, we describe the datasets, experimental setting, and all baselines.
Algorithm 1 FedED. The K local platforms are indexed by k. C is the fraction of local platforms that perform computation on each round. B is the local minibatch size. E is the number of local epochs, and η is the learning rate. Initialize Θ 0 on the central server for each communication round t = 0,1,2,... do m ← max(C× K, 1) J t ← (random set of m local platforms) The server distributes Θ t to J t . for each platform k ∈ J t in parallel do Perform LocalUpdate(k, Θ t ) end for // The procedure of Aggregator    Table 1 and 2. We random sample 20% of training data for validation. To evaluate our method, we use the standard evaluation metric for each dataset: Micro-F1 for 2010 i2b2/VA challenge dataset and F1-score for GAD and EU-ADR.

Experimental Settings
We use a controlled environment that is suitable for experiments and assume a synchronous update scheme that proceeds in rounds of communication.
For 2010 i2b2/VA challenge dataset, we set the number of local platforms (K) to 100. For EU-ADR and GAD datasets, the number of local platforms (K) is set to 50, since these datasets are small. The training data is randomly shuffled and then partitioned into K local platforms each receiving 1/K of the training data. This data partitioning simulates the scenario where each hospital is treated as a local platform and the central server is located in a trusted third party.
In our experiments, we use hugginface's implementation (Wolf et al., 2019) of BERT (base version) and initialize parameters of the BERT encoding layer with pretrained clinical BERT (Alsentzer et al., 2019) models. The learning rate is set to 2.5e-05. We use the dropout strategy to mitigate overfitting, which is set to 0.1. To conduct a fair comparison (presented in Section 4.4), we set all federated methods hyper-parameters as follows. The random fraction of local platforms C is 0.1, and we also study adding more local platforms at each round of communication in Section 4.7. Since the batch size and the number of local epochs are related to the number of secure local updates per round, the batch size B is fixed to 4 and the number of local epochs E is set to 2. We independently repeat each experiment 9 times and report the median F-score. All experiments are run with an NVIDIA GeForce RTX 2080 Ti.

Baselines
Under centralized training settings, we compare our medical relation extraction model (depicted in Section 3.2) with the following studies: (1) Bravo et al. (2015) combine the shallow linguistic kernel with the dependency kernel to mine the syntactic features of text; (4) Bhasuran and Natarajan (2018) employ an ensemble SVM with a rich feature set covering conceptual, syntax and semantic information; (5) Lee et al. (2020) propose a domain-specific language representation model, called BioBERT, pre-trained on large-scale biomedical corpora.
In the federated training manner, We compare our federated framework (depicted in Section 3.3) with the following baselines: (1) FedAvg (McMahan et al., 2016) averages element-wise parameters of local models with weights proportional to sizes of the local datasets; (2) FedAtt (Ji et al., 2019) leverages a layer-wise attention mechanism for model aggregation. which can automatically attend to the weights of the relation between the central model and different local models.

Results
Table 3, 4 and 5 answer RQ1 by showing the results of our model against baselines on the real-world medical datasets. In overall, our model significantly outperforms baselines on these datasets.
In the centralized training manner, our method outperforms REflex (Chauhan et al., 2019) on i2b2 dataset, which builds CNN upon the embeddings Centralized Training Bravo et al. (2015) 77.80 87.20 82.20 Bhasuran and Natarajan (2018)     In the federated training manner, our federated framework outperforms FedAvg (McMahan et al., 2016) and FedAtt (Ji et al., 2019). There are two possible reasons: (1) The performance of the ensemble model defines the upper extreme of aggregating when limited to a single communication in federated learning (Yurochkin et al., 2019), and the central model benefits from learning from the ensemble model. (2) FedAvg and FedAtt only model the simple process of central optimization by averaging or weighted averaging local model parameters, which overlook complicated relationships between local model parameters. FedED forces the central model to mimic the behavior of the ensemble model rather than modeling the complex relationship between parameters.
Comparing the federated training manner to the centralized training manner, we find that applying the centralized training manner achieves better performance. There are three reasons: (1)  (2) As the size of each local private data is small, the local model is prone to overfitting on it. (3) The local platforms are independent of each other; therefore, compared with centralized training, federated training lack the ability to model the overall data distribution. Although federated training does not perform as well as centralized learning, federated training is uniquely positioned to protect privacy. Moreover, our approach narrows the gap between federated training and centralized training in terms of performance.

Effectiveness Test of Federated Learning
To test the effectiveness of federated learning, we simulate a real-world scenario where a third party only has a small data, i.e., validation set, and copying data from hospitals is prohibited. The results are shown in Table 6, which answers the RQ2. From this table, we find that: (1) Due to data scarcity, the model trained only on the validation set can not achieve satisfactory performance; (2) FedAvg and FedAtt can effectively improve the performance of relation types with abundant examples, such as "TeRP", "TrAP" and "PIP", but perform poorly in relation types with few examples (The distribution of relation types is shown in  (3) Our proposed FedED is able to improve performance in all relation types. We conjecture that ensemble distillation can capture the rich similarity structure between relation types, which boosts the performance.

Communication Efficiency Test
We turn to RQ3 in this section.   The fraction of local platforms C controls the amount of local platforms selected by the coordinator in each round. In Figure 2, we report the number of communication rounds necessary to achieve an F1 value of 72% on the test set. We find that: (1) Increasing parallelism will speed up convergence for all methods. When all local platforms are selected (C = 1), all methods reach the target F1 value with minimal communication cost. This is mainly due to the fact that the increased parallelism leads to more data used in each round of training;

Increasing Parallelism
(2) Our method requires a much smaller number of communication rounds to reach the target F1 value than the other methods. We conjecture that this is due to that the central model (student) learns much faster and more reliably when trained with the outputs of the ensemble model (teacher) as soft labels (Phuong and Lampert, 2019).

Increasing Computation Per Platform
Finally, we address RQ5. The number of local computation per round is given by |D k | B E, where B is the local batch size, E is the number of local training epoch and |D k | is the size of private data in local platform k. Decreasing B, increasing E, or both will add more computation per local platform per round.

Conclusion and Future Work
In this paper, we propose a privacy-preserving medical relation extraction model based on federated learning, namely FedED. The main obstacle of applying federated learning to medical relation extraction is communication bottleneck, which is caused by the need to upload cumbersome parameters. To overcome this bottleneck, we leverage a knowledge distillation based strategy, which uses the uploaded predictions of ensemble local models to train the central model without requiring uploading cumbersome parameters. Our experiments on three benchmark datasets illustrate the advantages of our approach over previous federated algorithms. As to future work, we plan to explore how to jointly extract entities and relations in federated settings.