Out-of-Domain Detection for Low-Resource Text Classification Tasks

Out-of-domain (OOD) detection for low-resource text classification is a realistic but understudied task. The goal is to detect the OOD cases with limited in-domain (ID) training data, since in machine learning applications we observe that training data is often insufficient. In this work, we propose an OOD-resistant Prototypical Network to tackle this zero-shot OOD detection and few-shot ID classification task. Evaluations on real-world datasets show that the proposed solution outperforms state-of-the-art methods in zero-shot OOD detection task, while maintaining a competitive performance on ID classification task.


Introduction
Text classification tasks in real-world applications often consists of 2 components-In-Doman (ID) classification and Out-of-Domain (OOD) detection components Kim and Kim, 2018;Shu et al., 2017;Shamekhi et al., 2018). ID classification refers to classifying a user's input with a label that exists in the training data, and OOD detection refers to designate a special OOD tag to the input when it does not belong to any of the labels in the ID training dataset (Dai et al., 2007). Recent state-of-the-art deep learning (DL) approaches for OOD detection and ID classification task often require massive amounts of ID or OOD labeled data (Kim and Kim, 2018). In reality, many applications have very limited ID labeled data (i.e., few-shot learning) and no OOD labeled data (i.e., zero-shot learning). Thus, existing methods for OOD detection do not perform well in this setting.
One such application is the intent classification for conversational AI services, such as IBM Wat-  son Assistant 1 . For example, Table 1 shows some of the utterances a chat-bot builder provided for training. Each class may only have less than 20 training utterances, due to the high cost of manual labelling by domain experts. Meanwhile, the user also expects the service to effectively reject irrelevant queries (as shown at the bottom of Table  1). The challenge of OOD detection is reflected by the undefined in-domain boundary. Although one can provide a certain amount of OOD samples to build a binary classifier for OOD detection, such samples may not efficiently reflect the infinite OOD space. Recent approaches, such as (Shu et al., 2017), make remarkable progress on OOD detection with only ID examples. However, such condition on ID data cannot be satisfied by the few-shot scenario presented in Table 1. This work aims to build a model that can detect OOD inputs with limited ID data and zero OOD training data, while classifying ID inputs with a high accuracy. Learning similarities with the meta-learning strategy (Vinyals et al., 2016) has been proposed to deal with the problem of limited training examples for each label (few-shot learning). In this line of work, Prototypical Networks (Snell et al., 2017), which was originally introduced for few-shot image classification, has proven to be promising for few-shot ID text classification . However the usage of prototypical network for OOD detection is unexplored in this regard.
To the best of our knowledge, this work is the first one to adopt a meta-learning strategy to train an OOD-Resistant Prototypical Network for simultaneously detecting OOD examples and classifying ID examples. The contributions of this work are two-fold: 1) Unified solution using a prototypical network model which can detect OOD instances and classify ID instances in a real-world low-resource scenario. 2) Experiments and analysis on two datasets prove that the proposed model outperforms previous work on the OOD detection task, while maintaining a state-of-the-art ID classification performance.

