Bootstrapped Q-learning with Context Relevant Observation Pruning to Generalize in Text-based Games

We show that Reinforcement Learning (RL) methods for solving Text-Based Games (TBGs) often fail to generalize on unseen games, especially in small data regimes. To address this issue, we propose Context Relevant Episodic State Truncation (CREST) for irrelevant token removal in observation text for improved generalization. Our method first trains a base model using Q-learning, which typically overfits the training games. The base model's action token distribution is used to perform observation pruning that removes irrelevant tokens. A second bootstrapped model is then retrained on the pruned observation text. Our bootstrapped agent shows improved generalization in solving unseen TextWorld games, using 10x-20x fewer training games compared to previous state-of-the-art methods despite requiring less number of training episodes.


Introduction
Reinforcement Learning (RL) methods are increasingly being used for solving sequential decisionmaking problems from natural language inputs, like text-based games (Narasimhan et al., 2015;He et al., 2016;Yuan et al., 2018;Zahavy et al., 2018) chat-bots (Serban et al., 2017) and personal conversation assistants (Dhingra et al., 2017;Li et al., 2017;Wu et al., 2016).In this work, we focus on Text-Based Games (TBGs), which require solving goals like "Obtain coin from the kitchen", based on a natural language description of the agent's observation of the environment.To interact with the environment, the agent issues text-based action commands ("go west") upon which it receives a reward signal used for training the RL agent.
Traditional text-based RL methods focus on the problems of partial observability and large action spaces.However, the topic of generalization to unseen TBGs is less explored in the literature.We Goal: Who's got a virtual machine and is about to play through an fast paced round of textworld?You do! Retrieve the coin in the balmy kitchen.
Observation: You've entered a studio.You try to gain information on your surroundings by using a technique you call "looking."You need an unguarded exit ?you should try going east.You need an unguarded exit?You should try going south.You don't like doors?Why not try going west, that entranceway is unblocked.Bootstrapped Policy Action: go south show that previous RL methods for TBGs often show poor generalization to unseen test games.We hypothesize that such overfitting is caused due to the presence of irrelevant tokens in the observation text, which might lead to action memorization.To alleviate this problem, we propose CREST, which first trains an overfitted base model on the original observation text in training games using Q-learning.Subsequently, we apply observation pruning such that, for each episode of the training games, we remove the observation tokens that are not semantically related to the base policy's action tokens.Finally, we re-train a bootstrapped policy on the pruned observation text using Q-learning that improves generalization by removing irrelevant tokens.Figure 1 shows an illustrative example of our method.Experimental results on TextWorld games (Côté et al., 2018) show that our proposed method generalizes to unseen games using almost 10x-20x fewer training games compared to SOTA methods; and features significantly faster learning.et al., 2018).Recent methods (Adolphs and Hofmann, 2019;Ammanabrolu and Riedl, 2018;Ammanabrolu and Hausknecht, 2020;Yin and May, 2019;Adhikari et al., 2020) use various heuristics to learn better state representations for efficiently solving complex TBGs.

Base model
We consider the standard sequential decisionmaking setting: a finite horizon Partially Observable Markov Decision Process (POMDP), represented as (s, a, r, s ), where s is the current state, s the next state, a the current action, and r(s, a) is the reward function.The agent receives state description s t that is a combination of text describing the agent's observation and the goal statement.
The action consists of a combination of verb and object output, such as "go north", "take coin", etc.The overall model has two modules: a representation generator, and an action scorer as shown in Figure 2. The observation tokens are fed to the embedding layer, which produces a sequence of vectors x t = {x t 1 , x t 2 , ..., x t Nt }, where N t is the number of tokens in the observation text for timestep t.We obtain hidden representations of the input embedding vectors using an LSTM model as h t i = f (x t i , h t i−1 ).We compute a context vector (Bahdanau et al., 2014) using attention on the j th input token as, where W h , v and b attn are learnable parameters.
The context vector at time-step t is computed as the weighted sum of embedding vectors as c t = Nt j=1 α t j h t j .The context vector is fed into the action scorer, where two multi-layer perceptrons (MLPs), Q(s, v) and Q(s, o) produce the Q-values over available verbs and objects from a shared MLP's output.The original works of Narasimhan et al. (2015); Yuan et al. (2018) do not use the attention layer.LSTM-DRQN replaces the shared MLP with an LSTM layer, so that the model remembers previous states, thus addressing the partial observability in these environments.
Q-learning (Watkins and Dayan, 1992;Mnih et al., 2015) is used to train the agent.The param- Observation: You find yourself in a launderette.An usual kind of place.The room seems oddly familiar, as though it were only superficially different from the other rooms in the building.There is an exit to the east.Don't worry, it is unguarded.There is an unguarded exit to the west.eters of the model are updated, by optimizing the following loss function obtained from the Bellman equation (Sutton et al., 1998), (3) where Q(s, a) is obtained as the average of verb and object Q-values, γ ∈ (0, 1) is the discount factor.The agent is given a reward of 1 from the environment on completing the objective.We also use episodic discovery bonus (Yuan et al., 2018) as a reward during training that introduces curiosity (Pathak et al., 2017) encouraging the agent to uncover unseen states for accelerated convergence.

