Text Graph Transformer for Document Classification

Text classiﬁcation is a fundamental problem in natural language processing. Recent studies applied graph neural network (GNN) techniques to capture global word co-occurrence in a corpus. However, previous works are not scalable to large-sized corpus and ignore the heterogeneity of the text graph. To address these problems, we introduce a novel Transformer based heterogeneous graph neural network, namely Text Graph Transformer (TG-Transformer). Our model learns effective node representations by capturing structure and heterogeneity from the text graph. We propose a mini-batch text graph sampling method that signiﬁcantly reduces computing and memory costs to handle large-sized corpus. Extensive experiments have been conducted on several benchmark datasets, and the results demonstrate that TG-Transformer outperforms state-of-the-art approaches on text classiﬁcation task.


Introduction
Text classification is a widely studied problem in natural language processing and has been addressed in many real-world applications such as news filtering, spam detection, and health record systems (Kowsari et al., 2019;Che et al., 2015;Zhang et al., 2018). The objective is to assign corresponding labels to textual units based on text representations.
Deep learning models like Convolutional Neural Networks (CNN) (Kim, 2014) and Recurrent Neural Networks (RNN) (Hochreiter and Schmidhuber, 1997) have been applied for text representation learning instead of traditional hand-crafted features, such as n-gram and bag-of-words (BoW) (Joulin et al., 2016). Researchers have recently turned to Graph Neural Network (GNN) to exploit global features in text representation learning, which learns node embedding by aggregating information from neighbors through edges. Defferrard et al. (2016) first generalized CNN to graph for text classification task. Then Yao et al. (2019) applied Graph Convolution Network (GCN) (Kipf and Welling, 2016) on a corpus level heterogeneous text graph and achieved state-of-the-art performance. Liu et al. (2020) further improved classification accuracy by expanding the text graph with semantic and syntactic contextual information.
However, these GCN-based models on heterogeneous text graphs suffer from two practical issues. Firstly, none of these models are scalable to largesized corpus due to high computation and memory costs. Calculation of all the nodes in the graph is required at each layer during training. Secondly, all these models ignore the heterogeneity of the text graph, which consists of both document and word nodes. Distinguishing nodes of different types will benefit node representation learning.
To address the above problems, we propose a novel Transformer-based heterogeneous GNN model, namely Text Graph Transformer (TG-Transformer). Instead of learning based on the full text graph, we propose a text graph sampling method that enables subgraph mini-batch training. The significantly reduced computing and memory costs make the model scalable to large-sized corpus. Moreover, we utilize Transformer to aggregate information in subgraph batch with two proposed graph structural encodings. We also distinguish the learning process of different type nodes to fully utilize the heterogeneity of text graph. The main contributions of this work are as follows: 1. We propose Text Graph Transformer, a heterogeneous graph neural network for text classification. It is the first scalable graph-based method for the task to the best of our knowledge. We propose a novel heterogeneous text graph sampling method that significantly reduces computing and memory costs.
3. We perform experiments on several benchmark datasets, and the results demonstrate the effectiveness and efficiency of our model.

Methodology
In this section, we introduce TG-Transformer in great detail. First, we present how to construct a heterogeneous text graph for a given corpus. Then, we introduce our text graph sampling method, which can generate subgraph mini-batch from the text graph. These subgraph batches will be fed into TG-Transformer to learn efficient node representations for classification. The overall structure of our model is shown in Fig. 1.

Text Graph Building
To capture global word co-occurrence within corpus, we build a heterogeneous text graph G = (U, V, E, F). The text graph contains two types of nodes: word nodes (U) representing all documents in the corpus and document nodes (V) representing all the words in the corpus vocabulary. The text graph also contains two types of edges: word-document edges (E) and word-word edges (F). Word-document edges are built based on word occurrence within documents with edge weights measured by the term frequency-inverse document frequency (TF-IDF) method. Word-word edges are built based on local word co-occurrence within sliding windows in the corpus, with edge weights measured by point-wise mutual information (PMI): where N i , N j , N i,j are the number of sliding windows in a corpus that contain word w i , word w j and both w i , w j . N is the total number of sliding windows in the corpus.

