Large Product Key Memory for Pretrained Language Models

Product key memory (PKM) proposed by Lample et al. (2019) enables to improve prediction accuracy by increasing model capacity efficiently with insignificant computational overhead. However, their empirical application is only limited to causal language modeling. Motivated by the recent success of pretrained language models (PLMs), we investigate how to incorporate large PKM into PLMs that can be finetuned for a wide variety of downstream NLP tasks. We define a new memory usage metric, and careful observation using this metric reveals that most memory slots remain outdated during the training of PKM-augmented models. To train better PLMs by tackling this issue, we propose simple but effective solutions: (1) initialization from the model weights pretrained without memory and (2) augmenting PKM by addition rather than replacing a feed-forward network. We verify that both of them are crucial for the pretraining of PKM-augmented PLMs, enhancing memory utilization and downstream performance. Code and pretrained weights are available at https://github.com/clovaai/pkm-transformers.


Introduction
Larger model capacity has brought improvement in accuracy by enabling better modeling of data. However, increasing model capacity causes a significant increase in computational cost at both training and inference time despite better accuracy. To address this issue,  propose product key memory (PKM) that enables very efficient and exact nearest neighbor search in a large number of learnable memory slots. They substitute a feed-forward network (FFN) in a transformer block (Vaswani et al., 2017) with a PKM layer. Augmenting large PKM layers to networks allows * Equal contribution. † TJ was an intern at Clova AI while doing this work.  Table 1: Comparison of inference speed between different model sizes and the memory layers. We run each model for the classification task with batch size 1, and measure inference speed on a single V100 GPU. We follow the model size settings of BERT (Devlin et al., 2018). We use two memory layers with the recommended setting of PKM hyper-parameters following  as described in §5. As marked bold, BERT BASE with our proposed residual memory (ResM) is much faster than BERT LARGE , while having more parameters.
increasing model capacity, with only a slight increase in inference time.  prove the efficiency of PKM on causal language models (CLMs) in terms of the superior trade-off between perplexity and inference speed. For instance, they achieve a PKM-augmented CLM with only 12 layers that is more accurate and twice faster than a baseline with 24 layers. However, usage of PKM with a pretrained language model (PLM) such as BERT (Devlin et al., 2018) that is helpful for downstream tasks (Wang et al., 2018) has not been examined in the literature. In our experiments, plain PKM improves masked language modeling (MLM) perplexity but not downstream performance.
We measure various memory utilization metrics to analyze how many memory slots contribute to the model prediction. Careful examination about memory utilization during and after the training demonstrates that only a few memory slots are be-ing used importantly ( § 3.1). We attribute this phenomenon, called a catastrophic drift, to the sparsely updated memory parameters. The lower memory utilization implies that model capacity from memory is not fully exploited. It promotes us to develop methods that can overcome this issue.
We found that initialization from weights pretrained without memory is essential for pretraining PKM-augmented PLMs. Moreover, rather than replacing an FFN to a PKM as  do, we show that adding PKM to a transformer layer (Vaswani et al., 2017) with a residual connection (He et al., 2016) without removing FFN is advantageous. Both the initialization ( § 4.1) and our proposed residual memory (ResM, § 4.2) prevent a sudden change of transformer parameters, thus allow to train memory parameters better by less suffering from the catastrophic drift. Consequently, we obtain PKM-augmented-BERT BASE having comparable accuracy and faster than BERT LARGE . As demonstrated in Table 1, a model with a large memory is much faster than a model having twice many transformer layers, although it has far more weights. ResM does not slow down inference speed much. Accuracy comparison between them will appear in the later sections.
The main contributions of this work are summarized as follows. First, we explore how to incorporate PKM to PLMs to be finetuned for downstream tasks and find that simple application does not work well. Secondly, we attribute this to a catastrophic drift during the training by careful monitoring of memory utilization. Lastly, we propose simple yet effective solutions to tackle the observed catastrophic drift problem: (1) weight initialization without PKM and (2) the residual memory layer. We empirically verify that both of them are crucial to achieve improved accuracy. In our knowledge, this is the first work that successfully applies PKM to PLMs.

