Learning Informative Representations of Biomedical Relations with Latent Variable Models

Extracting biomedical relations from large corpora of scientific documents is a challenging natural language processing task. Existing approaches usually focus on identifying a relation either in a single sentence (mention-level) or across an entire corpus (pair-level). In both cases, recent methods have achieved strong results by learning a point estimate to represent the relation; this is then used as the input to a relation classifier. However, the relation expressed in text between a pair of biomedical entities is often more complex than can be captured by a point estimate. To address this issue, we propose a latent variable model with an arbitrarily flexible distribution to represent the relation between an entity pair. Additionally, our model provides a unified architecture for both mention-level and pair-level relation extraction. We demonstrate that our model achieves results competitive with strong baselines for both tasks while having fewer parameters and being significantly faster to train. We make our code publicly available.


Introduction
The vast amounts of scientific literature can provide a significant source of information for biomedical research. Using this literature to identify relations between entities is an important task in various applications (van Mulligen et al., 2012;Segura-Bedmar et al., 2013;Bravo et al., 2015;Krallinger et al., 2017).
Existing approaches to biomedical relation extraction usually fall into one of two categories. Mention-level extraction aims to classify the relation between a pair of entities within a short span of text (usually a sentence). In contrast, pair-level extraction aims to classify the relation between a pair of entities across an entire paragraph, document or corpus. * Work completed during internship at BenevolentAI.
For both mention-level and pair-level relation extraction, recent work has been focused on representation learning. This is considered to be one of the major steps towards making progress in artificial intelligence (Bengio et al., 2013). Representations of relations which understand their context are particularly important in biomedical research, where identifying fruitful targets is crucial due to the high costs of experimentation. Learning such representations is likely to require large amounts of unsupervised data due to the scarcity of labelled data in this domain.
Recent mention-level methods have been based on using large unsupervised models with Transformer networks (Vaswani et al., 2017) to learn representations of sentences containing pairs of entities. These representations are then used as the inputs to much smaller models, which perform supervised relation classification Beltagy et al., 2019).
Recent pair-level methods have been based on encoding each mention of a pair of entities, and designing a mechanism to pool these encodings (across a paragraph, document, or corpus) into a single representation. This representation is then used to classify the relation between the entity pair (Verga et al., 2018;Jia et al., 2019).
However, representation learning methods for both mention-level and pair-level extraction typically use a point estimate for each representation. As a result, they may struggle to capture the nature of the true, potentially complex relations between each pair of entities. For example, Figure 1 shows sentences for two entity pairs which demonstrate that relation statements can be very different, typically depending on biological circumstances (e.g. anatomical location, experimental details, presence of a disease, etc). Such nuanced relations can be difficult to capture with a single point estimate.
We hypothesise that there is a true underlying Protein Akt and protein GSK3β: ". . . Akt negatively regulates GSK3β activity. . . " ". . . Akt phosphorylates GSK3β. . . " Protein EAAT2 and disease ALS: "EAAT2/C1-4 were found to be equally expressed in ALS patients and controls." "EAAT2 protein is significantly reduced in ALS in the motor cortex and spinal cord." Figure 1: Two sets of sentences demonstrating the potentially complex nature of the relation between a pair of entities.
relation for each entity pair, and that this relation can be multimodal (because of the aforementioned complexities). The sentences containing each pair are textual observations of these underlying relations.
We therefore propose a probabilistic model which uses a continuous latent variable to represent the true relation between each entity pair. The distribution of a sentence containing that pair is then conditioned on this latent variable. In order to be able to model the complex relations between each entity pair, we use an infinite mixture distribution for the latent representation.
Our model provides a unified architecture for learning representations of relations between entity pairs both at mention and pair level. We show that (an approximation to) the posterior distribution of the latent variable can be used for mention-level relation classification. We also demonstrate that the prior distribution from the same model can be used for pair-level classification. On both tasks, we achieve results competitive with strong baselines with a model which has fewer parameters and is significantly faster to train.

