whirlwind tour, initial exploration, glossary

The subfield of alignment that delves into reverse engineering of a neural network, especially LLMs

To attack the curse of dimensionality, the question remains: how do we hope to understand a function over such a large space, without an exponential amount of time? 1

inference

application in the wild: Goodfire and Transluce

How we would do inference with SAE?

idea: treat SAEs as a logit_processor, though there are currently some bottleneck with logit_processor in vLLM, similar to guided decoding

Currently, before v1, logit_processor are row-wise, meaning logits are currently being processed before passing down to scheduling group 2

steering

refers to the process of manually modifying certain activations and hidden state of the neural net to influence its outputs

For example, the following is a toy example of how a decoder-only transformers (i.e: GPT-2) generate text given the prompt “The weather in California is”

flowchart LR
  A[The weather in California is] --> B[H0] --> D[H1] --> E[H2] --> C[... hot]

To steer to model, we modify H2H_2 layers with certain features amplifier with scale 20 (called it H3H_{3})3

flowchart LR
  A[The weather in California is] --> B[H0] --> D[H1] --> E[H3] --> C[... cold]

One usually use techniques such as sparse autoencoders to decompose model activations into a set of interpretable features.

For feature ablation, we observe that manipulation of features activation can be strengthened or weakened to directly influence the model’s outputs

A few examples where (Panickssery et al., 2024) uses contrastive activation additions to steer Llama 2

contrastive activation additions

intuition: using a contrast pair for steering vector additions at certain activations layers

Uses mean difference which produce difference vector similar to PCA:

Given a dataset D\mathcal{D} of prompt pp with positive completion cpc_p and negative completion cnc_n, we calculate mean-difference vMDv_\text{MD} at layer LL as follow:

vMD=1Dp,cp,cnDaL(p,cp)aL(p,cn)v_\text{MD} = \frac{1}{\mid \mathcal{D} \mid} \sum_{p,c_p,c_n \in \mathcal{D}} a_L(p,c_p) - a_L(p, c_n)

implication

by steering existing learned representations of behaviors, CAA results in better out-of-distribution generalization than basic supervised finetuning of the entire model.

sparse autoencoders

abbrev: SAE

see also: landspace

Often contains one layers of MLP with few linear ReLU that is trained on a subset of datasets the main LLMs is trained on.

empirical example: if we wish to interpret all features related to the author Camus, we might want to train an SAEs based on all given text of Camus to interpret “similar” features from Llama-3.1

definition

We wish to decompose a models’ activitation xRnx \in \mathbb{R}^n into sparse, linear combination of feature directions:

xx0+i=1Mfi(x)didiMn: latent unit-norm feature directionfi(x)0: corresponding feature activation for x\begin{aligned} x \sim x_{0} + &\sum_{i=1}^{M} f_i(x) d_i \\[8pt] \because \quad &d_i M \gg n:\text{ latent unit-norm feature direction} \\ &f_i(x) \ge 0: \text{ corresponding feature activation for }x \end{aligned}

Thus, the baseline architecture of SAEs is a linear autoencoder with L1 penalty on the activations:

f(x)ReLU(Wenc(xbdec)+benc)x^(f)Wdecf(x)+bdec\begin{aligned} f(x) &\coloneqq \text{ReLU}(W_\text{enc}(x - b_\text{dec}) + b_\text{enc}) \\ \hat{x}(f) &\coloneqq W_\text{dec} f(x) + b_\text{dec} \end{aligned}

training it to reconstruct a large dataset of model activations xDx \sim \mathcal{D}, constraining hidden representation ff to be sparse

L1 norm with coefficient λ\lambda to construct loss during training:

L(x)xx^(f(x))22+λf(x)1xx^(f(x))22: reconstruction loss\begin{aligned} \mathcal{L}(x) &\coloneqq \| x-\hat{x}(f(x)) \|_2^2 + \lambda \| f(x) \|_1 \\[8pt] &\because \|x-\hat{x}(f(x)) \|_2^2 : \text{ reconstruction loss} \end{aligned}

intuition

We need to reconstruction fidelity at a given sparsity level, as measured by L0 via a mixture of reconstruction fidelity and L1 regularization.

We can reduce sparsity loss term without affecting reconstruction by scaling up norm of decoder weights, or constraining norms of columns WdecW_\text{dec} during training

Ideas: output of decoder f(x)f(x) has two roles

  • detects what features acre active L1 is crucial to ensure sparsity in decomposition
  • estimates magnitudes of active features L1 is unwanted bias

Gated SAE

uses Pareto improvement over training to reduce L1 penalty (Rajamanoharan et al., 2024)