Transformers and Product Key Memory
A transformer encoder maps a sequence of input tokens into a sequence of continuous representations based on a self-attention mechanism (Vaswani et al., 2017). Transformer architecture is a stack of sub-layers, and each sub-layer consists of a multi-head attention layer and a feed-forward layer. Due to the remarkable prediction accuracy, a trans-former becomes standard architecture in natural language processing.
On the other hand, memory architecture can also be used to design a function that maps a continuous representation to another representation as a layer in neural networks. When a query vector is given in a standard memory-augmented neural network, the memory layer finds k-NN keys and returns a weighted sum of corresponding value vectors. These weights are normalized scores of the dot product between the query vector and the key vectors.  propose product key memory (PKM) that can significantly increase model capacity based on fast and exact nearest neighbor search. They plug a PKM layer in a transformer architecture, especially by switching an existing feed-forward layer to it, while keeping similar computational efficiency.
We explain the mechanism of PKM here to be self-contained. A product key is a pair of sub-keys, meaning that there are |K| = C 2 different memory slots when the codebook size of each sub-key is C. A given query vector is partitioned to the dimension of half-size. The score with a product key is the sum of the dot product between the sub-query vector and the sub-key vector. We can increase the size of key space effectively with sufficient C. Exact nearest neighbor search in the product key set can be done efficiently by first finding k-NN in each sub-key space and then finding k-NN again from k 2 combinations of sub-key pairs. In addition, a multi-head memory attention mechanism like selfattention in transformers is used to increase the representation power of the memory layer.

Pretrained Language Models
Transfer learning from pretrained language models (PLMs) has brought a paradigm shift in NLP with a remarkable improvement in a wide range of downstream tasks. Based on a transformer architecture (Vaswani et al., 2017), BERT (Devlin et al., 2018) is trained with two pretraining tasks, (1) masked language modeling (MLM) and (2) next sentence prediction (NSP), which achieves significant improvement in performance on fine-tuning tasks. RoBERTa  removes the NSP and increases the batch size and training corpus to train a more robust language model. It indicates that larger batch size and training data benefit the performance of PLM. In these trends, recently, lan-guage models with much larger parameters (Raffel et al., 2019;Shoeybi et al., 2019;Brown et al., 2020) are trained with a huge amount of text corpus. Despite their remarkable performance, the computational cost in training and inference is prohibitive. Improving trade-off between accuracy and efficiency is one of the crucial research directions.

Memory-Augmented Language Models
Memory augmented neural networks (Weston et al., 2014;Sukhbaatar et al., 2015) have the ability to solve complex algorithmic tasks and decouple the memory capacity from the number of model parameters. Chandar et al. (2016) propose a hierarchical memory network to access from large external memory efficiently. Rae et al. (2016) enable training a large memory in neural networks efficiently via a sparse read and write mechanism. However, it requires regular re-training to avoid a catastrophic drift. REALM (Guu et al., 2020) also suffers from a similar issue, so refresh the index asynchronously every several hundred training steps.
In addition to , augmenting memory architecture to a language model is a promising research direction. For example, EaE (Févry et al., 2020) and FaE (Verga et al., 2020) jointly train a memory that is interleaved in a transformer and dedicated to entities (or facts) with sparse updates, and access to only a small portion of the memory in inference time. On the other hand, each memory slot in  and ours does not have explicit meaning. Sukhbaatar et al. (2019) augments the selfattention layers with persistent memory vectors and removes the feed-forward layers. Khandelwal et al. (2019) augments a pretrained language model with the nearest neighbor language model that retrieves k-nearest neighbors from the datastore consisting of the key-value pairs of a context vector and the target word built from training data. Khandelwal et al. (2019) also only considers causal language modeling, and applying the same approach to masked language modeling widely used for PLMs is nontrivial.

Memory Utilization Analysis
As shown in our experiment (Table 2), large PKM provides a significant gain in masked language modeling in terms of perplexity. However, surprisingly, downstream task performance finetuned from PKM-augmented PLMs is similar to or sometimes worse than that without PKM in our experiments. Nevertheless, it is challenging to investigate what is going on under the hood. We presume that this frustrating outcome come from the catastrophic drift which will be explained later ( § 3.1) and it fosters us to scrutinize memory utilization ( § 3.2) thoroughly.