Model
In this section, we introduce our unified architecture for both mention-level and pair-level relation extraction. Throughout, we use the following notation: • c represents a 'context', i.e. a sentence (or sequence of tokens) containing a pair of entities. c has tokens c 1 , . . . , c T .
c tx and c ty are the tokens representing the two entities. We replace the actual tokens denoting the two entities with generic <ENT> tokens. Therefore, a context is given by: c = c 1 , . . . , c tx−1 , <ENT>, c tx+1 , . . . , c ty−1 , <ENT>, c ty+1 , . . . , c T • x and y are the input representations of the two entities.
-For pair-level classification, x and y will be unique identifiers for the two entities. -For mention-level classification, x and y will be the types of the two entities, e.g. GENE and DISEASE. This is done in order to allow for fair comparisons with previous methods, which use the entity types for mention-level classification (see Section 4.2 for further details). x and y always refer to the first and second entities in c respectively.
• e(c t ) is the embedding of token c t . e(x) and e(y) are the embeddings of the entities x and y.
• r represents the relation label.
Approach Large corpora of labelled relation statements are often scarce, whereas unlabelled sentences are usually plentiful. In order to leverage these unlabelled sentences, we first train an unsupervised model to learn representations of entity pairs and the contexts in which they occur. We then train much smaller models to classify relations using the representations from the unsupervised model.

Representation learning model
When training the unsupervised representation learning model, we assume access to a corpus of sentences in which entities have been tagged but there are no relation labels. We train the representation model to maximise the conditional loglikelihood log p(c|x, y). θ will refer to the set of parameters of the representation model which we wish to optimise. A graph of the representation model is shown in Figure 2 and a more detailed explanation is given below. There are many ways to express the same relation between a given pair of entities. For example, the sentences "John is Mary's brother" and "Mary is John's sister" express the same relation in different ways. In order to capture this phenomenon, we introduce a latent variable, z, to represent the true underlying relation. This will be the representation used for mention-level and pair-level relation classification. The conditional distribution is parametrised as: Intuitively, p θ (z|x, y) captures the true underlying relation between the two entities x and y, and p θ (c|z) captures the variation in the multiple possible ways of expressing that relation. For computational simplicity, we could choose p θ (z|x, y) to be Gaussian. However in reality, the true relation between a pair of entities is probably more complex than can be modelled well with a unimodal distribution. We therefore introduce another latent variable u such that: For p(u), we use a standard Gaussian distribution, N (0, I). For p θ (z|x, y, u), we again use a Gaussian distribution whose mean and variance are a function of x, y and u. We concatenate together e(x), e(y), e(x) e(y) and u, and pass the resulting vector into a feedforward network to output the mean and variance of p θ (z|x, y, u) ( denotes element-wise multiplication). Using a nonlinear network allows the marginal distribution p θ (z|x, y) to be an infinite mixture distribution (Mattei and Frellsen, 2018). The objective becomes: We parametrise p θ (c|z) with an LSTM, due to its strong performance in language modelling (Graves, 2013;Bowman et al., 2016;Melis et al., 2018). The conditional probabilities for t = 1, . . . , T are: where W is a learnable parameter of the model, and h p t is computed as: Complete hyperparameter details are provided in Section 4.1.

Training
Because of the nonlinear functions involved in p θ (z|x, y, u) and p θ (c|z), the integral in Equation (3) is intractable. We therefore perform approximate maximum likelihood estimation using stochastic gradient variational Bayes (SGVB) (Kingma and Welling, 2014;Rezende et al., 2014).
To do this, we parametrise a Gaussian inference distribution q φ (u|x, y, c) (referred to as q φ (u) henceforth, for brevity) with trainable parameters φ. This allows us to maximise the following lower bound on the log-likelihood: This bound can be approximated using Monte Carlo integration. It is optimised with respect to θ and φ jointly. To parametrise q φ (u), we use a bidirectional LSTM to encode the context. This is due to its ability to capture useful sentence level information into a low-dimensional vector (Zhou et al., 2016;Peters et al., 2018). It is computed as: We concatenate h q to e(x), e(y) and e(x) e(y) and pass the resulting vector into a feedforward network to output the mean and variance of q φ (u).