Context Relevant Episodic State Truncation (CREST)
Traditional LSTM-DQN and LSTM-DRQN methods, trained on observation text containing irrelevant textual artifacts (like "You don't like doors?" in Figure 1), that leads to overfitting in small data regimes.Our CREST module removes unwanted tokens in the observation that do not contribute to decision making.Since the base policy overfits on the training games, the action commands issued by it can successfully solve the training games, thus yielding correct (observation text, action command) pairs for each step in the training games.Therefore, by only retaining tokens in the observation text that are contextually similar to the base model's action command, we remove unwanted tokens in the observation, which might otherwise cause overfitting.
The distance between tokens is computed using cosine similarity, D(a, b).Token Relevance Distribution (TRD): We run inference on the overfitted base model for each training game (indexed by k) and aggregate all the action tokens issued for that particular game as the Episodic Action Token Aggregation (EATA), A k .For each token w i in a given observation text o k t at step t for the k th game, we compute the Token Relevance Distribution (TRD) C as: where the i th token w i 's score is computed as the maximum similarity to all tokens in A k .This relevance score is used to prune irrelevant tokens in the observation text by creating a hard attention mask using a threshold value.

Experimental Results
Setup: We used easy, medium, and hard modes of the Coin-collector Textworld (Côté et al., 2018;Yuan et al., 2018) framework for evaluating our model's generalization ability.The agent has to collect a coin that is located in a particular room.We trained each method on various numbers of training games (denoted by N#) to evaluate generalization ability from a few number of games.Quantitative comparison: We compare the performance of our proposed model with LSTM-DQN (Narasimhan et al., 2015) and LSTM-DRQN (Yuan et al., 2018).
Figure 2(b) and 2(c) show the reward of various trained models, with increasing training episodes on easy and medium games.Our method shows improved out-of-sample generalization on validation games with about 10x-20x fewer training games (500 vs. 25, 50) with accelerated training using drastically fewer training episodes compared to previous methods.
We report performance on unseen test games in Table 1.Parameters corresponding to the best validation score are used.Our method trained with N 25 and N 50 games for easy and medium levels respectively achieves performance similar to 500 games for SOTA methods.We perform ablation study with and without attention in the policy network and show that the attention mechanism alone does not substantially improve generalization.We also compare the performance of various word embeddings for TRD computation and find that Con-ceptNet gives the best generalization performance.Pruning threshold: In this experiment, we test our method's response to changing threshold values for observation pruning.Figure 4(a) and Figure 4(b) reveals that thresholds of 0.5 for easy games and 0.7 for medium games, gives the best validation performance.A very high threshold might remove relevant tokens also, leading to failure in the training; whereas a low threshold value would retain most irrelevant tokens, leading to over-fitting.Zero-shot transfer: In this experiment, agents trained on games with quest lengths of 15 rooms were tested on unseen game configurations with quest length of 20 and 25 rooms respectively without retraining, to study the zero-shot transferability of our learned agents to unseen configurations.The results in the bar charts of Figure 4(c) for N 50 easy games show that our proposed method can generalize to unseen game configurations significantly better than previous state-of-the-art methods on the coin-collector game.Generalizability to other games: In the above experimental section, we reported the results on the coin-collector environment, where the noun and verbs used in the training and testing games have a strong overlap.In this section, we present some discussion about the generalizability of this method to other games, where the context-relevant tokens for a particular game might be never seen in the training games.
Although these nouns were absent in the training action distribution, our proposed method can assign a high score to these words (except knife), since they are similar in concept to training actions.An appropriate threshold (eg.,th=0.4)can retain most tokens, which can be automatically tuned using validation games as shown in Fig 4(a) and 4(b) in the paper.Thus, as described in Section 5, assuming some level of overlap between training and testing knowledge domains, our method is generalizable and can reduce overfitting for RL in NLP.Large training and testing distribution gap is a much more difficult problem, even for supervised ML and conventional RL settings, and is out of the scope of this paper.

