Efficient Automatic Punctuation Restoration Using Bidirectional Transformers with Robust Inference

Though people rarely speak in complete sentences, punctuation confers many benefits to the readers of transcribed speech. Unfortunately, most ASR systems do not produce punctuated output. To address this, we propose a solution for automatic punctuation that is both cost efficient and easy to train. Our solution benefits from the recent trend in fine-tuning transformer-based language models. We also modify the typical framing of this task by predicting punctuation for sequences rather than individual tokens, which makes for more efficient training and inference. Finally, we find that aggregating predictions across multiple context windows improves accuracy even further. Our best model achieves a new state of the art on benchmark data (TED Talks) with a combined F1 of 83.9, representing a 48.7% relative improvement (15.3 absolute) over the previous state of the art.


Introduction
Enabling computers to use speech as input has long been an aspirational goal in the field of human computer interaction. Recent advances have had dramatic impact across multiple domains (e.g. relieving medical professionals from having to transcribe medical dictation , improving real-time spoken language translation (Gu et al., 2017), and affording convenience through conversational interfaces like those in virtual personal assistants (McTear et al., 2016)). For use cases that require reading transcribed speech, however, it is often still a challenge to recover meaningful clause boundaries from disfluent, errorful utterances.
Humans rely on punctuation for readability, perhaps because it lessens the burden of ambiguous phrasing. Studies have found that removing punctuation from manual transcriptions can be even more detrimental to understanding than a word error rate of 15% or 20% (Tündik et al., 2018). Reading comprehension is also significantly slower without punctuation (Jones et al., 2003). For downstream NLP models, the lack of clausal boundaries can significantly decrease accuracy (e.g. a 4.6% BLEU decrease in NMT; Vandeghinste et al. 2018). This likely reflects the discrepancy between wellsegmented training corpora and ASR output.
To solve the lack of punctuation in ASR output, we propose an automatic punctuation model, which leverages the recent trend in unsupervised pre-training (Devlin et al., 2019) and the parallel architecture of transformer networks (Vaswani et al., 2017). Unsupervised pre-training dramatically reduces the amount of labeled data required for superior performance on this task. Additionally, the model's departure from a recurrent architecture allows direct connections between all input tokens. This enables the network to more easily model longdistance dependencies (e.g. on one hand, ... on the other, ...) for improved punctuation performance. The departure from a recurrent architecture also allows computations to be performed in parallel for each layer with the speed of computations limited by the number of layers rather than the number of time steps (usually fewer). In addition to the parallel nature of the hidden layers, our network also predicts in parallel for all tokens in the input simultaneously. This helps significantly speed up inference compared with individual predictions for each token. During training, the parallel prediction task provides a richer signal compared with a sequential task, thereby making more efficient use of each example. Furthermore, advancing the prediction window less than the window's width (e.g. steps of 20 with a window of 50) allows aggregating multiple windows of context to predict a token's label. This allows the network to effectively become its own prediction ensemble and boosts accuracy further. Given that the aggregate predictions are independently obtained, these calculations too can be performed in parallel.
Though non-sequential, several previous approaches use simpler network architectures (e.g. DNNs (Yi et al., 2017;Che et al., 2016) or CNNs (B. Garg and Anika, 2018;Che et al., 2016;Żelasko et al., 2018)), which have less predictive power. The handful of approaches that make use of Transformer architectures are not bidirectional Nguyen et al., 2019;Vāravs and Salimbajevs, 2018;Wang et al., 2018). Our model also differs from the above in that it leverages pretraining to reduce training time and increase accuracy. The one previous work that uses a pretrained bidirectional transformer (Cai and Wang, 2019) only predicts punctuation one token at a time, which significantly increases both training and inference time. It is also unable to aggregate predictions across multiple contexts, limiting performance.