Mention-level classification
In this section, we assume that the unsupervised representation model from Section 2.1 has been trained with x and y being the types of the two entities. The representations z can now be used as the inputs to a supervised mention-level relation classification model. For mention-level classification, we assume access to a corpus of sentences in which entities have been tagged and there are labels classifying the type of relation between the entity pair in each sentence. We train the mention-level classification model to maximise p(r|x, y, c). λ will refer to the set of parameters of the mention-level classification model which we wish to optimise.
The representation z of the entity pair and context would ideally be distributed according to the posterior p(z|x, y, c) from the representation model. We would then optimise the parameters λ using the following objective: However: As mentioned in Section 2.1.1, the integral in the denominator is intractable. Instead, the following approximation to the posterior can be used: This is an approximation to the posterior because maximising the objective in Equation (6) is equivalent to minimising the KL divergence from q φ (u)p θ (z|x, y, u) to p(z, u|x, y, c) (Kingma and Welling, 2014): Using this approximation, the mention-level classification objective becomes: Empirically, however, we find that the model trains much more easily using the following objective: This is due, particularly at the start of training, to the values of p λ (r|z) being very small. Note that, due to Jensen's inequality, the objective in Equation (17) is in fact a lower bound on the log of the objective in Equation (16): To parametrise p λ (r|z), we use a shallow feedforward network with a softmax function at the output. Complete hyperparameter details are provided in Section 4.2.

Pair-level classification
In this section, we assume that the unsupervised representation model from Section 2.1 has been trained with x and y being unique identifiers for the two entities. The representations z can now be used as the inputs to a supervised pair-level relation classification model. For pair-level classification, we assume access to a dataset with pairs of entity identifiers, and labels classifying the type of relation between each pair. Instead of learning p(r|x, y, c) as in mention-level classification, we now learn p(r|x, y).
Intuitively, for pair-level classification, we wish to classify the relation between a pair of entities based on everything that the unsupervised model has learned about those entities (through the sentences containing them). This is unlike mentionlevel classification, where we classify the relation described in a specific sentence.
For pair-level classification, we follow a very similar approach to that described in Section 2.2 for mention-level classification. However we no longer base the input representation on the posterior distribution from the unsupervised model, p(z|x, y, c). Instead, the representation used will be distributed according to: Intuitively, this is the natural distribution to use, because we are interested in the relation between the entities x and y, without a specific context to condition on. We denote ψ as the parameters of the pair-level supervised model. Then, following the same reasoning as Section 2.2, the objective for the pairlevel supervised model is: To parametrise p ψ (r|z), we use a shallow feedforward network with a softmax function at the output. Complete hyperparameter details are provided in Section 4.3.

Related work
Mention-level relation extraction is typically performed using supervised learning. In the general domain, Zhang et al. (2017) combine an LSTM with a position-aware attention mechanism to perform multiclass relation extraction. Soares et al. (2019) fine-tune the BERT (Devlin et al., 2019) architecture to relation extraction tasks by enforcing similarity between representations of sentences containing the same pair of entities across a corpus. (Zhang et al., 2020) construct a teacher model to generate soft labels which guide the optimisation of a student network via knowledge distillation. In the biomedical and scientific domains, BioBERT  and SciBERT (Beltagy et al., 2019) train the BERT architecture on domain-specific corpora, achieving state of the art results on mentionlevel relation extraction tasks. Zhang et al. (2018) combine an RNN over the sentence's words and a CNN over its dependency graph to classify drugdrug and protein-protein interactions.
Pair-level relation extraction usually relies on distant supervision (Mintz et al., 2009). In the general domain, Hoffmann et al. (2011) develop a latent variable model to perform multi-instance learning while handling overlapping relations. Lin et al. (2016) use an attention mechanism to pool the representations of sentences containing a given pair into a single representation, which is then used as the input to a classifier. Quirk and Poon (2017)  Contrary to our work, there does not appear to be prior research performing both mention-level and pair-level relation extraction with a unified model.

Representation learning model
We train the unsupervised representation model described in Section 2.1 using sentences from PubMed abstracts, PubMed Central (PMC) openaccess full-text articles, and licensed full-text articles from Wiley and Springer. We take sentences with a maximum length of 140 tokens and tag the entities with their type using a dictionary-based method. Entities are linked to unique identifiers by first disambiguating entity types using a bidirectional LSTM sentence classifier, followed by type-specific term lookups. Note that if a sentence contains three or more entities, it is repeated in order to account for each possible pair of entities.