Clear consequence of the bias during training is shrinkage (Sharkey, 2024) 4

Idea is to use gated ReLU encoder (Dauphin et al., 2017; Shazeer, 2020):

f~(x)1[(Wgate(xbdec)+bgate)>0πgate(x)]fgate(x)ReLU(Wmag(xbdec)+bmag)fmag(x)\tilde{f}(\mathbf{x}) \coloneqq \underbrace{\mathbb{1}[\underbrace{(\mathbf{W}_{\text{gate}}(\mathbf{x} - \mathbf{b}_{\text{dec}}) + \mathbf{b}_{\text{gate}}) > 0}_{\pi_{\text{gate}}(\mathbf{x})}]}_{f_{\text{gate}}(\mathbf{x})} \odot \underbrace{\text{ReLU}(\mathbf{W}_{\text{mag}}(\mathbf{x} - \mathbf{b}_{\text{dec}}) + \mathbf{b}_{\text{mag}})}_{f_{\text{mag}}(\mathbf{x})}

where 1[>0]\mathbb{1}[\bullet > 0] is the (point-wise) Heaviside step function and \odot denotes element-wise multiplication.

termannotations
fgatef_\text{gate}which features are deemed to be active
fmagf_\text{mag}feature activation magnitudes (for features that have been deemed to be active)
πgate(x)\pi_\text{gate}(x)fgatef_\text{gate} sub-layer’s pre-activations

to negate the increases in parameters, use weight sharing:

Scale WmagW_\text{mag} in terms of WgateW_\text{gate} with a vector-valued rescaling parameter rmagRMr_\text{mag} \in \mathbb{R}^M:

(Wmag)ij(exp(rmag))i(Wgate)ij(W_\text{mag})_{ij} \coloneqq (\exp (r_\text{mag}))_i \cdot (W_\text{gate})_{ij}

Figure 3: Gated SAE with weight sharing between gating and magnitude paths

Figure 4: A gated encoder become a single layer linear encoder with Jump ReLU (Erichson et al., 2019) activation function σθ\sigma_\theta

feature suppression

See also: link

Loss function of SAEs combines a MSE reconstruction loss with sparsity term:

L(x,f(x),y)=yx2/d+cf(x)d: dimensionality of x\begin{aligned} L(x, f(x), y) &= \|y-x\|^2/d + c\mid f(x) \mid \\[8pt] &\because d: \text{ dimensionality of }x \end{aligned}

the reconstruction is not perfect, given that only one is reconstruction. For smaller value of f(x)f(x), features will be suppressed

How do we fix feature suppression in training SAEs?

introduce element-wise scaling factor per feature in-between encoder and decoder, represented by vector ss:

f(x)=ReLU(Wex+be)fs(x)=sf(x)y=Wdfs(x)+bd\begin{aligned} f(x) &= \text{ReLU}(W_e x + b_e) \\ f_s(x) &= s \odot f(x) \\ y &= W_d f_s(x) + b_d \end{aligned}
Lien vers l'original

sparse crosscoders

maturity

a research preview from Anthroppic and this is pretty much still a work in progress

see also reproduction on Gemma 2B and github

A variant of sparse autoencoder where it reads and writes to multiple layers (Lindsey et al., 2024)

Crosscoders produces shared features across layers and even models

Resolve:

  • cross-layer features: resolve cross-layer superposition

  • circuit simplification: remove redundant features from analysis and enable jumping across training many uninteresting identity circuit connections

  • model diffing: produce shared sets of features across models. This also introduce one model across training, and also completely independent models with different architectures.

motivations

cross-layer superposition

given the additive properties of transformers’ residual stream, adjacent layers in larger transformers can be thought as “almost parallel”

if we think of adjacent layers as being “almost parallel branches that potentially have superposition between them”, then we can apply dictionary learning jointly 5

persistent features and complexity

Current drawbacks of sparse autoencoders is that we have to train it against certain activations layers to extract features. In terms of the residual stream per layers, we end up having lots of duplicate features across layers.

Crosscoders can simplify the circuit given that we use an appropriate architecture 6

setup.

Autoencoders and transcoders as special cases of crosscoders.

  • autoencoders: reads and predict the same layers
  • transcoders: read from layer nn and predict layer n+1n+1

Crosscoder read/write to many layers, subject to causality constraints.

crosscoders

Let one compute the vector of feature activation f(xj)f_(x_j) on data point xjx_j by summing over contributions of activations of different layers al(xj)a^l(x_j) for layers lLl \in L:

