MSnet: A BERT-based Network for Gendered Pronoun Resolution

The pre-trained BERT model achieves a remarkable state of the art across a wide range of tasks in natural language processing. For solving the gender bias in gendered pronoun resolution task, I propose a novel neural network model based on the pre-trained BERT. This model is a type of mention score classifier and uses an attention mechanism with no parameters to compute the contextual representation of entity span, and a vector to represent the triple-wise semantic similarity among the pronoun and the entities. In stage 1 of the gendered pronoun resolution task, a variant of this model, trained in the fine-tuning approach, reduced the multi-class logarithmic loss to 0.3033 in the 5-fold cross-validation of training set and 0.2795 in testing set. Besides, this variant won the 2nd place with a score at 0.17289 in stage 2 of the task. The code in this paper is available at: https://github.com/ziliwang/MSnet-for-Gendered-Pronoun-Resolution

Recent work indicated that the pre-trained language representation models benefit to the coreference resolution . In the past years, the development of deep learning methods of language representation was swift, and the newer methods were shown to have significant effects on improving other natural language processing tasks (Peters et al., 2018;Radford and Salimans, 2018;Devlin et al., 2018). The latest one is Bidirectional Encoder Representations from Transformers (BERT) (Devlin et al., 2018), which is the cornerstone of the state of the art models in many tasks.
In this paper, I present a novel neural network model based on the pre-trained BERT for the gendered pronoun resolution task. The model is a kind of mention score classifier, and it is named as Mention Score Network (MSNet in short) and trained on the public GAP dataset. In particular, the model adopts an attention mechanism to compute the contextual representation of the entity span, and a vector to represent the triple-wise semantic similarity among the pronoun and the entities. Since the MSnet can not be tuned in a general way, I employ a two-step strategy to achieve the tuning-fine, which tunes the MSnet with freezing BERT firstly and then tunes them together. Two variants of MSnet are submitted in the gendered pronoun resolution task, and their logarithmic loss of local 5-fold cross-validation of train dataset is 0.3033 and 0.3042 respectively. Moreover, in stage 2 of the task, they acquired the score at 0.17289 and 0.18361 respectively, by averaging the predictions on the test dataset, and won the 2nd place in the task.

Model
As the target of the Gendered Pronoun Resolution task is to label the pronoun with whether it refers to entity A, entity B, or NEITHER. I aim to learn the reference probability distribution P (E i |D) from the input document D: where E i is the candidate reference entity of pronoun, E = {A, B, NEITHER} and s is the score function which is implemented by a neural network architecture, which is described in detail in the following subsection.

The Mention Score Network
The mention score network is build on the pretrained BERT model ( Figure 1). It has three layers, the span representation layer, the similarity layer, and the mention score layer. They are described in detail in the following part. Span Representation Layer: The contextual representation is crucial to accurately predict the relation between the pronouns and the entities. Inspired by Lee et al. (2017), I adopt the hidden states of transformers of the pre-trained BERT as the contextual representation. As Devlin et al. (2018) showed that the performance of the concatenation of token representations from the top hidden layers of pre-trained Transformer of BERT is close to fine-tuning the entire model, the top hidden states will be given priority to compute the representation of entity spans. Since most entity spans consist of various tokens, the contextual representation of them should be re-computed to maintain the correspondence. I present two methods to re-compute the span representations: 1) Meanpooling method: where x (i,l) denotes the hidden states of i-th token in l-th layer of BERT, and x * (j,l) denotes the contextual representation of entity span j, andN is the token counts of span j. 2) Attention mechanism: Instead of weighting each token equality, I adopt the attention mechanism to weight the tokens by: The weights a (i,j,l) are learned automatically from the contextual similarity s (i,l) between pronoun x (p,l) and the token x (i,l) in the span j. Different from the commonly used attention functions, the above one has no parameters and is more spaceefficient in practice. The scaling factor d H denotes the hidden size of BERT and is designed to counteract the effect of extremely small gradients caused by the large magnitude of dot products (Vaswani et al., 2017). Similarity Layer: Inspired by the pairwise similarity of Lee et al. (2017), I assume a vector s l to represent the triple-wise semantic similarity among the pronoun and the entities of l-layer in BERT: where a l , b l and p l denote the contextual representation of the pronoun, entity A and entity B of the l-th layer in BERT, · denotes the dot product and • denotes the element-wise multiplication. Theŝ l can be learned by a single layer feed-forward neural network with the weights W and the bias b. Mention Score Layer: Mention score layer is also a feed-forward neural network architecture and computes the mention scores given the distance vector d between the pronoun and its candidate entities and the concatenated similarity vector s: where d a (or d b ) denotes the distance encoding of entity A (or B),ŝ l denotes the similarity vector computed by the representation of the l-th layer in BERT. L is the total layers for representation, and START denotes the index of the start token of the span. w dist is a learnable weight for encoding the distance which corresponds to a learnable bias b dist and W E i is the learnable weights for scoring entity E i which corresponds to a learnable bias b E i .