Catastrophic Drift
PKM is jointly trained with other transformer parameters. In every training step, only a small portion (chosen as k-NN) of memory parameters are sparsely updated. Even if a memory slot is selected as top-k, the frequency is low or it is only selected as low-rank in top-k, the update of memory parameters relevant to this slot might be marginal.
If memory parameters (especially value vectors) are not updated (or rarely updated) for a while, they became stale. Stale parameters are unlikely to be matched with newly updated model parameters so that they will get remain unused. We call this situation a catastrophic drift. Moreover, catastrophic drift will be more severe in finetuning because it relies on a small number of data and training steps.
We hypothesize this catastrophic drift occurs during the training of a PKM-augmented LM, and it is one plausible cause of poor performance. This problem is overlooked by  because it is concealed by increasing the number of memory slots |K|, heads H, or k-NN. With a sufficient size of memory hyper-parameters, memory usage (see § 3.2 for the definition) becomes close to 100%. For example, in  and our experiments, memory usage is almost 100% when using 4 memory heads, selecting 32 keys per head, and using 512 2 memory slots. Considering only top-k memory usage, memory parameters are seemingly regarded as used effectively to their full extent.

Memory Utilization Metrics
Following Lample et al. (2019), we measure the memory utilization of trained PKM-augmented models in terms of (1) memory usage and (2) KL divergence with the uniform distribution using heldout data. Besides standard memory usage, we propose to measure top-1 memory usage that only counts memory slots as used when selected as top-1 rather than top-k and use it to monitor the degree of catastrophic drift.
For every memory slot, we count the number of selection as k-NN (or top-1) and sum the weights throughout all memory accesses: is the weight of the key i accessed in the memory when an input x is given to the language model with the memory. Memory usage (M U ) is the fraction of values that are accessed at least once. Top-1 memory usage ( M U ) is the fraction of values that are accessed as top-1 at least once. KL divergence with the uniform distribution is calculated for normalized average counts (KL u ) and normalized average weights (KL w ). Formally, we can calculate those values by where |K| is the number of memory slots, and u, t, and w are the normalized value of u , t , and w , respectively, as sum to 1.  propose PKM and show its advantage in causal language modeling. We investigate how to extend the usage of large PKM to PLMs such as BERT (Devlin et al., 2018) and RoBERTa ) that can be used as a good initialization point for downstream tasks, resulting in a great performance. By monitoring top-1 memory usage, we observe that catastrophic drift really occurs. Low memory utilization PKM-augmented PLMs means that the model does not fully exploit its increased capacity of the memory and thus is likely not to get accuracy gain much. To resolve the catastrophic drift, we introduce additional modifications for better pretraining: initialization from pretrained weights ( § 4.1) and residual memory layer ( § 4.2).

Initialization from Pretrained Weights
Learning transformer parameters and memory parameters together from scratch is difficult due to the discrepancy between them as described in § 3.1. To remedy this issue, we first pretrain a language model without memory layers, and then pretrain again a model with memory layers initialized from the already pretrained language model. Transformer parameters will be gradually changed since they are initialized from a well-trained language model. We expect that staleness would be mitigated as a result. Despite requiring two stages of training, a trained language model with initialization performs much better and has higher memory usage than that with the same amount of training steps from the scratch, as shown in Table 2. He et al. (2016) propose ResNet to train very deep convolution networks. A residual connection enables easier optimization and gains accuracy from increased depth. We borrow this idea by introducing a residual connection in augmenting a PKM to alleviate the catastrophic drift.

Residual Memory Layer
When we replace an FFN layer of pretrained networks with the PKM layer, it struggles to fit data in an early stage because the function of this layer suddenly changed to random function from a well-trained one (see a green line of Figure 3). We hope to prevent this undesirable circumstance while keeping strong representation power of product key memory. To this end, we propose residual memory (ResM) layer, adding the memory layer to a transformer block in the form of residual connection (He et al., 2016) instead of replacing the FFN layer. Due to the residual connection, the function of the layer does not deviate severely from that of the original pretrained weights, and it helps to start at a stable point. Figure 1 displays how the residual memory layer is different from the previous models. To be more precise, we can formulate these layers to where LN indicates layer normalization (Ba et al., 2016). (α, β) = (1, 0), (0, 1), (1, 1) corresponds to FFN layer, PKM layer, and ResM layer, respectively.

