Do Different Neural Networks Learn The Same Things?
Have you ever had a dataset, and asked: Does this model learn something different from that model? This is the question that Nguyen et. al. covered in their paper “Do Wide And Deep Networks Learn The Same Things?” [1].
Overview
For the rest of this post, I quote from the paper like this
Deep neural network architectures are typically tailored to available computational resources by scaling their width and/or depth. Remarkably, this simple approach to model scaling can result in state-of-the-art networks for both high- and low-resource regimes (Tan & Le, 2019).
and write my explanations below.
I am focusing on the core parts of the 24-pages paper: setting up the networks and dataset, computing the CKA, and creating the figures. This is also the outline for this post: We start with a short introduction, then setup the network. I then explain the equations behind the CKA scores. Lastly we write the code to create the visualizations from the paper.
Introduction
block structure
Based on upcoming code, we can visualize such block structures by comparing the activation between two models:
Experimental setup
[…] our experimental setup consists of a family of ResNets (He et al., 2016; Zagoruyko & Komodakis, 2016) trained on standard image classification datasets CIFAR-10, CIFAR-100 and ImageNet.
Let’s get this into code. To keep computation feasible, I focused on three ResNets [2]: 50, 101, and 152; and the CIFAR-10 dataset [3] only.
ResNet50
The smallest ResNet can be instantiated quite easily: We download the model, which is available from Keras, and load weights pre-trained on ImageNet [7]. We set the input shape to be (32,32,3), which is the shape of one image in the CIFAR-10 dataset. We do not need the fully connected output layer, so we set include_top to False. Lastly, we want to pool the output layer, which is a 4D tensor, to get a 2D output from our base model. Since we have 10 classes in our dataset, we include our own Dense output layer with 10 neurons and softmax activation.
Finally, we create a Model instance by using the base model’s input as the input, and out Dense layer as the output:
ResNet101
The procedure is similar for the larger ResNet101:
ResNet152
And also similar for the ResNet152:
CIFAR-10 dataset
Getting the CIFAR-10 dataset is very convenient, since it’s readily available from TensorFlow. We download the train and test dataset, and rescale the pixel values from 0–255 to 0–1.0, making them floats:
Measuring similarity
We use linear centered kernel alignment (Kornblith et al., 2019; Cortes et al., 2012) to measure similarity between neural network hidden representations.
How does one measure the similarity between two neural networks? One way is to calculate the similarity between the output of (hidden) layers.
Conceptually this is quite simple: You take a data batch, and instead of only capturing the final output, you also capture the output of any hidden layer. You do this for both networks, and now have two sets of activations: One set from network A, and one set from network B
You then do a pairwise comparison between an activation from A, and an activation from B. And that’s essentially all.
There are some requirements on such a similarity metric, though: First, its range has to be between 0 and 1, with 0 indicating two completely distinct activations (not similar), and 1 indicating two identical activations (similar). Further, it has to handle activations of unequal shape.
This metric is the (linear) centered kernel alignment [4][5]:
To reduce memory consumption, we compute CKA as a function of average HSIC scores computed over k minibatches:
where Xᵢ ∈ R^(n x p₁) and Yᵢ ∈ R^(n x p₂) are matrices containing the activations of two layers, one with p₁ neurons and another p₂ neurons, to the same minibatch of n examples sampled without replacement.
Before I explain the equation we need to know what HSIC is:
We use an unbiased estimator of HSIC (Song et al., 2012) so that the value of CKA is independent of the batch size:
where ~K and ~L are obtained by setting the diagonal entries of K and L to zero.
Ok, let’s cover these equations in detail, beginning with HSIC₁ which the CKA score builds on:
HSIC
HSIC is short for the Hilbert-Schmid-Independence Criterion [6], which measures the statistical independence between two distributions. In our case, these distributions are the activations.
Let’s go over the equation again, which follows [6], but this time highlighting important parts:
The first box marks the tr() operator, which is short for the trace of a matrix. The trace of a matrix is the sum of the elements on the main diagonal. And the matrix we compute the trace for is the dot-product between modified K and L.
Now, why are the matrices written in bold? To show that they are matrices; this is a standard convention. This also leads to the second marked block, the 1: This is a vector of length n, filled with ones (also called a unit vector).
Why do we know whether it’s a vector or a matrix? Look at the last marked part, 1^T. The T indicates that we transpose what is on its left side. Now, if we would have a square matrix (n, n), filled with 1s only, the transpose operation would not make any sense (because the transposed square unit matrix is equal to the original matrix).
Last, how do know the length of the 1 vector? This is determined by the dimension of the matrices K and L, which are square matrices. And they are square since we call the HSIC with
where we calculate the dot-product between the original matrix and its transposed version for the first and second arguments. A dot product on the matrix level is the same as the multiplication of those two matrices. (I decided to stick to the dot-product after reading [4], where they use this naming convention in section 3). And since we multiply a matrix of shape (m, k) with its transposed version of shape (k, m), we get a square matrix of shape (m, m).
We use minibatches of size n = 256 obtained by iterating over the test dataset 10 times, sampling without replacement within each epoch.
This indicates the dimension that our matrices will later have: We first have an activation of shape (256, a*b*c*…), which then becomes (256, 256) — a square matrix.
On code level, this yields the following:
After we have now covered HSIC, we can proceed with CKA:
CKA
CKA is short for centered kernel alignment [5], which normalizes the HSIC:
The nominator of our fraction is the sum of the individual HSICs for a set of activations. This is the Σ, which goes over all individual activations. This is
we compute CKA as a function of average HSIC scores computed over k minibatches
which simply says that we take a few batches (of size 256), and separately capture the activations of model A and model B. For each batch we now have the activations from model A, and from model B, which we use to compute the CKA score with.
The denominator is the normalizing factor.
Albeit the the C in CKA stands for centered, I found no mention of the centering process used in [1]. Looking up [4], I found
which is similar to the centering process used in the last part of the calculation for the HSIC score.
When implementing this in python, the calculation for a single batch of 256 images from the CIFAR-10 dataset took one hour. I thus reduced the computational amount by only using one batch, removing the Σ:
Results
We begin our study by investigating how the depth and width of a model architecture affects its internal representation structure. How do representations evolve through the hidden layers in different architectures? How similar are different hidden layer representations to each other? To answer these questions, we use the CKA representation similarity measure outlined in Section 3.1.
We find that as networks become wider and/or deeper, their representations show a characteristic block structure: many (almost) consecutive hidden layers that have highly similar representations.
Let’s get this into code and verify:
We first start by training our ResNets on the CIFAR-10 dataset. To make computation fast, we use TPUs and a large batch size of 256:
The strategy object used above is from TensorFlow and manages TPU training. One epoch takes around 8 seconds, and since this paper’s/blog’s focus is not on high accuracy scores we set the number of epochs to 10.
We similarly train the other ResNets:
The epochs take around 15 to 20 seconds for the larger models.
Now we need some code to calculate the similarity between two activations.
We begin by flattening all but the first channel, this turns a matrix of shape (n, a, b, c) into one of shape (n, a*b*c). No information is lost in this process. After reshaping the activations we can compute their CKA score. To save memory we delete them afterwards, returning only the score:
This code takes two activations, but how do we get them in the first place? We can use the Keras backend for this.
We code a short function that takes a model, and returns a function that, given a data batch, returns the output of every intermediate layer:
We can visualize the result as a heatmap, with the x and y axes representing the layers of the network, going from the input layer to the output layer.
The following function takes two models and a data batch. We use the previous method to get all the hidden activations for the data batch. We then create a placeholder for our heatmap:
The heatmaps start off as showing a checkerboard-like representation similarity structure, which arises because representations after residual connections are more similar to other post-residual representations than representations inside ResNet blocks. As the model gets wider or deeper, we see the emergence of a distinctive block structure — a considerable range of hidden layers that have very high representation similarity (seen as a yellow square on the heatmap). This block structure mostly appears in the later layers (the last two stages) of the network.
Based on the previous function we can visualize such a heatmap by calling
and then use pyplot to visualize this matrix:
Comparing the activations from the ResNet50 with the activations from the ResNet101 took around hour:
Inspecting this image, we see the mentioned block structure in the beginning, the first few layers of the models seem to learn quite similar internal representations. With further layers, this is more diverse: there are vertical lines (activations from ResNet101’s layers) that are completely distinct to the activations from the smaller ResNet.
Such blank spaces become more apparent when comparing the smaller ResNet with the ResNet152 (which took 2 hours on Colab):
Again we can see small block structures in the early layers. Notably, the first 100 to 150 layers of the big ResNet152 seem to learn things very similar to the first 20-ish layers of the much smaller model. This might be indicative of overparamerization:
Even though we nearly tripple the number of parameters from 25 636 712 for the smallest ResNet to 60 419 944 for the ResNet152, it seems as if these additional layers do not result in a different representation.
Note: Keep in mind that there are several cons to my argumentation: Due to limited resources I can only compute the activations for one batch, and for a few networks only. Readers are encouraged to use my code and run it with more batches.
What is visually evident is the checkerboard-like pattern in the lower layers.
That’s it. The Code is available as a Colab notebook here.
Summary
We covered the core parts of the paper Do Wide And Deep Networks Learn The Same Things?. We began with setting up three ResNets and prepared the CIFAR-10 dataset. We then examined the equations behind this paper and implemented them with python. In the last parts we write several short functions to compare two networks’ activations, enabling us to create figures similar to those in the paper.
Where to go from here?
There are a few open points:
- Implement the PCA for activations (section 5.1 of [1])
- With more compute train more networks
- Use more batches to average the similarities
References
[1] Nguyen, et al., Do Wide And Deep Networks Learn The Same Things? Uncovering How Neural Network Representations Vary With Width And Depth (2020), arXiv
[2] K. He, et. al., Deep residual learning for image recognition (2015), Proceedings of the IEEE conference on computer vision and pattern recognition
[3] A. Krizhevsky, et. al., Learning multiple layers of features from tiny images (2009), Tech report
[4] S. Kornblith, et. al., Similarity of neural network representations revisited (2019), International Conference on Machine Learning
[5] C. Cortes, et. al., Algorithms for learning kernels based on centered alignment (2012), The Journal of Machine Learning Research
[6] L. Song, et. al., Feature selection via dependence maximization (2012), Journal of Machine Learning Research
[7] J. Deng et al., Imagenet: A large-scale hierarchical image database (2009), IEEE conference on computer vision and pattern recognition