Architectures and training
To parametrise p θ (z|x, y, u), we use a 2-layer feedforward network with the ReLU nonlinearity. To parametrise p θ (c|z), we use a 1-layer LSTM. To parametrise q φ (u), we use a 1-layer bidirectional LSTM, the output of which is passed to a 2-layer feedforward network with the ReLU nonlinearity.
In order to evaluate the effect of the number of parameters on performance, we train four different versions of our representation learning model: {X-SMALL, SMALL, MEDIUM, LARGE}. These correspond to respective hidden state sizes of {128, 256, 512, 1024} in the networks. For all of the models, both u and z, as well as all embeddings, are 300-dimensional.
We train the unsupervised representation models using a single sample approximation of the objective in Equation (6). We train for 400,000 iterations, using a minibatch size of 192 and optimising the parameters using Adam (Kingma and Ba, 2015) with a learning rate of 0.0001.

Optimisation challenges
The unsupervised objective in Equation (6) can be expressed as: When training latent variable models with autoregressive observation distributions (such as that in Equation (4)), this objective can induce local optima where q φ (u) = p(u). This results in the KL divergence term in Equation (21) collapsing to 0, meaning the model ignores the latent variable altogether. To avoid such local optima, we use the following two methods (Bowman et al., 2016): KL annealing We multiply the KL divergence term by a constant weight which is linearly annealed from 0 to 1 over the first 10,000 iterations of training. This helps the model to escape local optima where D KL [q φ (u)||p(u)] = 0 early in training.
Token dropout In Equation (5), we randomly drop the token embedding being passed to the next LSTM hidden state. We use a dropout rate of 50%. This encourages the LSTM to rely more on the representation z than the previous tokens when modelling the context.

Computational costs
We show the computational costs of our unsupervised representation models in Table 1. We compare against BioBERT , a language model with state-of-the-art performance on relation extraction. All versions of our model have significantly fewer parameters than BioBERT. In terms of 'GPU days' 1 , training BioBERT is approximately 25 to 40 times slower than training our model. In addition, inference is an order of magnitude faster with our model compared to BioBERT.

Mention-level classification
After training the unsupervised representation model (using the entity types for x and y), we use it to perform supervised mention-level relation classification, as described in Section 2.2. We use the EU-ADR (van Mulligen et al., 2012) and GAD (Bravo et al., 2015) datasets. In both datasets, each sentence contains a gene and disease. The task is to classify whether the given sentence either does or does not exhibit a relation between the gene and the disease. Examples from both datasets are shown in Table 2 and dataset statistics are shown in Table 3. As per previous work, we report the performance using 10-fold cross validation on each dataset . 1 GPU days = No. of GPUs × training time (in days).
We compare our results with those of BioBERT as reported by . For a fair comparison, we use the same classifier architecture. This is a single layer network with a softmax nonlinearity. As well as training the parameters λ of the classifier, we also fine tune the parameters θ and φ of the representation model. Again, this is done to allow for a fair comparison with BioBERT (which follows the same procedure).
We approximate the objective in Equation (17) using 4 samples during training. We use a minibatch size of 8 and update the parameters using Adam with a learning rate of 0.00001. We train on EU-ADR for 200 iterations and on GAD for 3,000 iterations.
Note that the representations for BioBERT are 768-dimensional. This is in contrast to ours which are 300-dimensional.

Results
We perform 10-fold cross validation, and report the mean precision, recall and F1-score in Table  4. On EU-ADR, all versions of our model outperform BioBERT, with our LARGE model achieving a significantly higher F1-score. On this task, all versions of our model have significantly higher recall than BioBERT, with the precision being similar. On GAD, BioBERT slightly outperforms our LARGE model, thanks to its higher precision. In addition, we find that, on both tasks, the performance monotonically increases with the size of the unsupervised representation model.
These results show that it is possible to achieve results competitive with the state-of-the-art while making significant efficiency gains, both in terms of memory and time.

