maturity
a research preview from Anthroppic and this is pretty much still a work in progress
see also reproduction on Gemma 2B and github
A variant of sparse autoencoder where it reads and writes to multiple layers (Lindsey et al., 2024)
Crosscoders produces shared features across layers and even models
Resolve:
-
cross-layer features: resolve cross-layer superposition
-
circuit simplification: remove redundant features from analysis and enable jumping across training many uninteresting identity circuit connections
-
model diffing: produce shared sets of features across models. This also introduce one model across training, and also completely independent models with different architectures.
cross-layer superposition
given the additive properties of transformers’ residual stream, adjacent layers in larger transformers can be thought as “almost parallel”
intuition
In basis of superposition hypothesis, a feature is a linear combinations of neurons at any given layers.
if we think of adjacent layers as being “almost parallel branches that potentially have superposition between them”, then we can apply dictionary learning jointly 1
Current drawbacks of sparse autoencoders is that we have to train it against certain activations layers to extract features. In terms of the residual stream per layers, we end up having lots of duplicate features across layers.
Crosscoders can simplify the circuit given that we use an appropriate architecture 2
Autoencoders and transcoders as special cases of crosscoders.
- autoencoders: reads and predict the same layers
- transcoders: read from layer and predict layer
Crosscoder read/write to many layers, subject to causality constraints.
crosscoders
Let one compute the vector of feature activation on data point by summing over contributions of activations of different layers for layers :
We have loss
and regularization can be rewritten as:
weight of L1 regularization penalty by L1 norm of per-layer decoder weight norms 3
We use L1 due to
-
baseline loss comparison: L2 exhibits lower loss than sum of per-layer SAE losses, as they would effectively obtain a loss “bonus” by spreading features across layers
-
layer-wise sparsity surfaces layer-specific features: based on empirical results of model diffing, that L1 uncovers a mix of shared and model-specific features, whereas L2 tends to uncover only shared features.
good to explore:
- strictly causal crosscoders to capture MLP computation and treat computation performed by attention layers as linear
- combine strictly causal crosscoders for MLP outputs without weakly causal crosscoders for attention outputs
- interpretable attention replacement layers that could be used in combination with strictly causal crosscoders for a “replacement model”
see also: model stiching and SVCCA
(Laakso & Cottrell, 2000) proposes compare representations by transforming into representations of distances between data points. 4
How do features change over model training? When do they form?
As we make a model wider, do we get more features? or they are largely the same, packed less densely?
Bibliographie
- Gorton, L. (2024). The Missing Curve Detectors of InceptionV1: Applying Sparse Autoencoders to InceptionV1 Early Vision. arXiv preprint arXiv:2406.03662 [arxiv]
- Laakso, A., & Cottrell, G. (2000). Content and cluster analysis: Assessing representational similarity in neural systems. Philosophical Psychology, 13(1), 47–76. https://doi.org/10.1080/09515080050002726
- Lindsey, J., Templeton, A., Marcus, J., Conerly, T., Batson, J., & Olah, C. (2024). Sparse Crosscoders for Cross-Layer Features and Model Diffing. Transformer Circuits Thread. [link]
Remarque
-
(Gorton, 2024) denotes that cross-branch superposition is significant in interpreting models with parallel branches (InceptionV1) ↩
-
causal description it provides likely differs from that of the underlying model. ↩
-
is the L2 norm of a single feature’s decoder vector at a given layer.
In principe, one might have expected to use L2 norm of per-layer norm ↩
-
Chris Colah’s blog post explains how t-SNE can be used to visualize collections of networks in a function space. ↩