Structured Multi-Label Biomedical Text Tagging via Attentive Neural Tree Decoding

We propose a model for tagging unstructured texts with an arbitrary number of terms drawn from a tree-structured vocabulary (i.e., an ontology). We treat this as a special case of sequence-to-sequence learning in which the decoder begins at the root node of an ontological tree and recursively elects to expand child nodes as a function of the input text, the current node, and the latent decoder state. We demonstrate that this method yields state-of-the-art results on the important task of assigning MeSH terms to biomedical abstracts.


Introduction
We consider the task of multilabel text annotation, where labels are drawn from an ontology. We are motivated by problems in biomedical NLP (Zweigenbaum et al., 2007;Demner-Fushman et al., 2016). Specifically, scientific abstracts in this domain are typically associated with multiple Medical Subject Heading (MeSH) terms. MeSH is a controlled, hierarchically structured vocabulary that facilitates semantic labeling of texts at varying levels of granularity. This in turn supports semantic indexing of biomedical literature, thus facilitating improved search and retrieval. 1 At present, MeSH annotation is largely performed manually by highly skilled annotators employed by the National Library of Medicine (NLM). Automating this annotation task is thus highly desirable, and there have been considerable efforts to do so. The BIOASQ 2 challenge, in particular, concerns MeSH annotation, and competitive systems have emerged from this in past years (Liu et al., 2014;Tsoumakas et al., 2013); these constitute baseline approaches in the present work. 1 This problem also resembles tagging clinical notes with ICD codes (Mullenbach et al., 2018).

…
In this trial we enrolled 100 diabetics, aged 20-50. Patients received insulin or placebo. The primary outcome was blood pressure. Decoding (NTD) model. Input text is encoded, and a decoder then conditionally traverses the label tree to select all relevant nodes to apply, with node-wise attention induced over the input text.
More generally, MeSH annotation is a specific instance of multi-label classification, which has received substantial attention in general (Elisseeff and Weston, 2002;Fürnkranz et al., 2008;Read et al., 2011;Bhatia et al., 2015;Daumé III et al., 2017;Chen et al., 2017;Jernite et al., 2016). Our work differs from these prior efforts in that MeSH tagging involves structured multi-label classification: the label space is a tree 3 in which nodes represent nested semantic concepts, and the specificity of these increases with depth.
Past efforts in multi-label classification have considered hierarchical and tree-based approaches for tagging (Jernite et al., 2016;Beygelzimer et al., 2009;Daumé III et al., 2017), but these have not assumed a given structured label space; instead, these efforts have attempted to induce trees to improve inference efficiency. By contrast, we propose to explicitly capitalize on a known output structure codified here by the target ontology from which tags are drawn. We realize this by recursively traversing the tree to make (conditional) bi-nary tag application predictions.
The contribution of this work is a neural sequence-to-sequence (seq2seq) model (Bahdanau et al., 2014) for structured multi-label classification. Our approach entails encoding the input text to be tagged using an RNN, and then decoding into the ontological output space. This involves a tree traversal beginning at the root of the tree. At each step, the decoder decides whether to 'expand' children as a function of a hidden state vector, node embeddings, and induced attention weights over the input text. This approach is schematized in Figure 1. Expanded nodes are added to the predicted tag set. This process is repeated recursively until either leaf nodes are reached or no children are selected for expansion. This neural tree decoding (NTD) model outperforms state-of-the-art models for MeSH tagging.