Text Graph Sampling
To reduce computing and memory cost, we propose a text graph sampling method. Instead of learning based on the entire text graph, TG-Transformer is trained on sampled subgraph mini-batch, making it scalable to large-sized corpus. We separate sub-graph sampling as a pre-process step in an unsupervised manner for controlling the time costs in model learning.
We first calculate the intimacy matrix S of the text graph based on pagerank algorithm: where factor α ∈ [0, 1] is usually set as 0.15. A = D − 1 2 AD − 1 2 is the normalized symmetric adjacency matrix, A is the adjacency matrix of the text graph, and D is its corresponding diagonal matrix. Each entry S i,j measures the intimacy score between node i and node j.
For any document target node v i ∈ V, we sample its context subgraph C(v i ) of size k by selecting its top k intimate neighbour word nodes u j ∈ U.
Meanwhile, for any word target node u i ∈ U, we first calculate the ratios of two type incident where F(u i ), E(u i ) are the sets of word-word edges, word-document edges incident to u i with intimacy score larger than threshold θ. We sample its context subgraph C(u i ) of size k by selecting its top k · r w (u i ) intimate neighbour word nodes and its top k · r d (u i ) intimate neighbour document nodes, respectively.

Text Graph Transformer
Based on the sampled subgraph mini-batch, TG-Transformer will update the text graph nodes' representations iteratively for classification. We build one model for each target node type (document/word) to model heterogeneity. The input of our model will be raw feature embeddings of nodes in subgraph batch injected by the following two extra structural encodings: Heterogeneity Encoding The heterogeneity encoding can capture the document and word types in the text graph. Similar to the segment encoding in (Devlin et al., 2018), we use 0 and 1 to encode document nodes and word nodes, respectively.

Weisfeiler-Lehman Structural Encoding
We adopt the WL Role Embedding by (Zhang et al., 2020a) to capture the structure of text graph. The Weisfeiler-Lehman (WL) algorithm (Niepert et al., 2016) can label nodes according to their structural roles in the graph. For node v j (document or word node) in the sampled subgraph, we can denote its WL code as W L(v j ) ∈ N, and the encoding is defined as: .
(5) These two encodings have the same dimension (i.e., d h ) as the original raw feature embeddings, so we add them together as the initial node representation for the input subgraph, which can be denoted as H (0) .

Graph Transformer Layer
The D layer graph transformer will aggregate information from subgraph batch to learn the target node representation. Each graph transformer layer contains three trainable matrices: W Q , W K , W v ∈ R d h ×d h and queries Q, keys K and values V are generated by multiplying the input correspondingly: Then a TG-Transformer layer can be donated as: where G-res refers to the graph residual term in (Zhang and Meng, 2019) to solve the over-smoothing issue of GNNs. The output of the last layer H (D) will be averaged as the final representations z of the target node and fed into a softmax classifier: Based on the sampled subgraphs for all the nodes in the training set, e.g., T , we can define the crossentropy based loss function as: where n ∈ T denotes the target word/document nodes in the training set, d y is the label vector dimension, and y n represents the ground-truth label vector of node n.