Experiments
I train the model on the Kaggle platform by using scripts kernel which using the computational environment from the docker-python 1 . I employ pytorch as the deep learning framework, and the pytorch-pretrained-BERT package 2 to load and tune the pre-trained BERT model.

Dataset
The GAP Coreference Dataset 3 (Webster et al., 2018) has 4454 records and officially split into three parts: development set (2000 records), test set (2000 records), and validation set (454 records). Conforming to the stage 1 of Gendered Pronoun Resolution 4 task, the official test set and validation set are combined as the training dataset in the experiments, while the official development set is used as the test set correspondingly.

Preprocessing
In the experiments, the WordPiece is used to tokenize the documents. To ensure the token counts less than 300 after tokenizing, I remove the head or tail tokens in a few documents. Next, the special tokens [CLS] and [SEP] are added into the head and end of the tokens sequences.

Hyper-parameters
Pre-trained BERT model: As increasing model sizes of BERT may lead to significant improvements on very small scale tasks (Devlin et al., 2018), I explore the effect of BERT BASE and BERT LARGE in the experiments.
I employ the uncased_L-12_H-768_A-12 5 as the BERT BASE and cased_L-24_H-1024_A-16 6 as the BERT LARGE , and both of them are transformed into the pytorch-supported format by the script in pytorch-pretrained-BERT.
Hidden Layers for Representation: Devlin et al. (2018) showed that using the representation from appropriate hidden layers of BERT can improve the model performance, the hidden layers L (described in Section 2) is therefore utilized as a hyper-parameter tuned in the experiments.
Dimension of Similarity Vector: Since a vector is used to represent the task-specific semantic similarity, its dimensionŝ dim may have potential influence the performance. A smaller dimension will partly lose information, while a bigger one will cause generalization problems.
Span Contextual Representation: As section 2 described, both the meanpooling and attention method can be used to compute the contextual representation of the tokens span of the entity. Therefore, the choice of them is a hyper-parameter in the experiment.
Tunable Layers: I use two different approaches to train the MSnet model. The first one is the feature-based approach which trains MSnet with freezing the BERT part. The second one is the fine-tuning approach, which tunes the parameters of BERT and MSnet simultaneously. Howard and Ruder (2018) showed the discriminative finetuning gets a better performance than the ordinary, which possibly means that the pre-trained language model has a hierarchical structure. One possible explanation is that the lower hidden layers extract the word meanings and grammatical structures and the higher layers process them into higher-level semantic information. In this, I freeze the embedding layer and bottom hidden layers of BERT to keep the completeness of word meaning and grammatical structure and tune the top hidden layers L tuning .

Training Details
For improving the generalization ability of the model, I employ the dropout mechanism (Srivastava et al., 2014) on the input of the feed-forward neural network in the similarity layer and the concatenation in the mention score layer. The rate of dropout is set at 0.6 which is the best setting after tuned on it. I also apply the dropout on the representation of tokens when using the attention mechanism to compute the contextual representation of span, and its dropout rate is set at 0.4. Additionally, I adopt the batch normalization (Ioffe and Szegedy, 2015) before the dropout operation in the mention score layer. As introduced in section 3.3, I use the feature-based approach and the fine-tuning approach separately to train the MSnet, and the training details are described in the following.
Feature-based Approach: In the feature-based approach, I train the model by minimizing the cross-entropy loss with Adam (Kingma and Ba, 2014) optimizer with a batch size of 32. To adapt to the training data in the experiments, I tuned the learning rate and found a learning rate of 3e-4 was the best setting. The maximum epoch set at 30 and early stopping method is used to prevent the over-fitting of MSnet.
Fine-tuning Approach: In the fine-tuning approach, the generic training method was not working. I adopt a two-step tuning strategy to achieve the fine-tuning. In step 1, I train the MSnet in the feature-based approach. And in step 2, MSnet and BERT are tuned simultaneously with a small learning rate.
Since the two steps have the same optimization landscape, in step 2, the model may not escape the local minimum where it entered in step 1. I adopt two strategies of training in step 1 to reduce the probability of those situations: 1) premature. The MSnet is trained to under-fitting by using a small maximum training epoch which is set at 10 in the experiments. 2) mature. In this strategy, MSnet is trained to proper-fitting, and it is applied by adopting a weight decay at 0.01 rate, an early stopping at 4 epoch, and the maximum training epoch at 20 in the experiments. In addition, other training parameters of the two strategies have the same setting as in the feature-based approach.
In step 2, I also trained the model by minimizing the cross-entropy loss but with two different optimizers. For BERT, I used the Adam optimizer with the weight decay fix which implemented by pytorch-pretrained-BERT. For MSnet, the generic Adam was used. Both of the two optimizers are set with a learning rate at 5e-6 and a weight decay at 0.01. The maximum training epoch is set at 20, and the early stopping is set at 4 epoch. The batch size was 5 as the GPU memory limitation.

