Distilling a Tiny Model for Fast Interpretability
I trained a ~40M-parameter student model to approximate Gemma 4B representations, and found that it preserved enough structure to support SAE-based analysis
I was inspired by Anthropic’s recent work on emotion vectors to explore some hypotheses I had about interpretability signals in my personal datasets. And while doing the work, I began to wonder how well smaller models, distilled specifically to capture interpretable representations, might perform for these kinds of analyses.
The models I’d been working with in my own experiments were not particularly large — mostly Gemma and Qwen, in the 4B-parameter range — but still, it took one of my local systems, with a 5090, several hours to process a set of 100k documents and produce the representations I wanted to look at for one of the Gemma models. Just inference, no training. So it seemed plausible to me that a distilled model might compute the same kinds of representations many times faster. And that this might be useful, especially when scaling up to frontier-scale models, so long as the distilled student model’s representations were somewhat reliable.
Anyway, given that we now have coding agents that make answering simple research questions so much easier, I decided to see whether this sort of distillation was practical. But to explain what I did, I suppose I should first explain the kinds of analyses I was doing with this interpretability work: the representations I wanted a student distilled model to capture.
There are many different ways of getting at LLM interpretability, and I had been working with contrastive methods: where you have text with clear labels across various classes of signal you are investigating. For this kind of data, where you have those labels, you can identify representations for a class of interest by measuring contrast between what the LLM does when processing its class of documents against what it does when processing documents from another class (or set of classes, or an aggregate text distribution). This sort of contrastive approach is also what Anthropic was doing in their work identifying emotion vectors in Claude.
One of the most common ways of defining contrast is difference of class means: taking the mean of a transformer block residual activation stream across all unmasked tokens in that layer, such that you end up with a final single vector of hidden-layer size. You need to choose in advance (or do preliminary experiments to select) a particular layer of the transformer for this approach, although the consensus view empirically seems to be that middle layers work best.
Essentially, you apply this method and add up the vectors produced by documents of one class — the class of interest — and subtract those same vectors produced by the second class, or whatever the baseline distribution is. With some normalization and, maybe, dimensionality reduction. And that gives you a single vector that represents the class in question: say, for Anthropic’s case, an emotion like happy or sad.
I wanted to train a tiny transformer model to capture this sort of representation: more specifically, the post-block residual stream (the token representation after a transformer block’s updates) from layer 17 of the Gemma 4B model, which I’d been using for my earlier experiments. My plan was originally just to start with training the student transformer on the document’s mean-pooled residual for simplicity, and then move on to having the student predict, more precisely, the token-level residuals. However, training on the document-level mean is much faster, and my early experiments performed well enough1 that I didn’t have much pushing me to try out the much longer training process. So I trained the student (~40M params, 4-layers) for 20 epochs2 on Gemma representations computed from 100k out of 5M+ in the RAID corpus. I didn’t optimize anything too much, just chose relatively standard parameters, did some sanity checks, and looked for an asymptote in validation.
RAID was among the datasets I’d been exploring as part of my interpretability work, and is perhaps the largest and most well-known dataset for evaluating models on predicting differences in AI-generated and human text. (Though as it happens, this dataset was actually not a core piece of the work I had been doing). In any case, when running some baselines I’d been struck by the fact that on RAID a relatively simple contrastive approach along the lines I’ve described — where you build a vector that represents the documents of one class, here AI-generated, versus all those for another class, human-generated — would give quite strong held-out AUC on RAID held-out data. It was 0.95 in one representative experiment I had tried3.
The student model I trained gave similar AUC on that held-out benchmark (0.93) using the same sort of contrast approach. And again, there’s no real supervised training, just the aggregation of positive and negative examples in the representation space.
So the student model was capturing useful information. And it was ~40M parameters, 100 times smaller and more than 100 times faster than Gemma 4B. In fact, it was fast enough that I was able to use it to generate vector representations of all the documents in the RAID training corpus in a bit more than an hour, which would have taken almost a week on the 5090 system with Gemma 4B.
Bridging Contrast Vectors and Sparse Autoencoder Dictionaries
A contrast vector by itself is not necessarily the end state of an interpretability analysis. There are a number of other techniques you can use with or along side them to better understand the mixture of signals that a contrast vector represents, or isolate from it some nuance not easily captured by the high-level labels.
Sparse autoencoder dictionaries (SAEs) are one helpful tool in this vein. They can be applied to the token residual stream — or, less directly, a document representation or contrast vector, as I will explain in a moment — to produce the dominant dimensions of the autoencoder space for such vectors. And this gives you a set of interpretable human labels, because that’s the point of the dictionary.
It’s an enormous amount of effort to construct these SAE dictionaries, and they are not available for all models, or all layers within a model. For example, in Gemma 4B, only layers 12, 14, and 17 among the middle segments have SAEs. This is why I was using layer 17 for my original work and selected it as the target to train the distilled student model for this tangential investigation.
In any case, it struck me as potentially useful to have a model much faster than Gemma 4B that could produce very similar mean-pooled residual representations. For one thing, these representations could be used as features in supervised methods, or the small student model behind it could be finetuned with a classification head (which I actually tried with some success4). But more interesting to me, I wondered whether, with a bit of work, you could map this mean pooled representation back to Gemma’s SAE5, which had been so arduously constructed.
You shouldn’t just use the SAE encoder on document-level means, which is non-linear, because this gives it different inputs from what it was trained on — I did try out of curiosity, and the results were unusable. But you can use the SAE decoder weight matrix as a dictionary of interpretable directions, and project representations onto those directions. Basically, you can ask how much a given vector aligns with each learned feature direction, and in practice, this gives useful-looking results for normalized document representations6 and contrast vectors.
Using this method, I compared SAE features from the student vectors against those from the original Gemma vectors. The agreement was significant but not complete: the document-level pearson correlation of all 16k SAE feature scores across student and Gemma was 0.79 on held-out RAID, and 0.58 on an out-of-distribution Pangram dataset7. The mean top 20 SAE feature overlap was about 0.60 on held-out RAID and 0.32 on the Pangram data. This better approximation on RAID made sense, given that’s what I’d used to train the student. And 0.32 top 20 overlap isn’t necessarily as bad as it sounds, considering this process was ranking 16k features. I should also note that, at an individual document level, these SAE features are more like a fingerprint than a topic model — not always easy to inspect given that many of them describe abstractions like “end of sentence punctuation” or “qualities and judgments”. In any case, there was something it, that we could produce a similar set of SAE features for a document, with a dramatically smaller model.
For example, a document about gardening would produce within the top 20 SAE features a few telling activity- or nature-oriented features, like “nature and harmony” or “state of relaxation”. Similarly, an email requesting that students register for new classes in programming and creative writing included SAE features like “calls to action with exclamation” (“new class offerings!”) and “introductions”, and “code snippets”. And the synopsis of a friend’s gay sports romance novella gave “competitive athletic mindset” and “men and masculinity”. On the whole, SAE decompositions on the student representation seemed to carry over quite generally.
None of this was originally intended as publishable work, but I thought others might find elements of it useful. So I packaged up the distilled student model as a lightweight CLI tool8 that can be used to:
Apply an SAE decomposition to a given input document and report the resulting features and strength of the signal
Generate vector representations of documents compatible with the Gemma SAE dictionary using the method I’ve described
Classify the provenance of text as human or AI-generated (using the contrast vectors I computed with the student transformer over RAID).
For all of these use cases, the speed of the model is what makes it most interesting.
If I have time and/or curiosity in the future, it’s possible that I’ll explore training a student model on token level residuals and a broader mixture of text corpra, and see how much that improves out-of-distribution metrics. But on the whole, I was surprised how well this quick pass at distillation worked, and I’ll likely be using the bridge method I found between contrast vectors and SAEs in the future.
Similar performance to Gemma on contrast prediction on held-out RAID data, and 0.76 cosine similarity to Gemma vectors after z-normalization, which removes the dominant shared structure and isolates finer-grained representational detail.
I trained on MSE loss against the z-normalized teacher residual vector. The student is a 4-layer transformer encoder with hidden size 256, 4 attention heads, 128-dimensional token embeddings, and a 2,560-dimensional linear output.
Top models on this dataset are 0.99+ AUC (see the leaderboard), so this is surprisingly strong but not competitive with state-of-the art
A supervised model that used the distilled student as a base and finetuned on 80% of the RAID training data got >.99 AUC on the held-out 20%. This model is also available in the github repo linked above.
Specifically the smallest 16k variant
With direct projection of unnormalized document vectors, the top SAE features remained constant for most queries, dominated by common residual signal. When you apply z-normalization you see deviations from a reference distribution.
The Pangram benchmark dataset (n=1976)