Text Level Graph Neural Network for Text Classification

Recently, researches have explored the graph neural network (GNN) techniques on text classification, since GNN does well in handling complex structures and preserving global information. However, previous methods based on GNN are mainly faced with the practical problems of fixed corpus level graph structure which don’t support online testing and high memory consumption. To tackle the problems, we propose a new GNN based model that builds graphs for each input text with global parameters sharing instead of a single graph for the whole corpus. This method removes the burden of dependence between an individual text and entire corpus which support online testing, but still preserve global information. Besides, we build graphs by much smaller windows in the text, which not only extract more local features but also significantly reduce the edge numbers as well as memory consumption. Experiments show that our model outperforms existing models on several text classification datasets even with consuming less memory.


Introduction
Text classification is a fundamental problem of natural language processing (NLP), which has lots of applications like SPAM detection, news filtering, and so on (Jindal and Liu, 2007;Aggarwal and Zhai, 2012). The essential step for text classification is text representation learning.
With the development of deep learning, neural networks like Convolutional Neural Networks (CNN) (Kim, 2014) and Recurrent Neural Networks (RNN) (Hochreiter and Schmidhuber, 1997) have been employed for text representation. Recently, a new kind of neural network named Graph Neural Network (GNN) has attracted wide attention (Battaglia et al., 2018). GNN was first proposed in (Scarselli et al., 2009) and has been used in many tasks in NLP including text classification (Defferrard et al., 2016), sequence labeling (Zhang et al., 2018a), neural machine translation (Bastings et al., 2017), and relational reasoning (Battaglia et al., 2016). Defferrard et al. (2016) first employed Graph Convolutional Neural Network (GCN) in text classification task and outperformed the traditional CNN models. Further, Yao et al. (2019) improved Defferrard et al. (2016)'s work by applying article nodes and weighted edges in the graph, and their model outperformed the state-of-the-art text classification methods.
However, these GNN-based models usually adopt the way of building one graph for the whole corpus, which causes the following problems in practice. First, high memory consumption is required due to numerous edges. Because this kind of methods build a single graph for the whole corpus and use edges with fixed weights, which considerably limits the expression ability of edges, they have to use a large connection window to get a global representation. Second, it is difficult for this kind of models to conduct the online test, because the structure and parameters of their graph are dependent on the corpus and cannot be modified after training.
To address the above problems, we propose a new GNN based method for text classification. Instead of building a single corpus level graph, we produce a text level graph for each input text. For a text level graph, we connect word nodes within a reasonably small window in the text rather than directly fully connect all the word nodes. The representations of the same nodes and weights of edges are shared globally and can be updated in the text level graphs through a massage passing mechanism, where a node takes in the information from neighboring nodes to update its representation. Finally, we summarize the representations of all the nodes in the graph to predict the results. With our design, text level graphs remove the burden of dependency between a single input text and the entire corpus, which support online test. Besides, it has the benefit of consuming less memory by connecting words in a small contextual window, because it excludes a good many words that are far away in the text and have little relation with the current word and thus significantly reduces the number of edges. The message passing mechanism makes nodes in the graph perceive information around them to get precise meaning in a specific context.
In our experiments, our method achieves stateof-the-art results in several text classification datasets and consumes significantly fewer memory resources compared with previous methods.

Method
In this section, we will introduce our method in detail. First, we show how to build a text level graph for a given text; all the parameters for the text level graph are taken from some global-sharing matrices. Then, we introduce the message passing mechanism on these graphs to obtain information from the context. Finally, we depict how to predict the label for a given text based on the learned representations. The overall architecture of our model is shown in Figure 1.

Building Text Graph
We notate a text with l words as T = {r 1 , ...r i , ..., r l }, where r i denotes the representation of the i th word. r i is a vector initialized by d dimension word embedding and can be updated by training. To build a graph for a given text, we regard all the words that appeared in the text as the nodes of the graph. Each edge starts from a word in the text and ends with its adjacent words. Concretely, the graph of text T is defined as: where N and E are the node set and edge set of the graph, and word representations in N and edge weights in E are taken from global shared matrices. p denotes the number of adjacent words connected to each word in the graph. Besides, we uniformly map the edges that occur less than k times in the training set to a "public" edge to make parameters adequately trained. | V | × d Figure 1: Structure of graph for a single text "he is very proud of you.". For the convenience of display, in this figure, we set p = 2 for the node "very" (nodes and edges are colored in red) and p = 1 for the other nodes(colored in blued). In actual situations, the value of p during a session is unique. All the parameters in the graph come from the global shared representation matrix, which is shown at the bottom of the figure.
Compared with the previous methods in building graph, our approach can exceedingly reduce the scale of the graph in terms of nodes and edges. That means that the text-level graph can consume much less GPU memory. Besides, their method is unfriendly to new-coming text, while our approach can solve this problem because the graph for each text is only dependent on its content.