Experimental Setup
Datasets We evaluate the effectiveness of our model on five benchmarked datasets: R52 and R8 Reuters dataset 1 for news documents classification, Ohsumed dataset 2 for medical bibliographic classification, and two large-scale review rating datasets: IMDB and Yelp 2014. Detailed statistics of the datasets are summarized in Table 1.
Baselines We compare our method with three classical baseline models: CNN in (Kim, 2014), LSTM in (Liu et al., 2016) and fastText in (Joulin et al., 2016) using the average of word/n-grams embeddings. In addition, we compare with three state-of-the-art GNN-based models: TextGCN in (Yao et al., 2019) using GCN, Text GNN 3 in (Huang et al., 2019) using text level graphs and TensorGCN in (Liu et al., 2020) using semantic and syntactic contextual information.
Implementation We set the node representation dimension as 300 and initialize with Glove word embeddings (Pennington et al., 2014). We train a 2-layer graph transformer with a hidden size 32 and 4 attention heads. We use mini-batch SGD with Adam optimizer (Kingma and Ba, 2014), and the dropout rate is set as 0.5. The initial learning rate as 0.001, and we decay it with weight decay 5e −4 . 10 percent of the training set is randomly selected as validation set, and we stop training if

Experiment Results
Table 2 presents the classification accuracy of our model compared with baseline methods. GNNbased models generally perform better than sequential and bag-of-word models due to its ability to model global word co-occurrence in the corpus, and TG-Transformer outperforms other graph models with much less memory and computing cost. This is likely due to the utilization of the text graph's heterogeneity and effective representing learning by Graph Transformer Layers. Moreover, TG-Transformer performs well on large-sized corpus such as IMDB and Yelp 14. We also evaluate model efficiency with training time per epoch, as shown in Table 3. It can be observed that our text graph sampling method reduces the computing cost significantly and makes our model scalable to large corpus, where previous GNN-based models such as Text GCN are not applicable due to computing power limit.
Hyperparameter Here we analyze the effects of subgraph sizes k in sampling. We notice parameter k has a large influence on the model performance since it defines the number of neighbor nodes used to update the target node representation. During parameter tuning, we notice the learning performance improves steadily as k increases from 1 to an optimal value (i.e.,23 for R8) and starts decreasing as k further increases. The same trend is noticed for all datasets. The computing cost to train the model also increases as k goes larger, but is still less than other GNN-based models.

Ablation Study
We perform ablation studies to analyze our model further, as shown in Table 4. In (1), we remove the two structural encodings and only use raw feature embeddings as input. The decreased performance demonstrates that the structural encodings capture some useful heterogeneous graph structure information. In (2), we remove pre-trained word embeddings and initialize all the nodes with random vectors. Model performance has a larger decrease, demonstrating the significance of pre-trained word embeddings and initial node representations on our model. In (3), we train one model to update and learn both subgraph batch target node types. The slightly decreasing classification accuracy reflects the importance of modeling heterogeneity information of the text graph.
4 Related Work

Text Classification
Traditional text classification studies rely on handcrafted features like BoW (Zhang et al., 2010) and n-gram (Wang and Manning, 2012). With the development of deep learning, researchers applied CNN (Kim, 2014;Zhang et al., 2015), LSTM (Tai et al., 2015;Liu et al., 2016), word embedding techniques (Joulin et al., 2016;Pennington et al., 2014), attention mechanism (Yang et al., 2016;Wang et al., 2016) in text classification models and kept improving accuracy. Recently, graph based text classification models received growing attention due to its ability to model global information in corpus (Yao et al., 2019;Peng et al., 2018;Zhang et al., 2020b;Nikolentzos et al., 2019). Our paper follows this line of works on developing novel GNN for text classification.

Graph Neural Network
Representative examples of GNN models proposed by present include GCN (Kipf and Welling, 2016), Graph Attention Network (GAT) (Veličković et al., 2017) and Graph SAGE (Hamilton et al., 2017). GCN models are based on approximated graph convolutional operator while GAT relies on self-attention mechanism. Recently, Transformer (Vaswani et al., 2017) models have been applied in novel GNN designs (Hu et al., 2020;Zhang et al., 2020a).

Conclusion
In this paper, we proposed a scalable heterogeneous graph model, TG-Transformer, for text classification. Experimental results prove its effectiveness and efficiency compared to state-of-the-art methods. It also enables parallelization and pre-training in GNN models for further research.