Related Work
Out-of-Domain Detection Existing methods often formulate the OOD task as a one-class classification problem, then use appropriate methods to solve it (e.g., one-class SVM (Schölkopf et al., 2001) and one-class DL-based classifiers (Ruff et al., 2018;Manevitz and Yousef, 2007). A group of researchers also proposed an auto-encoderbased approach and its variation to tackle OOD tasks (Ryu et al., 2017(Ryu et al., , 2018. Recently, a few papers have investigated ID classification and OOD detection simultaneously (Kim and Kim, 2018;Shu et al., 2017), but they fail in a low resource setting.
Few-Shot Learning While few-shot learning approaches may help with this low-resource setting, some recent work is promising in this regard. For example, (Vinyals et al., 2016;Bertinetto et al., 2016;Snell et al., 2017) use metric learning by learning a good similarity metric between input examples; some other methods adapt a meta-learning framework, and train the model to quickly adapt to new tasks with gradients on small samples, e.g., learning the optimization step sizes (Ravi and Larochelle) or model initialization (Finn et al., 2017). Though most of these approaches are explored for computer vision, recent studies suggests that few-shot learning is promising in the text domain, including text classification Jiang et al., 2018), relation extraction (Han et al., 2018), link prediction in knowledge bases (Xiong et al., 2018) and finegrained entity typing (Xiong et al., 2019), and we put it to test with the OOD detection task.

Approach
In this paper, we target solving the zero-shot OOD detection problem for a few-shot meta-test dataset 2) The training size for each label in D train is limited (e.g. less than 100 examples). Such limitations prevent existing methods from efficiently training a model for either ID classification or OOD detection using D train only. We propose an OOD-resistant prototypical network for both OOD detection and few-shot ID classification. We follow (Snell et al., 2017) in few-shot image classification by training a prototypical network on T and directly perform prediction on D without additional training. But our method is different from the prior work in that during the meta-training, while we maximize the likelihood of the true label for an example in T i , we also sample an example from another meta-train task T j for the purpose of OOD training by maximizing the distance between the OOD instance and the prototypical vector of each ID label.

General Framework
As in Fig. 1, on a large-scale source dataset T with the following steps: 1. Sample a training task T i from T (e.g., the Book category of Amazon Review in Section 4), and another task T j from T − T i (e.g. the Apps-for-Android category). 2. Sample an ID training example x in i from T i , and a simulated OOD example x out j from T j . 3. Sample N labels (N =4) from T i in addition to the label of x in i . For the ground-truth label and N negative labels, we select K training examples for each label (K-shot learning, we set K=20). If a label has less than K examples, we replicate the selected example to satisfy K. Therefore, and the examples in S in l using a deep network (Any DL structure can be used for the encoder, such as LSTM and CNN. Here we use a one-layer CNN with a mean pooling. The detailed CNN hyper-parameters are introduced in Section 5). 5. Following (Snell et al., 2017), a Prototypical Vector representation for each label is generated, by averaging all the examples' representations of that label. 6. The model is optimized by an objective function, defined by x in i , x out j and S in . Details in Section 3.2. 7. Repeat these steps for multiple epochs (5k in this paper) to train the model, and select the best model based on an independent metavalid set T valid . T valid contains tasks that are homogeneous to the meta-test task D.
The only trainable parameters of this model are in the encoder E(·). Therefore the trained model can be easily transferred to the few-shot target domain.

Training Objective and Runtime
Prototypical networks (Snell et al., 2017) minimize a cross-entropy loss defined on the distance metrics between x in i and the supporting sets, where l i is the ground-truth label of x i , α is a rescaling factor. Here we define F as a cosine similarity score (mapped to the range between 0 and 1) 2 between the E(·)-encoded representations of x and the prototypical vector of a label. Our experiments show this meta-learning approach is efficient for ID classification, but is not good enough for detecting the OOD examples. We propose two more training losses in addition to the L in for OOD detection. The rationale behind this addition is to adopt the examples from other tasks as simulated OOD examples for the current meta-train tasks. Specifically, we first define a hinge loss on x out j and the closest ID supporting set in S in , then we push the examples from another task away from the prototypical vectors of ID supporting sets.
We expect optimizing only on L in and L ood will lead to lower confidences on ID classification, because the system tends to mistakenly reduce the scale of F in order to minimize the loss for OOD examples. Therefore we add another loss to improve the confidence of classifed ID labels.
The model is optimized on the three losses.
where α, β, γ, M 1 and M 2 are hyper-parameters, whose detailed values are shown in Section 5 . During inference, the supporting set per label is generated by averaging the encoded representations of all instances of that label in D train and the prediction is based on F (x, S in l ). OOD detection is decided with a confidence threshold.

Datasets
Our methods are evaluated on two datasets and each has many tasks and is divided into meta-train, meta-valid and meta-test sets, which are respectively used for background model training, evaluation and hyper-parameter selection.
Amazon Review 3 : We follow  to construct multiple tasks using the Amazon Review dataset (He and McAuley, 2016) . We convert it into a binary classification task of labeling the review sentiment (positive/negative). It has 21 categories of products, each of which is treated as a task. We randomly picked 13 categories as meta-train, 4 as meta-test and 4 as meta-valid. (another 3 original categories are discarded due to not enough examples to make a dataset). We construct a 2-way 100-shot problem per meta-test task by sampling 100 reviews per label in a category. For the test examples in meta-test and meta-valid, we sample other categories' examples as OOD, merged with a equal number of ID instances. We used all available data for meta-train.
Conversation Dataset: An intent classification dataset for a AI conversational system. It has 539 categories/tasks. We allocate 497 tasks as metatrain, and 42 tasks as meta-test. This dataset is different and more difficult than the typical ID few-shot learning data: 1) Both the meta-test and meta-train tasks are not restricted to N -way Kshot classification, and the source dataset is highly imbalanced across labels; 2) Each task has a variety of labels (utterance intents), whereas Amazon data always has two labels. There are 29% OOD testing instances in meta-test, which are humanlabeled and not generated from other tasks.