Product Key Memory
Our implementation is based on HuggingFace's Transformers library 1 , and the PKM part is borrowed from the XLM repository.   We add two memory layers in the intermediate layers at regular intervals: i.e., {4,8} in 12 layer models, and {2,4} in 6 layer models. We will explore the effect of changing the number of the position of memory layers in the future. We use 512 2 (≈ 262k) memory slots with 4 memory heads and select 32 keys per head for each memory layer for all experiments. We set the dimension of key vectors and value vectors to 256 and 768, respectively. We use query batch normalization to increase key coverage during training. We measure the top-1 memory usage and the KL divergence to measure how much the model effectively uses memory capacity.

Pretraining
We use 12 layer BERT BASE models with and without PKM. For pretraining, we use English Wikipedia and BookCorpus (Zhu et al., 2015) as a training corpus like BERT (Devlin et al., 2018), in total 17GB. We use the same vocabulary and tokenizer with Devlin et al. (2018). We train models with batch size of 1024 sequences for 500,000 steps. We use Adam optimizer (Kingma and Ba, 2014) with learning rate of 1e-4 and linear warmup scheduler over the first 10,000 steps. The memory values are learned with a sparse update of learning rate 1e-3, following . With halfprecision training 3 on 32 NVIDIA V100 GPUs, pretraining took 2.8 days without PKM and 5.1 days with PKM (or with ResM). To evaluate pretrained models themselves, we measure the perplexity of masked language modeling on the test set of WikiText-2, WikiText-103, and PG-19 (Rae et al., 2019). Since the pretraining corpus covers WikiText-2 and WikiText-103, perplexity on them is a proxy to the training perplexity. Meanwhile, because the PG-19 dataset came from different sources, perplexity on PG-19 can be regarded as the test perplexity.

Finetuning
For fine-tuning, we use SQuAD 1.1 (Rajpurkar et al., 2016) and GLUE (Wang et al., 2018)   we report dev set results instead of the test set to compare our variants. We report a median of 5 runs with different random seeds for each fine-tuning task. We measure exact match (EM) and F1 scores on SQuAD 1.1. For QQP, which is the binary classification task, the F1 score is used for the GLUE leaderboard. However, we use the accuracy as the metric for development set because the F1 score varies a lot depending on random seeds. Finetuning details appear in Table 3. Memory Utilization Surprisingly, the top-1 memory usage of the PKM-augmented PLM at the 4th layer is about 2%, which is remarkably low, though top-32 memory usage at this layer is almost 100%. In other words, the model does not take advantage of the lower memory layer effectively. With a residual connection, the top-1 memory usage of all layers become reasonably high. Similar to He et al. (2016), the residual connection helps to learn deep networks with memory, resulting in improved accuracy. Moreover, with the initialization from pretrained weights, top-1 memory usage is more than 95%. With the initialization and ResM, top-1 memory usage increases, and KL divergence decreases significantly, implying better exploitation of the memory layers. It becomes possible by preventing memory parameters not to suffer from the catastrophic drift.

