Low-Complexity Probing via Finding Subnetworks

The dominant approach in probing neural networks for linguistic properties is to train a new shallow multi-layer perceptron (MLP) on top of the model's internal representations. This approach can detect properties encoded in the model, but at the cost of adding new parameters that may learn the task directly. We instead propose a subtractive pruning-based probe, where we find an existing subnetwork that performs the linguistic task of interest. Compared to an MLP, the subnetwork probe achieves both higher accuracy on pre-trained models and lower accuracy on random models, so it is both better at finding properties of interest and worse at learning on its own. Next, by varying the complexity of each probe, we show that subnetwork probing Pareto-dominates MLP probing in that it achieves higher accuracy given any budget of probe complexity. Finally, we analyze the resulting subnetworks across various tasks to locate where each task is encoded, and we find that lower-level tasks are captured in lower layers, reproducing similar findings in past work.


Introduction
While pre-training has produced large gains for natural language tasks, it is unclear what a model learns during pre-training. Research in probing investigates this question by training a shallow classifier on top of the pre-trained model's internal representations to predict some linguistic property (Adi et al., 2016;Shi et al., 2016;Tenney et al., 2019, inter alia). The resulting accuracy is then roughly indicative of the model encoding that property.
However, it is unclear how much is learned by the probe versus already captured in the model representations. This question has been the subject of much recent debate (Hewitt and Liang, 2019;Voita and Titov, 2020;Pimentel et al., 2020b, inter alia).
We would like the probe to find only and all properties captured by a model, leading to a tradeoff between accuracy and complexity: a linear probe is insufficient to find the non-linear patterns in neural models, but a deeper multi-layer perceptron (MLP) is complex enough to learn the task on its own. Motivated by this tradeoff and the goal of lowcomplexity probes, we consider a different approach based on pruning. Specifically, we search for a subnetwork -a version of the model with a subset of the weights set to zero -that performs the task of interest. As our search procedure, we build upon past work in pruning and perform gradient descent on a continuous relaxation of the search problem (Louizos et al., 2017;Mallya et al., 2018;. The resulting probe has many fewer free parameters than MLP probes. Our experiments evaluate the accuracycomplexity tradeoff compared to MLP probes on an array of linguistic tasks. First, we find that the neuron subnetwork probe has both higher accuracy on pre-trained models and lower accuracy on random models, so it is both better at finding properties of interest and less able to learn the tasks on its own. Next, we measure complexity as the bits needed to transmit the probe parameters (Pimentel et al., 2020a;Voita and Titov, 2020). Varying the complexity of each probe, we find that subnetwork probing Pareto-dominates MLP probing in that it achieves higher accuracy given any desired complexity. Finally, we analyze the resulting subnetworks across various tasks and find that lower-level tasks are captured in lower layers, reproducing similar findings in past work (Tenney et al., 2019). These results suggest that subnetwork probing is an effective new direction for improving our understanding of pre-training.

Related Work
Probing. Probing investigates whether a model captures some hypothesized property and typically involves learning a shallow classifier on top of the model's frozen internal representations (Adi et al., 2016;Shi et al., 2016;Conneau et al., 2018). Recent work has primarily applied this technique to pre-trained models. 1 Clark et al. (2019), Hewitt and Manning (2019), and Manning et al. (2020) found that BERT captures various properties of syntax. Tenney et al. (2019) probed the layers of BERT for an array of tasks, and they found that their localization mirrored the classical NLP pipeline (partof-speech, parsing, named entity recognition, semantic roles, coreference) in that lower-level tasks were captured in the lower layers.
However, these results are difficult to interpret due to the use of a learned classifier. One line of work suggests comparing the probe accuracy to random baselines, e.g. random models (Zhang and Bowman, 2018) or random control tasks (Hewitt and Liang, 2019). Other works take an informationtheoretic view: Voita and Titov (2020) measure the complexity of the probe in terms of the bits needed to transmit its parameters, while Pimentel et al. (2020b) argue that probing should measure mutual information between the representation and the property. Pimentel et al. (2020a) propose a Pareto approach where they plot accuracy versus probe complexity, unifying several of these goals. We use these proposed metrics to compare our probing method to standard probing approaches.
Subnetworks. While pruning is widely used for model compression, some works have explored pruning as a technique for learning as well. Mallya et al. (2018) found that a model trained on Im-ageNet could be used for new tasks by learning a binary mask over the weights. More recently, Radiya-Dixit and Wang (2020) and Zhao et al. (2020) showed the analogous result in NLP that weight pruning can be used as an alternative to finetuning for pre-trained models. Our paper seeks to use pruning to reveal what the model already captures, rather than learn new tasks.

Subnetwork Probing
Given a task and a pre-trained encoder model with a classification head, our goal is to find a subnetwork with high accuracy on that task, where a subnetwork is the model with a subset of the encoder weights masked, i.e. set to zero. We search for this subnetwork via supervised gradient descent on the head and a continuous relaxation of the mask. We also mask at several levels of granularity, including pruning weights, neurons, or layers.
To learn the masks, we follow Louizos et al. (2017). Letting φ ∈ R d denote the model weights, we associate the ith weight φ i with a real-valued parameter θ i , which parameterizes a random variable Z i ∈ [0, 1] representing the mask. Z i follows the hard concrete distribution HardConcrete(β, θ i ) with temperature β and location θ i , where σ denotes the sigmoid and γ = −0.1, ζ = 1.1 are constants. This random variable can be thought of as a soft version of the Bernoulli. S i follows the concrete (or Gumbel-Softmax) distribution with temperature β (Maddison et al., 2016;Jang et al., 2016). To put non-zero mass on 0 and 1, the distribution is stretched to the interval (γ = −0.1, ζ = 1.1) and clamped back to [0, 1].
We will denote the mask as Z i = z(U i , θ i ) and the masked weights as φ * Z. We can then optimize the mask parameters θ via gradient descent. Specifically, let f (x; φ) denote the model. Then, given a data point (x, y) and a loss function L, we can minimize the expectation of the loss, or We estimate the expectation via sampling: we sample a single U and take the gradient ∇ θ L(f (x; φ * z(U, θ)). To encourage sparsity, we penalize the mask based on the probability it is non-zero, or Letting λ denote regularization strength, our objective becomes 1 |D| (x,y)∈D L(x, y, θ) + λR(θ). 2

Probe Evaluation
To evaluate the accuracy-complexity tradeoff of a probe, we adapt methodology from recent work.
First, we consider the non-parametric test of probing a random model (Zhang and Bowman, 2018). We check probe accuracy on the pre-trained model, the model with the encoder randomly reset (reset encoder), and the model with the encoder and embeddings reset (reset all). An ideal probe should achieve high accuracy on the pre-trained model and low accuracy on the reset models. 3 Next, we consider a parametric test based on probe complexity. We first vary the complexity of each probe, where for subnetwork probing we associate multiple encoder weights with a single mask, 4 and for the MLP probe we restrict the rank of the hidden layer. We then plot the resulting accuracy-complexity curve (Pimentel et al., 2020a).
To plot this curve, we need a measure of complexity that can compare probes of different types. Therefore, we measure complexity as the number of bits needed to transmit the probe parameters (Voita and Titov, 2020), where for simplicity we use a uniform encoding. In the subnetwork case, this encoding corresponds to using a single bit for each mask parameter. In the case of an MLP probe, each parameter is a real number, so the number of bits per parameter depends on its range and precision. For example, if each parameter lies in [a, b] and requires precision, then we need log( b−a ) bits per parameter. To avoid having the choice of precision impact results, we plot lower and upper bounds of 1 and 32 bits per parameter.

Experimental Setup
We probe bert-base-uncased (Devlin et al., 2019; for the following tasks: (1) Part-of-speech Tagging: We use the partof-speech tags in the universal dependencies dataset (Zeman et al., 2017). As our classification head, we use dropout with probability p = 0.1, followed by a linear layer and softmax projecting from the BERT dimension to the number of tags.
(2) Dependency Parsing: We use the universal dependencies dataset (Zeman et al., 2017) and the biaffine head for classification (Dozat and Manning, 2016). We report macro-averaged labeled attachment score.
(3) Named Entity Recognition (NER): We use the data from the CoNLL 2003 shared task (Tjong Kim Sang and De Meulder, 2003) and the same classification head as for part-of-speech tagging. We report F1 using the CoNLL 2003 script.
Our primary probing baseline is the MLP probe with one hidden layer (MLP-1): with U, V ∈ R d×r . The choice of r restricts the rank of the hidden layer and thus its complexity. 5 Then, if g(x; φ) is our pre-trained encoder and cls is the classification head, our two probes are f Subnetwork (x) = cls(g(x; φ * Z)) and f MLP-1 (x) = cls(MLP-1(g(x; φ))).
While we vary the complexity of each probe to produce the accuracy-complexity plot, we default to neuron subnetwork probing and full rank MLP-1 probing in all other experiments.

Results
Accuracy-Complexity Tradeoff. Table 1 shows the results from the non-parametric experiments. When probing the pre-trained model, the subnetwork probe has much higher accuracy than the MLP-1 probe across all tasks. Furthermore, when probing the random models, the subnetwork probe has much lower accuracy for dependency parsing and NER, suggesting that the probe is less able to learn the task on its own. Overall, these numbers suggest that the subnetwork probe is a more faithful probe in that it finds properties when they are present, and does not find them in a random model. Figure 1 plots the results from the parametric experiments, where we vary the complexity of each probe, apply it to the pre-trained model, and plot the resulting accuracy-complexity curve. We find that the subnetwork probe Pareto-dominates the MLP-1 probe in that it achieves higher accuracy for any complexity, even if we assume an overly optimistic MLP-1 lower bound of 1 bit per parameter. In particular, for part-of-speech and dependency parsing, the subnetwork probe achieves high accuracy even when given only 72 bits, while the MLP-1 probe falls off heavily at ∼20K bits.

Subnetwork
Analysis. An auxiliary benefit of subnetwork probing is that we can examine the subnetworks produced by the procedure. One possibility is to look at the locations of the subnetworks, and one way to examine location is to count the number of unmasked weights in each layer. Figure 2 shows locations of the remaining parameters in the subnetworks extracted from the pre-trained model and the random encoder model. To prune as many parameters as possible, we set λ max to be the largest out of (1, 5, 25, 125) such that accuracy is within 10% of fine-tuning accuracy (see the Appendix for more details). We then examine the sparsity levels of the attention heads for each layer. While reset encoder model's subnetworks are uniformly distributed across the layers, the pretrained model's subnetworks are localized and follow the order part-of-speech → dependencies → NER, reproducing the order found in Tenney et al. (2019). While Tenney et al. (2019) derived layer importance by training classifiers at each layer, we find location directly via pruning. This experiment strengthens their result and represents one example where subnetwork probing reveals additional insights into the model beyond accuracy.
Together, these results show that subnetwork probing is more faithful to the model and offers richer analysis than existing probing approaches. While this work explores accuracy and location-based analysis, there are other possible directions, e.g., applying neuron explainability techniques. Therefore, we see subnetwork probing as a fruitful new direction for understanding pre-training.

Ethical Considerations
While pre-trained models have improved performance for many NLP tasks, they exhibit biases present in the pre-training corpora (Manzini et al., 2019;Tan and Celis, 2019;Kurita et al., 2019, inter alia). As a result, deploying pre-trained models runs the risk of reinforcing social biases. Probing gives us a tool to better understand and hopefully mitigate these biases. As one example of such a study, Vig et al. (2020) analyze how neurons and attention heads contribute to gender bias in pretrained transformers. Therefore, while we analyze linguistic tasks in our paper, our method could also provide insights into model bias, e.g. by analyzing subnetworks for bias detection tasks like CrowS-Pairs (Nangia et al., 2020) or StereoSet (Nadeem et al., 2020

A Appendix
A.1 Hyperparameters The mask parameters are optimized using Adam with β = (0.9, 0.999), = 1 × 10 −8 , and learning rate 0.2 with linear warmup for the first 10% of the data. The classification head parameters are also optimized using Adam with the same hyperparameters and warmup, except with learning rate 5 × 10 −5 . The MLP-1 and fine-tuning baselines are also optimized using Adam with the same hyperparameters, warmup, and learning rate 5 × 10 −5 . We train for 30 epochs for all tasks. Table 2 shows probing accuracies for λ max ∈ (1, 5, 25, 125). Our method is consistently more selective than MLP-1 across the various values of λ max , except for λ max = 125, which seems to require too much sparsity.

A.3 Reproducibility Checklist
Experiments were run in Google Colab using a single 12GB NVIDIA Tesla K80 GPU. For each task, one run of fine-tuning took about half an hour. We used the transformers implementation of the bert-base-uncased model Devlin et al., 2019), which has 12 layers, 768 hidden dimension, 12 heads, and 110M parameters. As data, we used the dev (2002 examples