Evaluation
I report the multi-class logarithmic loss of the 5-fold cross-validation on train and the average of their predictions on the test. Also, the running time of the scripts is reported as a reference of the performance of the MSnet.

Feature-based Approach
The results of MSnet variants trained in featurebased approach are shown in Table 1. The comparison between model #1 and model #2 shows that the combination of the top 4 hidden layers for contextual representation is better than the top  layer. The possible reason is that the semantic information about gender may be partly transformed to the higher level semantic information during the hidden layers in BERT. In addition, changing BERT BASE to the BERT LARGE reduces the loss in 5-fold CV on train from 0.4699±0.0431 to 0.4041±0.0532, which demonstrate increasing model size of BERT can lead to remarkable improvement on the small scale task. The exploration of contextual representation layers shows the proper representation layers is proportionate to the number of hidden layers of BERT. In other words, the modeling ability of BERT LARGE is more powerful than BERT BASE by using a more complex function to do the same work.
The comparison among the model #4, model #6, model #7 and model #8 shows the dimension of the similarity vector has a slight affection for the performance of MSnet (Table 1) and the best loss is 0.3736±0.0465 with the dimension set at 16. Changing the method for computing the span contextual representation from meanpooling to attention mechanism reduces the loss in CV on train by ∼0.02, which demonstrates that the attention mechanism used in the experiment is effective to compute the contextual representation of the entity span. To the best of my knowledge, it is a novel attention mechanism with no learnable parameters and more space-efficient and more explainable in practice.

Fine-tuning Approach
The experiments in fine-tuning approach was based on model #9, and the results are shown in table 2. The comparison between model #10 and model #11 shows that their difference on performance is slight. Also, both of them are effective to the fine-tuning of MSnet and reduce loss in the CV of train by ∼0.054 compared to the feature-based approach. Furthermore, the tuning on L tuning shows the best setting is tuning top 12 hidden layers in BERT, and more or fewer layers will reduce the performance of MSnet. The possible reason is that tuning fewer layers will limit the ability of the transformation from basic semantic to gender-related semantic while tuning more bottom layers will damage the extraction of the underlying semantics when training on a small data set.
As the apporach transformed from the featurebased to the fine-tuning, the intentions of some hyper-parameters were changed. The obvious one is the hidden layers for contextual representation, which is used to combine the semantic in each hidden layers in the feature-based approach and changed to constrain the contextual representation to include the same semantic in fine-tuning approach. Although, the change on the intentions was not deliberate, the improvement on the per-formance of the model was observed in the experiments.

Results in Stage 2
The gendered pronoun resolution was a two-stage task, and I submitted the model #10 and #11 in stage 2 as their best performances in 5-fold cross-validation of the training dataset. The final scores of the models were 0.17289 (model #10 ) and 0.18361 (model #11). This result featurely demonstrates the premature strategy is better than the mature one and can be explained as former one keeps more explorable optimization landscape in step 2 in the fine-tuning approach.

Conclusion
This paper presented a novel pre-trained BERT based network model for the gendered pronoun resolution task. This model is a kind of mention score classifier and uses an attention mechanism to compucate the contextual representation of entity span and a vector to represent the triple-wise semantic similarity among the pronoun and the entities. I trained the model in the feature-based and the two-step fine-tuning approach respectively. On the GAP dateset, the model trained by the fine-tuning approach with premature strategy obtains remarkable multi-class logarithmic loss on the local 5-fold cross-valication at 0.3033, and 0.17289 on the test dataset in stage 2 of the task. I believe the MSnet can serve as a new strong baseline for gendered pronoun resolution task as well as the coreference resolution. The code for training model are available at: https: //github.com/ziliwang/MSnet-for-Gendered-Pronoun-Resolution