Model
Overview. Our model is an instance of an encoder-decoder architecture. For the encoder, we adopt a standard Gated Recurrent Unit (GRU) network (Cho et al., 2014a), which yields hidden states for the tokens comprising an input document. The decoder network consumes these outputs and begins at the root of the ontological tree. It induces an attention distribution over encoder states, which is used together with the current decoder state vector to inform which (if any) of its immediate children are applicable to the input text ( Figure 1). This decoding process proceeds recursively for all children deemed relevant. Below we provide more in-depth technical detail regarding the constituent modules.
The encoder (ENC) consumes as input a raw sequence of words, here composing an abstract. These are passed through an embedding layer, producing a sequence of word embeddings x (for clarity we omit a document index here), which are then passed through a GRU (Cho et al., 2014b) to obtain a sequence of hidden vectors h These are then passed to our neural tree decoder, which is responsible for tagging the encoded text with an arbitrary number of terms from the label tree, i.e., sequences in the structured output space. This module traverses the label space top-down, beginning at the root, thus exploiting the concept hierarchy codified by the tree structure.
At each step in the decoding process, the de-coder will be positioned at a particular node in the tree n. Children -immediate descendentsof this node are then considered for expansion in turn, based on a hidden state vector s n , and a context vector c n . Both of these are initialized to zero vectors and recursively updated during traversal, i.e., as nodes are selected for expansion (and hence added to the predicted tag set). More specifically, the context vector that informs the decision to expand node v in the label hierarchy from its parent node n is a weighted sum of the encoder hidden states h, where weights reflect induced attention over inputs, conditioned on n. That is: and a is a simple multi-layer perceptron (MLP), with node-specific parameters θ n . Here both sums range over the length of the input text. Given c n , we then estimate the probability that child label v is applicable to the current input text as a function of the decoder state vector (s n ), the current context vector (c n ) and the decoder parameters. In particular, this is realized via a standard linear layer with sigmoid activations, parameterized by a weight matrix W comprising independent weight vectors for each output node v. Thus the score for a particular output node v is σ(W v · [s n , c n ]), where W v denotes the weight vector for output node v.
Pseudocode for the training and decoding procedures are presented in Algorithm 1. In the NODELOSS function, n denotes a particular node. The set of hidden vectors induced by the encoder (corresponding to the inputs) are denoted by h, s is the hidden state of the decoder, and y is the reference label (this encodes a path in the output tree). We assume the decoder, DEC, consumes input representations, a node index and a hidden state and yields a context vector for n, c n and an updated state vector s n ; in our case the latter is implemented via a GRU. The advantage of using an RNN during decoding is that this allows the exploitation of learned, distributed hidden representations of partial tree paths, which inform nodewise attention and subsequent predictions. for each child v ∈ children(n) do 5:ŷv ← σ(Wv · [sn, cn]) 6: pv ← ∝ depth in tree 7: Bv ∼ Ber(pv) 8: if Bv then 9: ln ← ln + L(ŷv, y) 10: ifŷv > τ then 11: ln ← ln + NODELOSS(v, h, sn, y) return ln 12: function TRAIN(x, y, α, epochs) 13: θ ← INIT(θ) 14: e ← 0 15: while e < epochs do 16: for each instance xi ∈ x do 17: hi ← ENC(xi) 18: s0 ← 0 19: li ← NODELOSS(ROOT, hi, s0, yi) 20: ∆θ ← BACKPROP(li) 21: θ ← θ + α∆θ 22: e ← e + 1 return θ Incurring loss for all nodes along the path specified by y would place a disproportionate amount of emphasis on correctly applying terms that are 'higher' in the ontology, as loss will be propagated for the initial predictions concerning the application of these and then also, due to recursive application, for all of their children (and so on). Thus we only incur (and hence backpropagate) loss for a node v stochastically, according to a Bernoulli distribution B with parameter p v . We set p v to be proportional to the depth of node v in the tree such that we are likely to incur larger loss for deeper (rarely occurring) nodes. We operationalize this as: p v = min(1, 0.5 + m fv ), where m is the count corresponding to the least frequently observed node in the training corpus and f v is the count for node v. In Section 4 we demonstrate the benefit of this approach.
At train time we use teacher forcing (Williams and Zipser, 1989) during decoding. That is, we revert the model back to the correct (training) tree subsequence when it goes off-course, and continue decoding from there. We have elided this detail from the pseudocode for clarity.

Experimental setup
Below we describe experimental details concerning our implementation, datasets and baselines. Code and data to reproduce our results is available at https://github.com/gauravsc/NTD.

Implementation Details
We limited the vocabulary to the 50, 000 most frequent words. Word embeddings were initialized to pre-trained vectors induced via word2vec, trained over a large set of abstracts indexed on PubMed. 4 Ontology node embeddings were pre-trained using DeepWalk (Perozzi et al., 2014), fit over PubMed.

Dataset
Our dataset comprises abstracts of articles describing randomized controlled trials (RCTs) from PubMed along with their MeSH terms. The MeSH annotations were manually applied by professionals at the National Library of Medicine (NLM). The label space underlying MeSH terms is codified by a publicly available ontology. 5 We split this dataset into disjoint sets for training/development and final evaluation (Table 1). We further separated the former into train, validation and development test subsets, to refine our approach. For our final evaluation we used a heldout set of 10,000 abstracts that were not seen in any way during model development and/or hyperparameter tuning. We performed extensive hyperparameter tuning for the baseline models to ensure fair comparison; details regarding this tuning are provided in the Appendix.

