Can Mechanistic Interpretability be Applied to Connectomics?
A big problem in deep learning is that neural networks are black boxes — they cannot be ‘interpreted’. This makes neural networks hard to work with, both because we cannot easily say how to improve them or prevent them from doing things we do not want them to do. People build up intuition for how these models work, but there’s no unified theory that allows us to make formal predictions in the same way a Civil Engineer might about how the tensile strength of different materials impacts the total weight of a bridge.
One reason that deep neural networks are black boxes is because these models are highly entangled. At any given layer, any ‘semantic’ meaning is spread across many or all of the neurons in that layer. For example, there is no one neuron inside ChatGPT that encodes for the concept ‘dog’. That concept is represented as a set of neurons — a bit of neuron A, a bit of neuron B, etc. Other concepts may use different combinations of neurons. ‘Refrigerator’ may require a bit of neuron A and a bit of neuron C and so on. Notice that there is overlap. ‘Dog’ and ‘Refrigerator’ and a bunch of other words may all depend on the same individual neuron, while having unique sets of neurons that lead to the unique concept1. Each neuron is ‘polysemantic’ — that is, it may have many meanings. That behavior makes ML models really powerful, because they can represent more bits of information with fewer neurons. But it makes understanding these models really challenging.
One angle for attacking neural network interpretability is to try and get to a deep understanding of each individual piece of the neural net, and from that try to reconstruct the behavior of the whole system. This is called ‘mechanistic interpretability’. It’s a bit like trying to learn how a clock works by taking apart and understanding each piece. If we could somehow figure out a way to decompose a neural network such that each neuron is related to a specific semantic concept (i.e. is monosemantic) we could maybe get a better understanding of how all these pieces interact with each other.
About a year ago, there was a paper from Anthropic about how to achieve monosemanticity in language models; see related blog post from ACX that does a pretty good job summarizing the key points. This was followed up a few months ago by another paper that applied monosemantic neuron extraction technique to a large language model (Claude 3.0).
I’m not going to go too into the weeds here, but the basic idea is that by training a sparse autoencoder on a fixed neural network, you can extract/discover distinct sets of neurons in the neural net that correspond to monosemantic features. For example, I can discover exactly which linear combinations of neurons correspond to the output tokens for "Happy" or "Algebra" or "DNA" in a LLM. The folks over at Anthropic successfully used this technique to modify their models in predictive ways — for example, by getting the model to only talk about the Golden Gate Bridge.
The process is data hungry, but overall seems to be pretty effective. The Anthropic team was able to identify thousands of ‘sets’ of neurons that each corresponded to specific semantically relevant activation points.
—
Before I was meaningfully working in AI, I was working on connectomics at the BioNet Lab at Columbia. Connectomics is a pretty niche field — it’s the study of neural circuit architecture and how that architecture leads to specific processing capabilities, generally by evaluating in silico simulations. In layman’s terms, we’d work on constructing biologically accurate neural circuits of fruit fly brains using code, and then try to learn things about the neural circuits from running that code.
Brain circuitry exhibits many of the same polysemantic behaviors as deep learning models. There is no ‘dog’ neuron in my head; when I see a dog, a whole chunk of my brain lights up at once. Same with ‘refrigerator’ or whatever else.
Hopefully you can see where I’m going with this…
In theory, if you could wire up a sparse autoencoder to some pre-existing connectome, you might be able to dictionary-learn your way into mapping neuron circuits to a wide variety of mono-semantic concepts. For example, you may be able to take a detailed model of a retina, stick this auto-encoder behind the retina, and then figure out which circuits within the retina are used for contrast detection, brightness detection, edge detection, or maybe even specific objects.
I don’t have access to the compute clusters for large scale connectome modeling these days, but I bet someone out there does. Hoping that someone with more resources or expertise can pick up this thread (and maybe shoots me an email to let me know if this makes any sense at all)
I like to pull the word ‘omnigenic’ from molecular bio as an analogy. Neuron behavior in a DNN is a bit like how there is no one ‘gene’ that encodes for height. Rather, a person’s height is a mix of dozens or hundreds of genes that all come together to determine how tall someone ends up being. Similarly, many and maybe every neuron in a layer will contribute a bit to semantic representations.