Pair-level classification
In this section, we use the LARGE representation model from Section 4.1, trained using the unique entity identifiers for x and y. We fix the parameters of the unsupervised representation model and use it to perform supervised pair-level classification, as described in Section 2.3.
We construct a multiclass classification dataset by combining multiple third-party biomedical datasets. These datasets only provide pairs of entities which are related. Therefore, if an entity pair does not appear in any of the datasets, they are assumed to be unrelated and given the label NO-RELATION. If two entities are related, the label is given by the concatenation of the two 30M 1 x V100 GPU 3 days 1 x V100 GPU 0.0007s/sent.   entity types. This is therefore a multiclass classification problem, with the set of possible classes being {NO-RELATION, DISEASE-GENE, GENE-GENE, CHEMICAL-GENE, CHEMICAL-DISEASE}. Note that we only include entity pairs that occur in at least one sentence in the dataset used to train the representation learning model.
We randomly split the related entity pairs into training, validation and test sets. The set of entity pairs with label NO-RELATION is extremely large. We randomly assign a proportion of these to the validation and test sets. During training, we randomly sample a proportion of each minibatch from the remaining unrelated entity pairs. The dataset statistics are shown in Table 5. For the pair-level classifier, we train a 2-layer model which has a 300-dimensional hidden layer with a skip connection. We approximate the objective in Equation (20) using 4 samples during training. We train for 100,000 iterations, using a minibatch size of 512 (of which 448 are sampled from the NO-RELATION set). We optimise the parameters using Adam with a learning rate of 0.0001. When making predictions on unseen data points, we only predict a label other than NO-RELATION if the predicted probability is higher than a threshold. This threshold is tuned to maximise the F1score on the validation set.

Baselines
We compare our method with the following two baselines: Co-occurrences For every entity pair that occurs in at least one sentence in the dataset used to train the representation learning model, we predict the relation to be positive (i.e. the concatenation of the types of the two entities). By design, this method will have perfect recall.
Attention This method is similar to those presented by Lin et al. (2016) and Verga et al. (2018). For a given pair of entities, we collect every sentence containing the pair from the dataset used to train the representation learning model. Each sentence is passed to an LSTM whose final state is taken as the sentence representation. The representations for all sentences for the given entity pair are pooled together into a single representation us-  Table 4: Results using 10-fold cross validation on the EU-ADR and GAD classification tasks. We report the mean precision (P), recall (R) and F1-score (F) over the 10 folds. For all metrics, higher is better.
DATASET PAIR-LEVEL   Table 6: Results on the test set of the pair-level classification task. We report the precision (P), recall (R) and the F1-score (F). For all metrics, higher is better.
ing an attention mechanism. This representation is then used as the input to a feedforward network with a softmax function at the output. This method is therefore trained on exactly the same dataset as our pair-level classifier. The attention model is trained for 1,000,000 iterations using a minibatch size of 100 (of which 50 are sampled from the NO-RELATION set). The parameters are optimised using Adam with a learning rate of 0.000005. As with our model, when making predictions on unseen data points, we only predict a label other than NO-RELATION if the predicted probability is higher than a threshold. This threshold is tuned to maximise the F1-score on the validation set.

Results
The precision, recall, and F1-score on the test set are reported in Table 6. Our model achieves a higher F1-score than the attention model. Unsurprisingly, both the attention model and our model achieve significantly higher precision than the cooccurrence baseline at the expense of lower recall.
In contrast to the attention model, when classifying a new pair, our model does not need to encode all of the sentences containing that pair. This provides significant computational advantages, both in terms of memory and time.

Conclusion
We have presented a model for learning representations of pairs of biomedical entities from unlabelled text corpora. We use a latent variable with an arbitrarily flexible distribution in order to be able to capture the complex relations between each pair of entities. The unified architecture can be used for both mention-level and pair-level relation extraction. On both tasks, we achieve results competitive with strong baselines. We also show significant computational gains in terms of the number of parameters and training times.
Our model presents many avenues for future work. The results in Table 4 show that the model's performance improves with the size of the hidden states in the networks; this suggests that there are further gains achievable simply by providing the model with more parameters. The model could be further scaled up by using a hierarchy of latent variables to increase the expressive power of the representations.
Other directions include evaluating the benefits of having a representation which explicitly captures uncertainty about the relations. For example, this can be done by assessing if the model is less confident when making predictions about entity pairs which do not occur frequently in the unlabelled corpus. Additionally, since our model can produce a representation for any pair of entities (even those which do not occur together in the unlabelled corpus), it could be used in a link prediction setting to score unseen entity pairs.