Method
Architecture The network architecture can be seen in Figure 1.
The first component of our network is a pre-trained language model (RoBERTa base ; ) employing the recent deep bidirectional Transformer architecture (Devlin et al., 2019;Vaswani et al., 2017). The network's input is a sequence of unpunctuated lowercased words tokenized using RoBERTa's tokeniza-tion scheme (see  for more details). We then add two additional linear layers after the pre-trained network with each layer preserving the fully-connected nature of the entire network. The first linear layer maps from the masked language model output space to a hidden state space for each input token with parameters shared across tokens. The second linear layer concatenates the hidden state representations into a vector for the prediction window which allows the tokens to interact arbitrarily within the window. We then apply batch normalization (Ioffe and Szegedy, 2015) and dropout (best results obtained with a rate of 0.2; Hinton et al. 2012) prior to predicting punctuation marks for all tokens in the window.
When aggregating predictions across contexts, activations at the sequence layer are added for each token prior to classification (see Figure 1 for visualization). Prediction is performed in parallel during both training and inference with the output size of the final classifier being |classes| * length window .  Figure 1: The punctuation network takes as input a sequence of unpunctuated words tokenized in the same manner as RoBERTa. It outputs predictions for these sequences individually during training (layer O n ). For validation and testing, however, these labels are aggregated across overlapping context windows to obtain the final punctuation predictions (layer Y n ). Note that while the pre-trained LM's output begins as vocabulary distributions, they cease to be so once the entire network undergoes fine-tuning. Training Schedule It is worth noting that we use only the TED Talks dataset described below for training but enjoy significant benefits from a sizable pre-training corpus . Although prediction is performed on multiple tokens at once, the same number of training samples are generated from the corpus by moving the sliding window one token at a time over the input. To perform gradient descent, we use LookAhead (Zhang et al., 2019) with RAdam (Liu et al., 2020) as the base optimizer. We use a simple cross-entropy function to calculate the loss for each token's classification prediction. Our best performing model (see Table 1) uses a prediction window size of 100, a final-layer dropout of 0.2, and a hidden-state space of dimensionality 1500. The top two linear layers (henceforth referred to as the "top layers") are initially trained from scratch while the transformer core remains frozen. Then, having selected the model version with the lowest validation loss from training the top layers, the transformer core is unfrozen, and we fine-tune the parameters of the entire network. We then select the model version with the lowest validation loss to prevent overfitting.
We train the top layers for nine epochs with a mini-batch size of 1000 (using 100-token sequences) while the transformer is frozen. The lowest validation loss for the top layers is usually achieved around the sixth epoch. We then unfreeze the transformer and fine-tune the entire network for three more epochs with a mini-batch size of 250. We typically observe the lowest validation loss midway through the first epoch while fine-tuning. It is worth noting that a highly competitive model (82.6 overall F1) can be trained with just 1 epoch each for the top layers and fine-tuning. This training can be completed in slightly less than 1 hour on a p3.16xlarge AWS instance (with 8x Tesla V100 GPUs).
For the LookAhead optimizer, we use a sync rate of 0.5, and a sync period of 6. The RAdam optimizer-used as the model's base optimizerhas its learning rate set to 10 −5 , β 1 = 0.9, β 2 = 0.999, = 10 −8 . We do not use weight decay.
Data To train the network and evaluate its performance (both at test and validation time), we use the IWSLT 2012 TED Talks dataset (Cettolo et al., 2012). This dataset is a common benchmark in automatic punctuation (e.g. Kim 2019) and consists of a 2.1M word training set, a 296k word validation set, and a 12.6k word test set (for reference transcription, 12.8k for ASR output). Each word is labeled with the punctuation mark that follows it, yielding a 4-class classification problem: comma, period, question mark, or no punctuation. The class balance of the training dataset is as follows: 85.7% no punctuation, 7.53% comma, 6.3% period, 0.47% question mark.