Experimental Results
Baselines: We compare our model O-Proto with 4 baselines: 1) OSVM (Schölkopf et al., 2001): OSVM is trained on meta-test set, and learn a domain boundary by only examining ID examples. 2) LSTM-AutoEncoder (Ryu et al., 2017): Recent work on OOD detection that uses only ID examples to train an autoencoder for OOD detection. 3) Vanilla CNN: A classifier with a typical CNN structure that uses a confidence threshold for OOD detection. 4) Proto. Network (Snell et al., 2017): A native prototypical network trained on T with only the loss L in , which uses a confidence threshold for OOD detection. We test the Proto. Network with both CNN and bidirectional LSTM as the encoder E(·).
Hyper Parameters: We introduce the hyperparameters of our model and all baselines below.
We use Python scikit-learn One-Class SVM as the basis of our OSVM implementation. We use Radial Basis Function (RBF) as the kernel and the gamma parameter is set to auto. We use squared hinge loss and L2 regularization.
We follow the same architecture as proposed in (Ryu et al., 2017) for the LSTM-Autoencoder. In LSTM, we set the input embedding dimension as 100 and hidden as 200. We use RMSprop as the optimizer with a learning rate of 0.001. We train the LSTM with a batch size of 32 and 100 epochs. For Autoencoder, we set the hidden size as 20. We use Adam as the optimizer with a learning rate of 0.001. We train the model with a batch size of 32 for 10 epochs.
For vanilla CNN, we use the most common CNN architecture used in NLP tasks, where the convolutional layer on top of word embedding has 128 filters followed by a ReLU and max pooling layer before the final softmax. We use Adam as the optimizer with a learning rate of 0.001. We train the model with a batch size 64 for 100 epochs.
Our proposed model O-Proto uses the similar CNN architecture, the optimizer and the learning rate in the previous Vanilla CNN. The input word embeddings are pre-trained by 1-billiontoken Wikipedia corpus. We set the batch size as 10. In Eq. 1, 2, 3 and 4, α, β, γ, M 1 and M 2 are hyper-parameters, which we fix β, γ as 1.0 by default, and set α, M 1 and M 2 as 10.0, 0.4 and 0.8 according to the meta-valid performance of Amazon dataset. The sentence encoder, CNN, has 200 filters, followed by a tanh and mean pooling layer before the final regression layer. The maximum length of tokens per example is 40, and any words out of this range will be discarded. During training, we set the size of sampled negative labels Step 3 (section 3.1) to at most four, so there will be maximum five labels involved in a training step (1 positive, 4 negative). The supporting set size for each label are 20.
To make a fair comparison, we follow the same hyper-parameters as O-Proto in Proto. Network, except that the weight of L ood and L gt , β and γ, are set to zero. Evaluation Metrics: Following (Ryu et al., 2017;Lane et al., 2007), we use a commonly used OOD detection metric, equal error rate (EER), which is the error rate when the confidence threshold is located where false acceptance rate (FAR) is equivalent to false rejection rate (FRR).    Table 2).
Results: Table 2   O-Proto, respectively. O-Proto without L ood completely loses the ability of OOD detection. Compared to the one without L gt , O-Proto with all losses gives a mild improvement. We also observe a more stable testing performance among epochs during training. Finally, we replace the CNN encoders with bidirectional LSTMs (the bottom of Table 1), which yields the same dimension of sentence representations as CNN. For Conversation data, we achieve the best performance on validation set when α and γ are 0.5. We observe comparable performances with respect to Proto.Network and O-Proto, showing that our proposed OOD approach is not limited to a specific sentence encoding architecture. Improvement in different K-shot settings: On the Amazon data, we construct different K-shot tasks as meta-test (results shown in Table 3), and observe consistent improvements on EER.
Effect of β: Fig. 2 and 3 show EER and CER with different β values on the two datasets. We observe within a proper range of β (between 0.5 and 2.0), the model can provide stable improvement on EER, guaranteeing competitive CER results.

Conclusion
Inspired by the Prototypical Network, we propose a new method to tackle the OOD detection task in low-resource settings. Evaluation on real-world datasets demonstrates that our method performs favorably against state-of-the-art algorithms on the OOD detection, without adversely affecting performance on the few-shot ID classification.