Conclusion
We present a method for improving generalization in TBGs using irrelevant token removal from observation texts.Our bootstrapped model trained on the salient observation tokens obtains generalization performance similar to SOTA methods, with 10x-20x fewer training games, due to better generalization; and shows accelerated convergence.In this paper, we have restricted our analysis to TBGs that feature similar domain distributions in training and test games.In the future, we wish to handle the topic of generalization in the presence of domain differences such as novel objects, and goal statements in test games that were not seen in training.

A Description of Text-based games
We used Textworld (Côté et al., 2018) framework for evaluating our model's generalization ability on text-based games.For each game, the agent is provided with a goal statement and an observation text describing the current state of the world around it only.The agent has to overcome partial observability using memory because it never sees the full state of the world.The games are inspired by the chain experiments used in (Plappert et al., 2017;Osband et al., 2016) for evaluating exploration in RL policies.The agent has to navigate through various rooms that are randomly connected to form a chain, finally reaching the goal state.We ensure that the goal statement does not contain navigational instructions by using the "--only-last-action" option.The agent is rewarded only when it successfully achieves the end-goal.We use a quest length (number of rooms to travel before reaching goal state) of L15 for training our policies.We use the Coin-collector environment for evaluating our experiments, where the agent has to collect a coin that is located in a particular room.For different games the location of coin and interconnectivity between the rooms are different.We experiment with three modes of this challenge according to (Yuan et al., 2018): easy, there are no distractor rooms (dead-ends) along the optimal path, and the agent needs to choose a command  is the avergare of final reward of 1.0 on completion of the quest and 0.0 otherwise, thus measuring the average success rate.Our method shows better generalization from significantly less number of training games and faster learning with fewer episodes for all cases of "easy", "Medium" and "hard" validation games.
that only depends on the previous state; medium, there is one distractor per room in the optimal path and the agent has to issue a reverse command of its previous command to come out of such distractor rooms; hard, there are two distractors per rooms in the optimal path and the agent has to issue a reverse command of its previous command to come out of such distractor rooms, in addition to remembering longer into the past to successfully keep track of which paths it has already traveled.

B Experimental Setup
For the Textworld coin collector environment, we use 10 verbs and 10 nouns in the vocabulary which is learned using Q learning.This is different from previous methods that use only 2 words and 5 objects thus increasing the complexity of Q learning slightly.However, it is to be noted that the problem of generalization exists even for less number of action tokens as reported in (Yuan et al., 2018) and is not significantly aggravated on a slightly larger action space (10 vs 100 combinations).The configuration used in our base and bootstrapped model learning is the same as the previous methods with the only change being the addition of the attention layer.We trained each environment for 6000 epochs with annealing of 3600 epochs from a starting value of 1.0 to 0.2.Each training experiment took about 5-6 hours for completion.Our experiments were conducted on a Ubuntu16.04system with a Titan X (Pascal) GPU.We use a sin-gle LSTM network with 100 dimensional hidden units in the representation generator.For the action scorer, a single LSTM network with 64-dim hidden unit (for DRQN ) and two MLPs for verb and object Q-values were used.The number of trainable parameters in our policy network is 128,628 for the model with attention and 125,364 without attention.
In our experiments, we wish to investigate the generalization property of our method on a small number of training games.To that end, we wish to answer the following questions: (1) Can our proposed observation masking system out-perform previous RL methods for TBGs using less training data with accelerated learning?, (2) Is there a positive correlation between the strength of observation masking and generalization performance?and (3) Does our method understand the semantic meaning of the games to perform zero-shot generalization to unseen configurations of TBGs?