Message Passing Mechanism
Convolution can extract information from local features (LeCun et al., 1989). In the graph domain, convolution is implemented by spectral approaches (Bruna et al., 2014;Henaff et al., 2015), or non-spectral approaches (Duvenaud et al., 2015). In this paper, a non-spectral method named message passing mechanism (MPM) (Gilmer et al., 2017) is employed for convolution. MPM first collects information from adjacent nodes and updates its representations based on its original representations and collected information, which is defined as: where M n ∈ R d is the messages that node n receives from its neighbors; max is a reduction function which combines the maximum values on each dimension to form a new vector as an output. N p n denotes nodes that represent the nearest p words of n in the original text; e an ∈ R 1 is the edge weight from node a to node n, and it can be updated during training; and r n ∈ R d denotes the former representation of node n. η n ∈ R 1 is a trainable variable for node n that indicates how much information of r n should be kept. r n denotes the updated representation of node n. MPM makes the representations of nodes influenced by neighborhoods, which means the representations can bring the information from context. Therefore, even for polysemous words, the precise meaning in the context can be determined by the influence of weighted information from neighbors. Besides, the parameters of text level graphs are taken from global shared matrices, which means the representations can also bring global information as other graph-based models do.
Finally, the representations of all nodes in the text are used to predict the label of the text: where W ∈ R d×c is a matrix mapping the vector into an output space, N i is the node set of text i and b ∈ R c is bias.
The goal of training is to minimize the crossentropy loss between ground truth label and predicted label: where g i is the "one-hot vector" of ground truth label.

Experiments
In this section, we describe our experimental setup and report our experimental results.

Experimental Setup
For experiments, we utilize datasets including R8, R52 1 , and Ohsumed 2 . R8 and R52 are both the subsets of Reuters 21578 datasets. Ohsumed corpus is extracted from MEDLINE database. MED-LINE is designed for multi-label classification, we remove the text with two or more labels. For all the datasets above, we randomly select 10% text from the training set to build validation set. The overview of datasets is listed in Table 1. We compare our method with the following baseline models. It is noted that the results of some models are directly taken from (Yao et al., 2019). • CNN Proposed by (Kim, 2014), perform convolution and max pooling operation on word embeddings to get representation of text.
• LSTM Defined in (Liu et al., 2016), use the last hidden state as the representation of the text. Bi-LSTM is a bi-directional LSTM.
• Text-GCN A graph based text classification model proposed by (Yao et al., 2019), which builds a single large graph for whole corpus.

Implementation Details
We set the dimension of node representation as 300 and initialize with random vectors or Glove (Pennington et al., 2014). k discussed in Section 2.1 is set to 2. We use the Adam optimizer (Kingma and Ba, 2014) with an initial learning rate of 10 −3 , and L2 weight decay is set to 10 −4 . Dropout with a keep probability of 0.5 is applied after the dense layer. The batch size of our model is 32. We stop training if the validation loss does not decrease for 10 consecutive epochs. For baseline models, we use default parameter settings as in their original papers or implementations. For models using pre-trained word embeddings, we used 300-dimensional GloVe word embeddings. Table 2 reports the results of our models against other baseline methods. We can see that our model can achieve the state-of-the-art result.

Experimental Results
We note that the results of graph-based models are better than traditional models like CNN, LSTM, and fastTest. That is likely due to the characteristics of the graph structure. Graph structure  Table 2: Accuracy on several text classification datasets. Model with "*" means that all word vectors are initialized by Glove word embeddings. We run all models 10 times and report mean results. allows a different number of neighbor nodes to exist, which enables word nodes to learn more accurate representations through different collocations. Besides, the relationship between words can be recorded in the edge weights and shared globally. These are all impossible for traditional models.
We also find that our model performs better than graph-based models like Graph-CNN. Graph-CNN represents documents using the bag-of-word model, which is similar to ours, but they connect word nodes within a large window without weighted edges, which cannot distinguish the importance between different words. While our model employed trainable edge weights, which let words express themselves differently when faced with various collocation. Besides, the weights are shared globally which means they can be trained by all the text contains the same collocation in the entire corpus.
We also note that our model performs better than former state-of-the-art model Text-GCN. That is likely due to more expressive edges, which have been discussed before, and the difference of representations learning. Text-GCN learns word representations by corpus level cooccurrence while our model is trained within a contextual window like traditional word embeddings. Therefore our model can benefit from pretrained word embeddings and achieve better results. Table 3 reports the comparison of memory consumption and edges numbers between Text-GCN and our model. Results show that our model has a significant advantage in memory consumption.   Figure 2: Model performance using p from 1 to 19 and "∞" (fully-connected). All hyperparameters are set the same except p. The left and right ordinate indicate the accuracy on the r8 and ohsumed dataset respectively.

