Ilya's 30 Papers to Carmack: Identity Mappings in ResNets
This post is part of a series of paper reviews, covering the ~30 papers Ilya Sutskever sent to John Carmack to learn about AI. To see the rest of the reviews, go here.
[EDIT: this one was accidentally sent out before the ResNet review. I ended up just publishing both of them at the same time]
Paper 16: Identity Mappings in Deep Residual Networks
Before reading this review, I highly recommend reading my previous paper review on Deep Residual Networks. I use a lot of overlapping terminology and framing.
High level
ResNets model connections between layers, such that information can pass ‘unchanged’ up a model, and gradients back down. This has the effect of allowing different parts of the input signal to make its way through the model ‘uncorrupted’.
In theory, you could use any function of the input to pass the residuals up from lower levels through to upper ones. You could use a learned one (this is the F(x) = H(x) - Wx formulation I mentioned earlier), or a gated one, or some scaled one. In the original paper they use an identity function — F(x) = H(x) - 1*x — and don’t really explore the others.
In theory, you can also use any activation function between layers. In the original paper they use a ReLu in between layers, and again don’t really explore other options.
In this paper, the authors explore other options for both the skip connection function and the activation function and find that surprisingly, the identity function for both the skip connection AND the activation function happens to work the best. The authors propose the following modification to ResNets based on their finding:
This is somewhat unintuitive. Generally, adding more parameters to a model leads to higher representation capacity, which in turn leads to richer representations and better outcomes. This is a foundational assumption of the scaling hypothesis. Even the authors of the first paper imply that learning weights in the skip connection ought to be better1. Similarly, adding nonlinearities is standard practice between model layers and nearly always improves results2. So what gives? Why is it that in this case, removing parameters and non-linearities actually results in better outcomes?
The rest of the paper is attempting to examine, empirically and intuitively, why a strict identity function is ideal for deeply nested models.
The authors start by theorizing about gradients.
If you have an identity in the skip-connection AND in the activation functions, you essentially have a direct line of signal all the way through the model. This single continuous ‘residual stream’ leads to the best outcomes, because the model can read and write to the residual stream as needed, instead of having it be interrupted by the structure of the architecture.
More formally:
You can express any layer as a function of a previous layer (i.e. you get the recursive residual stream all the way back to the beginning)
You can express any layer as a sum of the previous residuals. This is in comparison to other setups where the layer L is normally expressed as a product of previous layers3.
The resulting backprop gradients are also nice. At each layer, you get a term for the direct information that is unmodulated by weights, and a term for the weight layers themselves. These terms are additive, which means it is much less likely that gradients will be unstable, and layers are much more modular.
By comparison, the original ResNet paper would only have direct signals within an individual ResNet block, because the ReLu would break up the signal between blocks. As a result, signal/gradients between ResNet blocks may experience unnecessary instability, and there are stricter cross-layer dependencies.
So we have some ideas for why identities may work well. Do we have any ideas about why other forms of residual skip connections — like those with additional weights or gates — work poorly?
First, let’s analytically reason through just a scalar change: we multiply the skip connection by a constant. Instead of learning F(x) = H(x) - x, the model has to learn F(x) = H(x) - cx, where c is a constant4. In very deep networks, if the scalar constant c is > 1, you get some exponential growth term that swamps the signal; if the scalar is < 1, the residual connection just disappears. This should sound like gradients exploding/vanishing, because it basically is. Another way of thinking about this is that the addition of the constant forces the model to learn a more complicated F(x) function to 'balance' the constant out5. In general, the same problem applies for even more complicated transforms like weights or gates — you may still get gradients, but the gradients are formed through a multiplication-series that may exhibit the same optimization instabilities as the constant scalar. Identity skip-connections don't have this problem.
And the various empirical tests bear this out — as the authors got closer to identity in the skip connection, things worked better! For example, when testing gating for the skip connections using 1 - sigmoid(Wx + b), they found that the more negative the b term was (i.e. the closer the gate was to 1), the better the model did.
Importantly, empirical results improved when skip-connection function approached identity, despite the fact that some of the skip-connection variants being tested had higher representational capacity (i.e. more parameters). This fits with the broader hypothesis that the unimpeded skip connection through the model is itself useful — if the optimal outcome is to let the skip-connection pass through unimpeded, then learning an identity function is always going to be worse than just setting identity directly. The learned version will only ever approximate identity in the best case, and in the worst case will have all sorts of weird local minima errors6. By contrast, if identity signaling did not matter, we would expect more parameters to uniformly improve the outcomes.
What about those ReLu activations?
Earlier, I said that the motivation behind the ResNet block is to make it easier to learn identity functions. But in the original ResNet paper, there is a ReLu nonlinearity after each skip connection ResNet block, ResNet(x) = ReLu(F(x) + x). The authors call this post-activation — the activation happens after the addition of the residual. In post-activation architecture, the ReLu blocks a direct signal between ResNet layers. If x is negative, there's no way to pass that signal on to the next layer. We lose the linear dependency on x across layers. The authors’ primary thesis is that unmodulated signal is good; as a result, the ReLu activation has to go.
One naive option is to move the ReLu to just the learned weight stack, ResNet(x) = ReLu(F(x)) + x. But, again, the ReLu blocks all negative signals. You end up with monotonically increasing values as you go through each layer. To avoid this, the authors shift everything down. They put a ReLu at the beginning of the learned weight stack, ResNet(x) = F(ReLu(x)) + x.
The same logic applies to batch norm layers, resulting in the construction we briefly proposed in the beginning. The authors call this pre-activation.
As expected, the modified ResNet is empirically better. The authors also find that the impact of pre- vs post-activation is more pronounced on larger models. The authors hypothesize that this is because ReLu, with its gate on negative outputs, likely slows gradients early in the model training cycle, and this is more likely to happen to more layers in a larger randomly initialized model7.
Insight
This is probably one of my favorite papers in the set so far, because it is an ML interpretability paper! In a world without a formal theory, interpretability papers are all about empirics. The core of the field is trial and error, building and testing theories from observation. Science as it was meant to be! In a very real sense this feels a lot like biology — poking and prodding at things to see what sticks.
Many of the insights from the original paper above apply here. The model is more explicitly learning operations that are applied to some shared state — by reframing all operations as a sum, you can fully disentangle layers from each other and treat the entire model as a pipeline instead. This serves to solve the capacity / bottleneck issues faced by standard feedforward nets, much in the same way LSTMs help resolve capacity issues in standard RNNs. More generally, this paper echoes many of the exciting / useful properties of LSTMs, and, later, A Mathematical Framework for Transformer Circuits.
Pulling on that thread a bit more, this paper further motivates the results found in the RNN Regularization paper. To recap, the authors of that paper found that dropout on hidden-state connections in RNNs made the entire model perform significantly worse. In my mind, hidden state connections in RNNs exactly map to residual connections in these deep convolutional networks. And, lo and behold, the authors of this paper find that dropout in residuals also performs worse. Here, they provide additional motivation — the authors argue that dropout is statistically equivalent to a constant scalar multiplier on the weights of the skip connection, something that we already showed was likely to be unstable and bad for learning.
As a final note, it's interesting that this paper came out in 2016, just 6 months after the initial ResNet paper. In my opinion, it is very closely aligned with a lot of threads that LSTMs solved about 2 decades prior. I'm not sure why it took so long for these ideas to cross pollinate. Reading the paper, the authors never really explicitly cite or draw on LSTMs as inspiration, so it’s fully possible they land on the idea independently. Maybe the computer vision guys weren't reading the language modelling papers and vice versa? Maybe it was because of AI winter? Unclear. Still, by the time I was formally studying ML models (starting around 2014/2015), it was standard practice to learn about LSTMs and conv nets in the same classroom, so you'd think the ah-ha moment would have come sooner.
Quoting from the original ResNet paper: We can also use a square matrix Ws in Eqn.(1). But we will show by experiments that the identity mapping is sufficient for addressing the degradation problem and is economical, and thus Ws is only used when matching dimensions.
Stacked linear functions are also linear. As a result, stacked linear layers without any non-linearities learn a function that can be learned by a single linear layer. This is why you need non-linearities in the first place — your model doesn’t get any additional representational capacity from only having stacked linear layers, even if you have a ton of them.
The default two layer MLP construction is MLP(x) = ReLu(W * ReLu(Wx + b) + b). Notice how the output of layer one is multiplied by the weights in layer two.
Remember that H(x) is the function we want the ResNet block as a whole to learn, and F(x) is what the weights within the ResNet block learn.
The original paper was all about learning an identity function. In the default ResNet construction, the model just has to learn F(x) = 0 for the ResNet block to learn identity — something that was really easy to do, because ReLu activations would push any negative values to 0. If you wanted the ResNet block with scaled skip-connections to represent an identity function, F(x) = x - cx. This is a much more complicated function than F(x) = 0!
More generally, neural networks are function approximators. If you know what the function should be, use that instead of trying to approximate it through learning.
More generally, a direct signal path down the model means there is an opportunity for every layer to have direct influence on the outcome, and for gradients to be applied ‘directly’ without modulation. In a traditional deep model without identity skip connections, this isn't true — each layer's gradients are modulated by every other layer that comes above it.