Results
The results of our best performing model relative to previous results published on this benchmark can be found in Table 1. Additionally, we conducted a number of ablation experiments manipulating various aspects of the architecture and training routine. In providing accuracy comparisons, all results in this section are reported in terms of the absolute change in the overall F1 measure.
In place of the pre-trained RoBERTa base language model, which provided the best result, we also evaluated (in order of decreasing performance relative to RoBERTa as implemented by Wolf et al. (2020) Table 1. The performance benefit of RoBERTa base over BERT base is likely due to the significant increase in pre-training corpus size. The lower performance of ALBERT base may be due to the sharing of parameters across layers. It is interesting to note that XLNet base provides higher recall for periods and question marks and T5 base for commas and question marks, but both sacrifice significant precision to achieve this.
In addition to the LookAhead optimizer using RAdam as its base, we also evaluated: LookAhead with Adam (-1.5%), RAdam alone (-1.6%), and Adam alone (-2.9%; Kingma and Ba 2017). Given the class imbalance inherent in the dataset between the no punctuation class and all the punctuation marks, we tested focal loss (Lin et al., 2018), class weighting, and their combination, but found that none outperformed simple cross-entropy loss.
Perhaps the most noteworthy result is the comparison between parallel prediction (described above) and sequential prediction, wherein the forward pass predicts punctuation for one token at a time using a context window centered on that token. Sequential prediction requires longer inference times (>15x) yet yields only a marginal performance benefit (2.2%) relative to a parallel prediction without aggregation across multiple contexts. Ensembling predictions over multiple contexts overcomes the performance gap, while retaining an advantage with respect to inference time. Compared to the self-ensemble approach, sequential prediction is >4x slower and 5.4% less accu-

rate.
A less obvious choice must be made between a single parallel prediction and multiple aggregated predictions, given the additional runtime of multiple predictions (see Table 2 for details). For our purposes, the 7.6% improvement is worth the increase in inference time, which is sub-linear given GPU parallelization but still appreciable. While our best method sums activations from different contexts to obtain the aggregate predictions, we also tested adding normalized probabilities across classes and then renormalizing, but we found it resulted in slightly worse performance (-0.3%).
In addition to the RoBERTa base model whose results are reported here, we also trained with a RoBERTa large model. There was no appreciable performance difference between the two sizes (the large being -0.4% worse) however the large model incurred a significant slowdown (≈1.5x). This may imply that the base model size is adequately powered for punctuation tasks, at least on manually transcribed English datasets similar to the benchmark. This is supported by the findings of Kovaleva et al. (2019), who found BERT base to be overparameterized for most downstream tasks, implying RoBERTa large would be extremely overparameterized. A smaller pre-trained language model option is DistilRoBERTa, a knowledge distilled version of RoBERTa (analogous to DistilBERT: Sanh et al. 2020). The DistilRoBERTa network is 12% smaller and performs inference ≈1.2x faster, but sacrifices 9.1% in accuracy on the benchmark.
The previous state of the art approach was a multi-headed attention network on top of multiple stacked bidirectional GRU layers (Kim, 2019).
Given the recurrent nature of the GRU layers, the network is subject to the shortcomings of sequential computation discussed in the Introduction. Our findings illustrate yet another language task where transformers outperform previous recurrent neural network approaches.
Our approach enjoys a 48.7% relative improvement (15.3 absolute) over the previous state of the art (Kim, 2019). Given the ablation results presented above, we attribute the performance gains to the deeply bi-directional transformer architecture, the benefit of leveraging RoBERTa's pre-trained language model trained on ≈ 33B words, and the aggregation of multiple prediction contexts for robust inference. Some performance gain may also be attributed to the addition of an encoding layer trained solely on the punctuation task.
One of the more notable findings is that the nonrecurrent nature of the entire network allows for a large degree of parallelization resulting in a more competitive runtime compared to previous recurrent approaches. While source code was not openly available for benchmarking runtime against Kim (2019), we did compare against a similar approach from Tilk and Alumäe (2016) 2 , which was roughly 78.8x slower on GPUs and 1.2x slower on a CPU, when evaluating the TED Talks test set.
The results presented here have not benefited from any rigorous hyperparameter tuning (e.g. grid search or Bayesian optimization). We leave that to future work given that a rigorous systematic approach may yield appreciable improvements in accuracy.

Conclusion
We have presented a state of the art automatic punctuation system which aggregates multiple prediction contexts for robust inference on transcribed speech. The use of multiple prediction contexts, unsupervised pre-training, and increased parallelism makes it possible to achieve significant performance gains without increased runtime or cost.
On a different dataset, Boháč et al. (2017) reported human agreement of around 76% for punctuation location and 70% for use of the same punctuation mark. Although we have yet to make a direct comparison, it's possible our model is already competitive with human performance on this task. Future work will explore how this performance