Analysis of Memory Consumption
As discussed in 2.1, the words in our model are only connected to adjacent words in the texts, while Text-GCN , which is based on the corpus level graph, connects nodes within a reasonably large window. Because Text-GCN uses cooccurrence information as fixed weights, it has to enlarge the window size to get a more accurate co-occurrence weight. Therefore, we will get a much more sparse edge weights matrix than Text-GCN. Also, since the representation of a text is calculated by the sum of the representations of word nodes in the text, there is no text node in our model, which also reduces memory consumption.

Analysis of Edges
To understand the difference of various connecting windows, we compared the performance of the R8 and ohsumed datasets with different p values, the result is reported in Figure 2. We find that the accuracy increases as p becomes larger and achieves the best performance when connected with about 3 neighborhoods. Then the accuracy decreases volatility as p increases. This suggests that when connected only with the nearest neighborhood, nodes cannot understand the dependencies that span multiple words in the context, while connected with neighborhoods far away (much larger p), the graphs become more and more similar with fully connected graphs which ignore the local features. In addition, the fewer edges, the  Table 4: Results of ablation studies. We run all models for 5 times and give mean results. fewer memory consumption. Our model has fewer edges compared with previous methods, and this also show the advantages of our proposed model.

Ablation Study
To further analyze our model, we perform ablation studies and Table 4 shows the results.
In (1), we fix the weights of edges and initialize them with point-wise mutual information (PMI), and the size of sliding windows is set to 20, which is the same as (Yao et al., 2019). Removing the trainable edges makes the model perform worse on all data sets, which demonstrates the effectiveness of trainable edges. In our opinion, the main reason is that trainable edges can better model the relations between words compared with fixed edges.
In (2), we change the max-reduction by meanreduction. In the original model, the node gets its new representation from received messages by obtaining the maximum value alone each dimension. From Table 4, we can see that the max reduction can achieve better results. The node reduction function is similar to the pooling operation on CNN. Reduction by max highlights features that are highly discriminating and provides nonlinearity, which helps to achieve better results.
In (3), we remove the pre-trained word embeddings from nodes and initialize all the nodes with random vectors. Compared with the original model, the performances are slightly decreased without pre-trained word embeddings. Therefore, we believe that the pre-trained word embeddings have a particular effect on improving the performance of our model.

Related Work
In this section, we will introduce the related works about GNN and text classification in detail.

Graph Neural Networks
Graph Neural Networks (GNN) has got extensive attention recently (Zhou et al., 2018;Zhang et al., 2018b;Wu et al., 2019). GNN can model non-Euclidean data, while traditional neural networks can only model regular grid data. While many tasks in reality such as knowledge graphs (Hamaguchi et al., 2017), social networks (Hamilton et al., 2017) and many other research areas  are with data in the form of trees or graphs. So GNN are proposed (Scarselli et al., 2009) to apply deep learning techniques to data in graph domain.

Text Classification
Text classification is a classic problem of natural language processing and has a wide range of applications in reality. Traditional text classification like bag-of-words (Zhang et al., 2010), n-gram (Wang and Manning, 2012) and Topic Model (Wallach, 2006) mainly focus on feature engineering and algorithms. With the development of deep learning techniques, more and more deep learning models are applied for text classification. Kim (2014); Liu et al. (2016) applied CNN and RNN into text classification and achieved results which are much better than traditional models.
With the development of GNN, some graphbased classification models are gradually emerging (Hamilton et al., 2017;Veličković et al., 2017;Peng et al., 2018). Yao et al. (2019) proposed Text-GCN and achieved state-of-the-art results on several mainstream datasets. However, Text-GCN has the disadvantages of high memory consumption and lack of support online training. The model presents in this paper solves the mentioned problems in Text-GCN and achieves better results.

Conclusion
In this paper, we proposed a new graph based text classification model, which uses text level graphs instead of a single graph for the whole corpus. Experimental results show that our model achieves state-of-the-art performance and has a significant advantage in memory consumption.