abbrev: SAE

see also: landspace

Often contains one layers of MLP with few linear ReLU that is trained on a subset of datasets the main LLMs is trained on.

empirical example: if we wish to interpret all features related to the author Camus, we might want to train an SAEs based on all given text of Camus to interpret “similar” features from Llama-3.1

definition

We wish to decompose a models’ activitation into sparse, linear combination of feature directions:

Thus, the baseline architecture of SAEs is a linear autoencoder with L1 penalty on the activations:

training it to reconstruct a large dataset of model activations , constraining hidden representation to be sparse

L1 norm with coefficient to construct loss during training:

intuition

We need to reconstruction fidelity at a given sparsity level, as measured by L0 via a mixture of reconstruction fidelity and L1 regularization.

We can reduce sparsity loss term without affecting reconstruction by scaling up norm of decoder weights, or constraining norms of columns during training

Ideas: output of decoder has two roles

  • detects what features acre active L1 is crucial to ensure sparsity in decomposition
  • estimates magnitudes of active features L1 is unwanted bias

Gated SAE

see also: paper

uses Pareto improvement over training to reduce L1 penalty (Rajamanoharan et al., 2024)

Clear consequence of the bias during training is shrinkage (Sharkey, 2024) 1

Idea is to use gated ReLU encoder (Dauphin et al., 2017; Shazeer, 2020):

where is the (pointwise) Heaviside step function and denotes elementwise multiplication.

termannotations
which features are deemed to be active
feature activation magnitudes (for features that have been deemed to be active)
sub-layer’s pre-activations

to negate the increases in parameters, use weight sharing:

Scale in terms of with a vector-valued rescaling parameter :

Figure 3: Gated SAE with weight sharing between gating and magnitude paths

Figure 4: A gated encoder become a single layer linear encoder with Jump ReLU (Erichson et al., 2019) activation function

feature suppression

See also: link

Loss function of SAEs combines a MSE reconstruction loss with sparsity term:

the reconstruction is not perfect, given that only one is reconstruction. For smaller value of , features will be suppressed

How do we fix feature suppression in training SAEs?

introduce element-wise scaling factor per feature in-between encoder and decoder, represented by vector :

Dauphin, Y. N., Fan, A., Auli, M., & Grangier, D. (2017). Language Modeling with Gated Convolutional Networks. arXiv preprint arXiv:1612.08083 arxiv
Erichson, N. B., Yao, Z., & Mahoney, M. W. (2019). JumpReLU: A Retrofit Defense Strategy for Adversarial Attacks. arXiv preprint arXiv:1904.03750 arxiv
Rajamanoharan, S., Conmy, A., Smith, L., Lieberum, T., Varma, V., Kramár, J., Shah, R., & Nanda, N. (2024). Improving Dictionary Learning with Gated Sparse Autoencoders. arXiv preprint arXiv:2404.16014 arxiv
Sharkey, L. (2024). Addressing Feature Suppression in SAEs. AI Alignment Forum. [post]
Shazeer, N. (2020). GLU Variants Improve Transformer. arXiv preprint arXiv:2002.05202 arxiv

Footnotes

  1. If we hold fixed, thus L1 pushes , while reconstruction loss pushes high enough to produce accurate reconstruction.
    An optimal value is somewhere between.
    However, rescaling the shrink feature activations (Sharkey, 2024) is not necessarily enough to overcome bias induced by L1: a SAE might learnt sub-optimal encoder and decoder directions that is not improved by the fixed.