Variance-reduced First-order Meta-learning for Natural Language Processing Tasks

First-order meta-learning algorithms have been widely used in practice to learn initial model parameters that can be quickly adapted to new tasks due to their efficiency and effectiveness. However, existing studies find that meta-learner can overfit to some specific adaptation when we have heterogeneous tasks, leading to significantly degraded performance. In Natural Language Processing (NLP) applications, datasets are often diverse and each task has its unique characteristics. Therefore, to address the overfitting issue when applying first-order meta-learning to NLP applications, we propose to reduce the variance of the gradient estimator used in task adaptation. To this end, we develop a variance-reduced first-order meta-learning algorithm. The core of our algorithm is to introduce a novel variance reduction term to the gradient estimation when performing the task adaptation. Experiments on two NLP applications: few-shot text classification and multi-domain dialog state tracking demonstrate the superior performance of our proposed method.


Introduction
Meta-learning has recently emerged as a promising approach in solving many natural language processing tasks, such as few-shot text classification (Obamuyide and Vlachos, 2019;Bao et al., 2019), low resource language understanding (Gu et al., 2018;Dou et al., 2019;Yu et al., 2020a), and multi-domain dialogue systems (Qian and Yu, 2019;Huang et al., 2020). In particular, modelagnostic meta-learning (MAML) (Finn et al., 2017), a widely-used meta-learning approach, trains an initial model that can be adapted to a new task with a small number of optimization steps and training data. However, MAML requires the computation of second-order derivatives, which can be costly for reinforcement learning and NLP applications. Therefore, numerous computationallyefficient MAML variants (Finn et al., 2017;Li et al., 2017;Nichol et al., 2018;Antoniou et al., 2018;Zintgraf et al., 2019;Song et al., 2020) have been proposed in recent years. First-order metalearning (Finn et al., 2017;Nichol et al., 2018) is a widely-used method in practice because it is easy to implement, eliminates computationally-intensive second-order derivatives in MAML, and achieves state-of-the-art performance.
Although meta-learning including first-order meta-learning has shown promising performances in many applications (Triantafillou et al., 2019), it still somewhat struggles to learn on diverse task distributions (Triantafillou et al., 2020;Rebuffi et al., 2017;Yu et al., 2020c). For first-order metalearning, it consists of task adaptation and meta updates. Task adaptation aims to obtain a taskspecific model for each task by performing several optimization steps based on the current meta model. Then, the meta update aggregates the gradient information of task-specific models to obtain a new meta model. It has been observed in many previous works (Zhao et al., 2018;Karimireddy et al., 2019;Charles and Konečnỳ, 2020) that local update methods, including first-order meta-learning, performing multiple optimization steps on local data can lead to overfitting to atypical local data. In the context of first-order meta-learning, due to the large variance of the gradient estimator, task adaptation will drive task-specific models to move away from each other, resulting in that the gradients used in meta update have diverse directions. Furthermore, since the difference in gradient magnitudes will also be large, the task with a much larger gradient in magnitude will dominate the task adaptation. As a result, the meta update will overfit to this dominating task. Similar issues have been studied in multi-task learning: Yu et al. (2020b) showed that conflicting gradients, i.e., two gradients that have a negative cosine similarity, can lead to significantly degraded performance when the difference in gradient magnitudes is large.
The above gradient variance issue, i.e., the large variance from the gradient estimator, is significant in NLP applications since many NLP datasets have diverse properties, and the tasks for meta-learning in NLP applications also have their unique characteristics. For example, the MultiWOZ dataset (Budzianowski et al., 2018) for dialog systems and the Spider dataset  for semantic parsing, both consist of complex and cross domain examples. To address the aforementioned gradient variance issue in NLP applications when applying first-order meta-learning approaches, we propose a variance-reduced first-order meta-learning (VFML) algorithm. The key idea of our algorithm is that we leverage a novel variance reduction term in the task adaptation steps to reduce the variance of the gradient estimator. We evaluate our proposed method on two NLP applications: few-shot text classification and domain adaptation in multi-domain dialog state tracking. We experiment on several benchmark datasets, finding that our method produces models that can achieve better performances than the baseline Reptile (Nichol et al., 2018).

Problem Setup and Preliminaries
Let T = {T i } i∈I be the set of all tasks and I be the task index set. Suppose T i is drawn from T with probability p i , and we use p to denote the probability distribution over T . Our goal is to find an initial model θ such that it will have a small loss on a new task T i after a few steps of updates. Therefore, we want to solve the following problem is the function that updates the initial model parameter θ for K steps on task T i .