C Improved generalization by CREST
The generalization ability of the learned base policy is measured by the performance on unseen games that were not used during training the policy network.We measure the reward obtained by the agent in each episode which is the metric of success in our experiments.During the evaluation of unseen games, only the environment reward is used and the episodic discovery bonus is turned off.Since a  Observation: You find yourself in a studio.An usual kind of place.Okay, just remember what you're here to do, and everything will go great.There is an exit to the east.Don't worry, it is unblocked.There is an unguarded exit to the north.There is an unguarded exit to the south.
Observation: You find yourself in a launderette.An usual kind of place.The room seems oddly familiar, as though it were only superficially different from the other rooms in the building.There is an exit to the east.Don't worry, it is unguarded.There is an unguarded exit to the west.
(a) Easy games (N50) Observation: You've just walked into a chamber.You begin to take stock of what's here.There is an unblocked exit to the east.There is a coin on the floor.
Observation: You've entered a cookhouse.You begin to take stock of what's in the room.You need an unguarded exit?You should try going north.There is an exit to the south.Don't worry, it is unguarded.There is a coin on the floor.
Observation: You arrive in an office.A normal kind of place.You don't like doors?Why not try going east, that entranceway is unblocked.There is an exit to the north.Don't worry, it is unblocked.You need an unblocked exit?You should try going west.There is a coin on the floor.reward of 1.0 is obtained on the completion of the game, the average reward can also be interpreted as the average success rate in solving the games.
The verb and object tokens corresponding to the maximum Q-value are chosen as the action command.Traditional LSTM-DQN and LSTM-DRQN methods are trained on observation text descriptions that include irrelevant textual artifacts, which might lead to overfitting in small data regimes.To demonstrate this effect, we plot the performance of LSTM-DRQN (SOTA on coin-collector) and LSTM-DQN on the Coin-Collector easy, medium, and hard games on various training and 20 unseen validation games in Figure 5.In each training episode, a random batch of games is sampled from the available training games and Q-learning is performed.
While for a large number of training games (500), the SOTA policies can solve most of the validation games (especially for easy games).However, the performance degrades significantly for less number of training games.On the other hand, the training performance shows a 100% success rate indicating overfitting.This kind of behavior might occur if the agent associates certain action commands to irrelevant tokens in the observation.For example, the agent might encounter games in training where observation tokens, "a typical kind of place" correspond to the action of "go east".In this case, the agent might learn to associate such irrelevant tokens to the "go east" command without actually learning the true dependency on tokens like "there is a door to the east".

C.1 Quantitative evaluation of generalization
Our proposed method shows better generalization performance as is evident from Figure 5 of unwanted tokens from observation text.We show the visualization of Token Relevance Distributions (TRDs) obtained by our method for easy, medium, and hard games in Figure 6.Each token has a similarity score between 0 and 1, indicating how relevant it is for making decisions about the next action.Tokens with a score less than a threshold are pruned.We also perform such observation pruning in the testing phase.Therefore, our proposed method learns on such clean observation texts which are also tested on unseen pruned texts, which leads to improved generalization.

C.2 Zero-shot transfer
While in the previous experiments the training and evaluation games had the same quest length configuration, in this experiment we evaluate our method on games with different configurations of coincollector never seen during training.Specifically, during training, we use games with quest lengths of 15 rooms.The models trained on such configuration are tested on games with quest length of 20 and 25 rooms respectively without any retraining.This is aimed to study the zero-shot transferability to other configurations that the agent has never encountered before.The results are shown in Table 2 for all modes of the coin-collector games, show that our proposed observation masking method can also generalize to unseen game configurations with increased quest length and largely outperforms the previous state-of-the-art methods.Our method, CREST learns to retain important tokens in the observation text which leads to a better semantic understanding resulting in the better zero-shot transfer.
In contrast, the previous method can overfit to the unwanted tokens in the observation text that does not contribute to the decision making process.