f(xj)=ReLU(lLWenclal(xj)+benc)Wencl: encoder weights at layer lal(xj): activation on datapoint xj at layer l\begin{aligned} f(x_j) &= \text{ReLU}(\sum_{l\in L}W_{\text{enc}}^l a^l(x_j) + b_{\text{enc}}) \\[8pt] &\because W^l_{\text{enc}} : \text{ encoder weights at layer } l \\[8pt] &\because a^l(x_j) : \text{ activation on datapoint } x_j \text{ at layer } l \\ \end{aligned}

We have loss

L=lLal(xj)al(xj)2+lLifi(xj)Wdec,ilL = \sum_{l\in L} \|a^l(x_j) - a^{l^{'}}(x_j)\|^2 + \sum_{l\in L}\sum_i f_i(x_j) \|W^l_{\text{dec,i}}\|

and regularization can be rewritten as:

lLifi(xj)Wdec,il=ifi(xj)(lLWdec,il)\sum_{l\in L}\sum_{i} f_i(x_j) \|W^l_{\text{dec,i}}\| = \sum_{i} f_i(x_j)(\displaystyle\sum_{l \in L} \|W^l_\text{dec,i}\|)

weight of L1 regularization penalty by L1 norm of per-layer decoder weight norms lLWdec,il\sum\limits{l\in L} \|W^l_\text{dec,i}\| 7

We use L1 due to

  • baseline loss comparison: L2 exhibits lower loss than sum of per-layer SAE losses, as they would effectively obtain a loss “bonus” by spreading features across layers

  • layer-wise sparsity surfaces layer-specific features: based on empirical results of model diffing, that L1 uncovers a mix of shared and model-specific features, whereas L2 tends to uncover only shared features.

variants

good to explore:

  1. strictly causal crosscoders to capture MLP computation and treat computation performed by attention layers as linear
  2. combine strictly causal crosscoders for MLP outputs without weakly causal crosscoders for attention outputs
  3. interpretable attention replacement layers that could be used in combination with strictly causal crosscoders for a “replacement model”

model diffing

see also: model stiching and SVCCA

(Laakso & Cottrell, 2000) proposes compare representations by transforming into representations of distances between data points. 8

questions

How do features change over model training? When do they form?

As we make a model wider, do we get more features? or they are largely the same, packed less densely?

Lien vers l'original

superposition hypothesis

tl/dr

phenomena when a neural network represents more than nn features in a nn-dimensional space

Linear representation of neurons can represent more features than dimensions. As sparsity increases, model use superposition to represent more features than dimensions.

neural networks “want to represent more features than they have neurons”.

When features are sparsed, superposition allows compression beyond what linear model can do, at a cost of interference that requires non-linear filtering.

reasoning: “noisy simulation”, where small neural networks exploit feature sparsity and properties of high-dimensional spaces to approximately simulate much larger much sparser neural networks

In a sense, superposition is a form of lossy compression

importance

  • sparsity: how frequently is it in the input?

  • importance: how useful is it for lowering loss?

over-complete basis

reasoning for the set of nn directions 9

features

A property of an input to the model

When we talk about features (Elhage et al., 2022, p. see “Empirical Phenomena”), the theory building around several observed empirical phenomena:

  1. Word Embeddings: have direction which corresponding to semantic properties (Mikolov et al., 2013). For example:
    V(king) - V(man) = V(monarch)
  2. Latent space: similar vector arithmetics and interpretable directions have also been found in generative adversarial network.

We can define features as properties of inputs which a sufficiently large neural network will reliably dedicate a neuron to represent (Elhage et al., 2022, p. see “Features as Direction”)

ablation

refers to the process of removing a subset of a model’s parameters to evaluate its predictions outcome.

idea: deletes one activation of the network to see how performance on a task changes.

  • zero ablation or pruning: Deletion by setting activations to zero
  • mean ablation: Deletion by setting activations to the mean of the dataset
  • random ablation or resampling

residual stream

flowchart LR
  A[Token] --> B[Embeddings] --> C[x0]
  C[x0] --> E[H] --> D[x1]
  C[x0] --> D
  D --> F[MLP] --> G[x2]
  D --> G[x2]
  G --> I[...] --> J[unembed] --> X[logits]

residual stream x0x_{0} has dimension (C,E)\mathit{(C,E)} where

  • C\mathit{C}: the number of tokens in context windows and
  • E\mathit{E}: embedding dimension.

Attention mechanism H\mathit{H} process given residual stream x0x_{0} as the result is added back to x1x_{1}:

x1=H(x0)+x0x_{1} = \mathit{H}{(x_{0})} + x_{0}

grokking

See also: writeup, code, circuit threads

A phenomena discovered by (Power et al., 2022) where small algorithmic tasks like modular addition will initially memorise training data, but after a long time ti will suddenly learn to generalise to unseen data

empirical claims

related to phase change

References

  • Dauphin, Y. N., Fan, A., Auli, M., & Grangier, D. (2017). Language Modeling with Gated Convolutional Networks. arXiv preprint arXiv:1612.08083 arxiv
  • Erichson, N. B., Yao, Z., & Mahoney, M. W. (2019). JumpReLU: A Retrofit Defense Strategy for Adversarial Attacks. arXiv preprint arXiv:1904.03750 arxiv
  • Rajamanoharan, S., Conmy, A., Smith, L., Lieberum, T., Varma, V., Kramár, J., Shah, R., & Nanda, N. (2024). Improving Dictionary Learning with Gated Sparse Autoencoders. arXiv preprint arXiv:2404.16014 arxiv
  • Sharkey, L. (2024). Addressing Feature Suppression in SAEs. AI Alignment Forum. [post]
  • Shazeer, N. (2020). GLU Variants Improve Transformer. arXiv preprint arXiv:2002.05202 arxiv
  • Gorton, L. (2024). The Missing Curve Detectors of InceptionV1: Applying Sparse Autoencoders to InceptionV1 Early Vision. arXiv preprint arXiv:2406.03662 arxiv
  • Laakso, A., & Cottrell, G. (2000). Content and cluster analysis: Assessing representational similarity in neural systems. Philosophical Psychology, 13(1), 47–76. https://doi.org/10.1080/09515080050002726
  • Lindsey, J., Templeton, A., Marcus, J., Conerly, T., Batson, J., & Olah, C. (2024). Sparse Crosscoders for Cross-Layer Features and Model Diffing. Transformer Circuits Thread. [link]
  • Elhage, N., Hume, T., Olsson, C., Schiefer, N., Henighan, T., Kravec, S., Hatfield-Dodds, Z., Lasenby, R., Drain, D., Chen, C., Grosse, R., McCandlish, S., Kaplan, J., Amodei, D., Wattenberg, M., & Olah, C. (2022). Toy Models of Superposition. Transformer Circuits Thread. [link]
  • Mikolov, T., Yih, W., & Zweig, G. (2013). Linguistic Regularities in Continuous Space Word Representations. In L. Vanderwende, H. Daumé III, & K. Kirchhoff (Eds.), Proceedings of the 2013 Conference of the North American Chapter of the Association for Computational Linguistics: Human Language Technologies (pp. 746–751). Association for Computational Linguistics. https://aclanthology.org/N13-1090
  • Panickssery, N., Gabrieli, N., Schulz, J., Tong, M., Hubinger, E., & Turner, A. M. (2024). Steering Llama 2 via Contrastive Activation Addition. arXiv preprint arXiv:2312.06681 arxiv
  • Power, A., Burda, Y., Edwards, H., Babuschkin, I., & Misra, V. (2022). Grokking: Generalization Beyond Overfitting on Small Algorithmic Datasets. arXiv preprint arXiv:2201.02177 arxiv

Footnotes

  1. good read from Lawrence C for ambitious mech interp.

  2. the benchmark was run against vllm#0.6.3.dev236+g48138a84, with all configuration specified in the pull request.

  3. An example steering function can be:

    H3=H2+steering_strengthSAE.Wdec[20]max_activationH_{3} = H_{2} + \text{steering\_strength} * \text{SAE}.W_{\text{dec}}[20] * \text{max\_activation}
  4. If we hold x^()\hat{x}(\bullet) fixed, thus L1 pushes f(x)0f(x) \to 0, while reconstruction loss pushes f(x)f(x) high enough to produce accurate reconstruction.

    An optimal value is somewhere between.

    However, rescaling the shrink feature activations (Sharkey, 2024) is not necessarily enough to overcome bias induced by L1: a SAE might learnt sub-optimal encoder and decoder directions that is not improved by the fixed.

  5. (Gorton, 2024) applies SAEs to study InceptionV1, where cross-branch superposition is significant in interpreting models with parallel branches

  6. causal description it provides likely differs from that of the underlying model.

  7. Wdec,il\|W_\text{dec,i}^l\| is the L2 norm of a single feature’s decoder vector at a given layer.

    In principe, one might have expected to use L2 norm of per-layer norm lLWdec,il2\sqrt{\sum_{l \in L} \|W_\text{dec,i}^l\|^2}

  8. Chris Colah’s blog post explains how t-SNE can be used to visualize collections of networks in a function space.

  9. Even though features still correspond to directions, the set of interpretable direction is larger than the number of dimensions