Baselines
We compare our proposed approach to three baselines, including two prior winners of the annual BioASQ challenge, which includes an automated MeSH annotation task. However, it is important to note that we used a different (and considerably smaller) dataset in the current work, as compared to the corpus used in the BioASQ challenge. LSSI (Tsoumakas et al., 2013) use an approach that involves predicting both the number of terms and which to apply to a given abstract. They use linear models for both tasks, which operate over TF-IDF representations of abstracts. Specifically, they train a regressor to predict k, the number of MeSH terms to be applied to an abstract. Simultaneously, a binary linear SVM is trained independently for each MeSH term appearing in the train set. At test time, these SVMs provide scores for each term and the topk terms are applied, wherek is the estimate from the aforementioned regressor. UIUC (Liu et al., 2014) uses a learning-to-rank model to identify the top MeSH terms for an abstract from a candidate set of terms, which is obtained from the nearest neighbours of the abstract. Additionally, one SVM classifier is trained for each of the MeSH terms (similar to the above approach), and scores for each are used to obtain additional terms to be added to the candidate set. In the end, a threshold (tuned on the validation set) is used to select the final set of terms to be assigned. Finally, we consider a deep multilabel classification model DML (Rios and Kavuluru, 2015) that takes as input unstructured abstracts and activates the output nodes corresponding to the relevant MeSH terms. In brief, embedded tokens are fed through a CNN to induce a vector representation, which is then passed on to the dense output layer. Finally, this is passed through a sigmoid activation function. Note that this model exploits the same pre-trained word embeddings as our model does.

Evaluation metrics
We first evaluate model performance via output node-wise precision, recall and F1 measure. However, these metrics are overly strict in the sense that a model will be penalized equally for all mistakes, regardless of whether they are nearby or far from the target in the label tree. This is problematic because whether to apply a specific MeSH term or its immediate parent may be somewhat subjective in practice. To quantify this, and to explore the extent to which explicitly decoding into the target label space yields improved predictions, we also consider a measure that we refer to as semantic distance (SD): where Y andŶ are the sets of target and predicted terms respectively, and dist is a function that returns the shortest distance between two nodes in  the label ontology tree. The idea is that this penalizes less for 'near misses'. Thus if a model fails to apply a particular tag t, but does apply one near to t in the label tree, then it is penalized less. 6 We hypothesize that our model will improve results markedly with respect to this metric, given our exploitation of the tree structure.
As in the case of recall, SD can be 'gamed': one can achieve a perfect score by predicting that all nodes apply to a given abstract. Thus this is only meaningful alongside complementary metrics like F1.

Results
Results on the test set (which was completely held out during development) are reported in Table 2. The proposed Neural Tree Decoding model with stochastic backpropagation (NTD-s) bests the most competitive baseline (LSSI) in F1 score by over 2 points.
To explore the effect of backpropagating loss from nodes in proportion to their depth in the ontology, we also include results for a deterministic variant that does not do this, NTD-d. This version does not perform as well, demonstrating the utility of the proposed training approach.
The metrics reported thus far do not account of the structure in the output space. We thus additionally report results with respect to the the semantic distance (SD) metric (Eq. 3). We observe a marked performance increase of ∼21% over the best performing baseline. This is intuitive given that we are explicitly decoding into the label tree structure, and demonstrates the ability of our model to learn the ontological structure, thereby predicting semantically appropriate terms.

Conclusions, Discussion & Limitations
We developed a neural attentive sequence tree decoding model for structured multilabel classification where labels are drawn from a known ontology. The proposed method can decode an input text into a tree of labels, effectively using the structure in the output space. We demonstrated that this model outperformed SOTA approaches for the important task of tagging biomedical abstracts with Medical Subject Heading (MeSH) terms on a modestly sized training corpus. Code and data to reproduce these results are available at https: //github.com/gauravsc/NTD.
One limitation of our model is that it is comparatively slow, due to having to traverse the tree structure during decoding. Prediction speed may not be a major issue in practice, as articles on PubMed could be batch tagged nightly as they arrive. However, slow decoding also means lengthy training (see Appendix, section A.2 for details). For this reason we have here used a modest training set of ∼20k abstracts, which is smaller than corpora used in prior work on this task. Given the relative expressiveness of our model, we expect it to benefit substantially from additional training data, moreso than the simpler baseline architectures. But at present this is only a conjecture.
In future work we thus hope to apply this model to larger datasets, and to address the efficiency issue. Concerning the latter, sibling subtrees may be traversed in parallel, conditioned on the hidden state of their parent. Another promising direction would be to move to convolutional encoder and decoder architectures, designing the latter in a way similarly capitalizes on the label space tree structure.