D Discussion
Empirical evaluation shows that our observation masking method can successfully reduce the overfitting problem in RL for Text-based games by reducing irrelevant tokens.Our method also learns at an accelerated rate requiring fewer training episodes due to pruned textual representations.We show that observation masking leads to better generalization, as demonstrated by superior performance for our CREST method with accelerated convergence with less number of training games as compared to the state-of-the-art method.
In this paper, we assume that the domain distribution between the training and evaluation are similar in our environments because our goal is to explore generalization by observation pruning without additional heuristic learning components.This means the evaluation games will have similar objectives as seen during the training games, and similar objects would be encountered in the evaluation games without encountering any novel objects.For example, if the goal objective is set as "pickup the coin" in the training games, it will not be changed to "eat the apple" which was never seen before in training.
To handle such environments with domain divergence, training needs to be performed with external datasets that show a satisfactory level of overlap with the domain of unseen test games.However, such training from external sources can be readily combined with our existing proposal in this paper using previous hand-crafted methods like Adolphs and Hofmann (2019); Ammanabrolu and Riedl (2018).

Figure 1 :
Figure1: Our method retains context-relevant tokens from the observation text (shown in green) while pruning irrelevant tokens (shown in red).A second policy network re-trained on the pruned observations generalizes better by avoiding overfitting to unwanted tokens.

Figure 2 :
Figure 2: (a) Overview of Context Relevant Episodic State Truncation (CREST) module using Token Relevance Distribution for observation pruning.Our method shows better generalization from 10x-20x less number of training games and faster learning with fewer episodes on (b) "easy" and (c) "medium" validation games.

Figure 3 :
Figure 3: Ranking of context-relevant tokens from observation text by our token relevance distribution.

Figure 4 :
Figure 4: Comparison of validation performance for various thresholds on (a) easy and (b) medium games, (c) Our method trained on L15 games and tested on L20 and L25 games significantly outperforms the previous methods.

Figure 5 :
Figure5: Training and validation games learning curve for various games.The metric of measurement (y-axis) is the avergare of final reward of 1.0 on completion of the quest and 0.0 otherwise, thus measuring the average success rate.Our method shows better generalization from significantly less number of training games and faster learning with fewer episodes for all cases of "easy", "Medium" and "hard" validation games.

Figure 6 :
Figure6: Showing the relevance distribution of observation token for easy, medium and hard games along with original observation text.The top row shows a non-terminal observation where the "coin" is not present.The second row shows terminal states.Each relevance score is bounded between [0,1].The bootstrapped model is trained on tokens that have relevance above some threshold to remove irrelevant tokens.

Table 1 :
The average success rate of various methods on 20 unseen test games.Experiments were repeated on 3 random seeds.Our method trained on almost 20x fewer data has a similar success rate to state-of-the-art methods.
. An example observation looks like this: " . . .YOU SEE A FRIDGE.THE FRIDGE CONTAINS SOME WATER, A DICED CILANTRO AND A DICED PARSLEY.YOU WONDER IDLY WHO LEFT THAT HERE.WERE YOU LOOKING FOR AN OVEN?BECAUSE LOOK OVER THERE, IT'S AN OVEN.WERE YOU LOOKING FOR A TABLE? BECAUSE LOOK OVER THERE, IT'S A TABLE.THE TABLE IS MASSIVE.ON THE TABLE YOU MAKE OUT A COOKBOOK AND A KNIFE.YOU SEE A COUNTER.HOWEVER, THE COUNTER, LIKE AN EMPTY COUNTER, HAS NOTHING ON IT . . .".The objective of this game is to prepare a meal following the recipe found in the kitchen and eat it.
You have fallen into a salon.Not the salon you'd expect.No, this is a salon.You start to take note of what's in the room.You need an unguarded exit?You should try going east.You don't like doors?Why not try going north, that entranceway is unblocked.There is an unblocked exit to the south.You don't like doors?Why not try going west, that entranceway is unblocked. Observation:

Table 2 :
Average succes in zero-shot transfer to other configurations.We trained the RL policies for L15 games and test the performance on L20 and L25 unseen game configurations.CREST significantly outperforms the previous methods on such tasks for all cases of easy, medium and hard games.