First-order meta learning
To solve the problem in equation 2.1, MAML uses task adaptation, i.e., f K i (θ), and the following meta update based on sampled tasks where τ is the step size, I b is the index set of the sampled tasks, and f K i (θ) is usually K steps of gradient descent. A more efficient and effective MAML variant is the first-order method (Finn et al., 2017;Nichol et al., 2018). For instance, Finn et al. (2017) proposed to replace the Hessian matrix in meta update with an identity matrix, which leads to First-order MAML (FOMAML). Nichol et al. (2018) proposed Reptile to further simplify FO-MAML by using the the following meta update In this work, we propose a new method based on Reptile to improve the performance of first-order meta-learning methods.

Method
Our proposed algorithm for meta-learning is illustrated in Algorithm 1. In the following discussion, we use ∇L i,B i t to denote the mini-batch stochastic gradient for task i and B i t is the sample index set. The main idea of our method is to construct a variance reduction term v, which is motivated by the stochastic recursive momentum technique proposed in (Cutkosky and Orabona, 2019). v will be used in the task adaptation step (line 4 in Algorithm 1) to reduce the variance of the gradient estimator. More specifically, we use the gradient Algorithm 2) to update the task-specific model for task T i . g i k is a weighted sum of the mini-batch stochastic gradient ∇L i,B i k (w i k ) and the variance reduction term v, and (1 − γ) is the weight for v. When γ = 1, it reduces to Reptile. We initialize the variance reduction term v 0 by averaging the gradients from a set of tasks which are randomly sampled and computed using the initialization θ 0 .
Next, we briefly discuss the intuition of why our proposed method can reduce the variance of the gradient estimator.
2 is the variance introduced by the dissimilarity between tasks. Intuitively, the variance of the gradient estimator in Reptile, i.e., 2 2 , will be determined by the following quantity In addition, the variance of the gradient estimator in VFML, i.e., 2 . If we have a large number of examples for each task, then σ 2 1 will be small, and the variance of the gradient estimator in Reptile will be determined by O(σ 2 2 ). When we have very diverse task distributions, σ 2 2 will be large, which can lead to a significant degradation in performance. However, for VFML, the variance will be dominated by O γ 2 σ 2 2 +(1−γ) 2 (β 2 σ 2 2 +(1−β) 2 ∆ 2 t+1 ) . Since σ 2 2 can be much smaller than σ 2 2 and ∆ 2 t+1 goes to zero as our algorithm convergences, the variance of g i k can be much smaller than σ 2 2 by choosing appropriate parameters β, γ. Therefore, the role of the variance reduction term v is to alleviate the variance introduced by the task dissimilarity.

Experiments
We evaluate our proposed method on one simulation experiment and two NLP applications: text classification and dialog state tracking.

Simulation
To validate the effectiveness of our proposed method, we consider the one-dimensional sine wave regression (Finn et al., 2017;Nichol et al., 2018). Our goal is to learn a neural network that can quickly adapt to a given sine wave function after a few adaptation steps. We follow the same experimental setup in the previous work (Nichol et al., 2018), and we compare our proposed method with Reptile (Nichol et al., 2018) in terms of the mean square error between the output of the adapted neural network and the sine wave function. Parameters: For both methods, we sample 10 tasks at each outer loop iteration and use 10 examples, i.e., b = 10, to compute the minibatch stochastic gradients. We choose K = 3, η = 0.01 for the task adaptation step, and choose τ = 1 for the meta update.  Figures 1(a) and 1(b) show that VFML can reduce the iteration numbers and achieve better performance in terms of training and test accuracy than Reptile. Figures 1(c) and 1(d) illustrates that our proposed method can quickly converge to a given sine wave function. These results validate the superiority of VFML.

Few-shot Text Classification
We consider two text classification datasets: Amazon (He and McAuley, 2016) and FewRel (Han et al., 2018). For Amazon dataset, it consists of customer reviews from 24 product categories, and we follow the previous work (Bao et al., 2019) to sample 1000 reviews for each category. For this dataset, our goal is to classify a given review into its corresponding product category. FewRel is a relation classification dataset, and each example is a sentence annotated with a head entity, a tail entity, and their relation. For FewRel, we aim to predict the relation between the head and tail in a given sentence.
We follow the experimental setup in previous work (Bao et al., 2019). We consider the N -way K-shot setting, where N is the number of classes in each task, and K is the number of examples in the class. Baseline models: For this problem, we consider the convolutional neural network (CNN) based  Figures 1(a) and 1(b) show the training error and test error versus the number of iterations. Figures 1(c) and 1(d) Table 1: Results of text classification on Amazon and FewRel datasets. We consider 5-way N -shot settings with N = 5, 10, 50. We report the classification accuracy with the standard deviation over 10 trials.  Table 2: Results of different meta-learning methods on Amazon and FewRel datasets in 5-way 5-shot settings. We report the classification accuracy with the standard deviation over 10 trials.  model proposed in (Bao et al., 2019). More specifically, we use a CNN as the embedding model to generate the input representation and a one-hiddenlayer neural network with 300 units and ReLU activation as the classifier. Parameters: For both Reptile and our method, we choose K by searching the grid {1, 3, 5, 10}, η by {0.01, 0.05, 0.1, 0.3, 0.5} for the task adaptation step, and choose τ = 1 for the meta update. For our proposed method, we choose γ by searching the grid {0.1, 0.3, 0.5, 0.7, 0.9} and β by {0.05, 0.1, 0.3, 0.5, 0.7, 0.9, 1.0}. Results: Table 1 summarizes the comparisons of different methods on Amazon and FewRel datasets for text classification. The results are averaged over 10 runs. In the 5-way 5-shot setting, our proposed method can achieve 1% and 0.8% improvements in terms of classification accuracy on Amazon and FewRel datasets, respectively. Analysis: We also consider the 5-way 10-shot and 5-way 50-shot settings. These two settings are used to evaluate our proposed method's performance when the variance of the gradient estimator is dominated by the variance introduced by the task dissimilarity. The results show that, when we have 50 shots, our proposed method can achieve 2.02% and 2.5% gains on classification accuracy on Amazon and FewRel datasets, respectively. The results in 10-shot and 50-shot settings validate the effectiveness of the variance reduction term, i.e., it is used to alleviate the variance of the gradient estimator introduced by the task dissimilarity. We also compare our proposed method with the MAML and FO-MAML methods proposed in (Finn et al., 2017) on Amazon and FewRel datasets in 5-way 5-shot settings. Table 2 shows that our method outperforms these two baselines.