Pretraining Results
We check when each memory slot is used at last among saved checkpoints. Then, we count the num-  Figure 2: Histogram for staleness evaluation of PKMaugmented PLMs. We save model checkpoints every 100k step during the entire 500k pre-training steps. This histogram illustrates how many memory slots are used at last for each saved checkpoint. For example, if a key is used at 200k model checkpoint and never used after that, then it is likely to keep its state as stale after 200k. Because the total number of memory slots is fixed to 512 2 , the model having boxes toward the right in the graph is better.
ber of slots depending on the last used checkpoint. Figure 2 indirectly indicates how many memory slots are kept not selected as top-1. This figure provides evidence that a model with the initialization and residual memory prevents staleness compared to a model with plain PKM.
Masked Language Modeling Augmenting large PKM always improves masked language modeling compared to a model without memory. Figure 3 shows the training curve of the models after the initialization. It proves that the residual connection prevents a deviation of the PKM at the beginning (bigger initial perplexity) even with the initialization from the pretrained weight. Although they are converged to a similar perplexity after very long training steps, the initial perplexity of PKM is much bigger than that of ResM. In sum, both the initialization from pretrained PLM and the residual memory layer are beneficial for PLM with a memory to perform better in masked language modeling.  Table 4: Experimental results of fine-tunining PKM-augmented PLMs. Model (a)-(f) are the same one from Table  2. : we borrow pretrained weights of BERT BASE and BERT LARGE from (Devlin et al., 2018). We fine-tune these models on SQuAD 1.1 (Rajpurkar et al., 2016) and GLUE tasks (Wang et al., 2018). 7 Finetuning Results Table 4 shows the experimental results of finetuning using our pretrained models.
Downstream Performance Although large PKM helps masked language modeling, the downstream performance of several tasks with plain PKM is worse than the baseline without memory. We think this is because the catastrophic drift problem is especially severe in the fine-tuning step. Downstream dataset size and the number of training steps are too small to fit memory parameters accordingly.
Better memory utilization coming from the initialization and the residual connection also leads to better downstream accuracy in most of the datasets. We report the fine-tuning results using the weights of pretrained BERT LARGE from De-   Table 4. 4 We believe that our best PKM-augmented-BERT BASE would have comparable performance with BERT LARGE even after pretraining it by ourselves, while much faster as described in Table 1.
On the assumption that updating memory parameters sparsely using a limited number of data and training steps might be vulnerable to the catastrophic drift, we try to fix memory parameters during fine-tuning as in Table 5. However, it degrades the downstream performance. Table 6 shows the memory usage and KL divergence of fine-tuned PKMaugmented models. Comparison of fine-tuned PKM-augmented models in terms of the memory usage has similar trends with that of pretrain-  Table 6: Memory utilization of PKM-augmented models after fine-tuning. We measure memory utilization metrics (M U , M U , KL u , and KL w ) at 4th and 8th layer after fine-tuning using MNLI-m (Williams et al., 2017), SST-2 (Socher et al., 2013), and CoLA (Warstadt et al., 2019) datasets as an example. We use the same fine-tuned models that appeared in Table 2. ing. The initialization and the residual memory improve memory usage, meaning better exploitation of model capacity for downstream tasks. Especially in a large dataset like MNLI (Williams et al., 2017), the memory usage of the fine-tuned model reaches to almost 100% similar to pretrained models due to the sufficient training steps to update memory parameters. On the other hand, interestingly, the initialization and the residual memory do not always reduce KL divergence. We presume this because fine-tuning of classification tasks encourages input examples of the same class to be clustered into similar representations, so it requires to access similar patterns of memory slots while utilizing many of them.

Memory Utilization
To validate the assumption mentioned above, we check the difference in memory usage between positive examples and negative examples using SST-2 (Socher et al., 2013)   is the binary classification tasks to predict the sentiment of a movie review. To measure the difference, we calculate (1) KL divergence between two distributions (positive/negative) and (2) intersection over union (IOU), which is a widely used metric in object detection (Ren et al., 2015) on the top-1 memory usage. We calculate IOU as i is a top-1 usage at memory position i for positive examples and negative examples, respectively. As illustrated in Figure 4, our best PKM-augmented model shows much higher KL and lower IOU in every layer than the plain PKM-augmented model, implying better discriminative ability.
Other Pretrained Models We release the code and pretrained weights to encourage researchers and practitioners to easily utilize and reproduce our work, allowing the application to different model sizes and other backbone architectures. In particular, we employ our methods to DistilBERT model , which is a 6-layer transformer model trained by knowledge distillation (Hinton et al., 2015) from BERT BASE . Similarly, it obtains accuracy comparable to BERT BASE as shown in Table 7. 5 Moreover, we believe our approaches could also be helpful to any other task.
PKM vs. ResM One might argue that the gap between the PKM model and the ResM model might be attributed to the difference in model size. We claim that the impact of the architectural difference between PKM and ResM is more than from more parameters. ResM achieves better memory utilization, resulting in a better final performance. 0.3 higher average GLUE score with only 9M more parameters (smaller than 2% of the entire model) is significant considering that BERT-Large achieves a 1.9 higher average GLUE score with 230M more parameters than BERT-Base ( 0.3 9 1.9 230 ).

Conclusion and Future Work
This work starts from unexpected results that directly applying PKM to PLMs does not work well in downstream tasks, contrary to . In this paper, we successfully augment PKM to PLMs with two ingredients, weight initialization and residual connection, based on the observation of memory utilization and catastrophic drift during the training. Consequently, we encourage to utilize memory architecture such as PKM for PLMs in practical use. Although our approach mitigates the catastrophic drift problem somehow, we leave further study on it during both pretraining and finetuning as future work. One possible solution is to regularize a PKM memory by a structured dropout on the memory keys like DropHead (Zhou et al., 2020). It would also help to prune unnecessary memory slots on-demand during the inference time.