Dialog State Tracking
We also test our VRML method on the task of multi-domain dialog state tracking (DST). We experiment on the MultiWOZ (Budzianowski et al., 2018), a large scale, multi-domain human-human dialog state tracking dataset. It had been introduced to help facilitate research to solve the DST problem. This corpus contains 8438 multi-turn dialogues with on average of 13.7 turns per dialogue. Multi-domain dialog state tracking in MultiWOZ is a challenging task for meta-learning, due to the differences in dialogues between each domain. For example, the dialog states, and user utterances for hotel and train are quite different. We use the most frequent five domains: (restaurant, hotel, attraction, taxi, train). We follow the same setup in (Huang et al., 2020) by training on three source domains: hotel, restaurant and train, and testing on 1% of the target domains: (taxi, attraction). We compare our method with Reptile and the train-from-scratch, i.e., we train a randomly initialized model using data from the target domain. We use joint and slot accuracy  to evaluate different methods. Joint accuracy measures the accuracy of dialogue states, where a dialogue state is correctly predicted only if all the values for (domain, slot) pairs are correctly predicted. Slot accuracy measures the accuracy of each (domain, slot, value) tuples for the dialog state. Baseline models: We quantify the benefits of different meta-learning algorithms by comparing the results on top of the TRADE model architecture . TRADE is an encoder-decoder model utilizing two BiGRUs to encode sequences of dialogue turns, and then generating corresponding (domain, slot, value) tuples. We set the hidden size of the encoder and decoder to be 400 and use Glove embedding (Pennington et al., 2014). Parameters: For both Reptile and our method, we choose K by searching the grid {1, 3, 5}, η by {0.01, 0.05, 0.1} for the task adaptation step, and choose τ = 1 for the meta update. For our proposed method, we choose γ by searching the grid {0.1, 0.3, 0.5, 0.7, 0.9} and β by {0.05, 0.1, 0.3, 0.5, 0.7, 0.9, 1.0}. Following the previous work Huang et al., 2020), we set batch size to 32, dropout rate to 0.2. For the finetune step, we search the batch size by the grid {4, 8, 16, 32} and setp size by {0.01, 0.05, 0.1}. We early stop the training of both methods when the validation accuracy converges. Table 3 reports the joint and slot accuracy for different methods. The results show that, when we have 1% of the target domain data for finetuning, our proposed method can achieve 2.44% and 1.01% improvements in slot and joint accuracy compared with Reptile for Attraction. Compared with train-from-scratch, we can obtain 15.27% and 10.59% gains in slot and joint accuracy. Similar improvements can be obtained for Taxi. Analysis: We also consider the case when we have more target domain data for finetuning. Table 3 shows that the more target domain data we have, the more gains our method can obtain. For example, when we have 10% data for Taxi, our method can achieve 6.64%/13.13% improvements in slot/joint accuracy compared with train-from-scratch. Compared with Reptile, we can obtain 3.32%/1.09% gains in slot/joint accuracy. Note that there is no change of performance for the train-from-scratch method on 1%/5%/10% Taxi data, due to the small size of the Taxi dataset. If we train on the entire Taxi data, the joint/slot accuracy would be 75.61%/89.61%. These results show that metalearning indeed helps when the target data is small, and VRML is very effective on using the small amount of target data compared to Reptile.

Conclusion
We propose a novel first-order meta-learning method to reduce the variance of the gradient estimator used in task adaptation for NLP tasks. We show in both few-shot text classification and DST that our method can achieve better performance than existing methods. It is interesting to further study domain adaptation methods built upon our new algorithm.