abbrev: SAE
see also: landspace
Often contains one layers of MLP with few linear ReLU that is trained on a subset of datasets the main LLMs is trained on.
empirical example: if we wish to interpret all features related to the author Camus, we might want to train an SAEs based on all given text of Camus to interpret “similar” features from Llama-3.1
definition
We wish to decompose a models’ activitation into sparse, linear combination of feature directions:
Thus, the baseline architecture of SAEs is a linear autoencoder with L1 penalty on the activations:
training it to reconstruct a large dataset of model activations , constraining hidden representation to be sparse
L1 norm with coefficient to construct loss during training:
intuition
We need to reconstruction fidelity at a given sparsity level, as measured by L0 via a mixture of reconstruction fidelity and L1 regularization.
We can reduce sparsity loss term without affecting reconstruction by scaling up norm of decoder weights, or constraining norms of columns during training
Ideas: output of decoder has two roles
- detects what features acre active ⇐ L1 is crucial to ensure sparsity in decomposition
- estimates magnitudes of active features ⇐ L1 is unwanted bias
Gated SAE
see also: paper
uses Pareto improvement over training to reduce L1 penalty (Rajamanoharan et al., 2024)
Clear consequence of the bias during training is shrinkage (Sharkey, 2024) 1
Idea is to use gated ReLU encoder (Dauphin et al., 2017; Shazeer, 2020):
where is the (pointwise) Heaviside step function and denotes elementwise multiplication.
term | annotations |
---|---|
which features are deemed to be active | |
feature activation magnitudes (for features that have been deemed to be active) | |
sub-layer’s pre-activations |
to negate the increases in parameters, use weight sharing:
Scale in terms of with a vector-valued rescaling parameter :
Figure 3: Gated SAE with weight sharing between gating and magnitude paths
Figure 4: A gated encoder become a single layer linear encoder with Jump ReLU (Erichson et al., 2019) activation function
feature suppression
See also: link
Loss function of SAEs combines a MSE reconstruction loss with sparsity term:
the reconstruction is not perfect, given that only one is reconstruction. For smaller value of , features will be suppressed
illustrated example
consider one binary feature in one dimension with probability and otherwise. Ideally, optimal SAE would extract feature activation of and have decoder
However, if we train SAE optimizing loss function , let say encoder outputs feature activation if and 0 otherwise, ignore bias term, the optimization problem becomes:
How do we fix feature suppression in training SAEs?
introduce element-wise scaling factor per feature in-between encoder and decoder, represented by vector :
Footnotes
-
If we hold fixed, thus L1 pushes , while reconstruction loss pushes high enough to produce accurate reconstruction.
An optimal value is somewhere between.
However, rescaling the shrink feature activations (Sharkey, 2024) is not necessarily enough to overcome bias induced by L1: a SAE might learnt sub-optimal encoder and decoder directions that is not improved by the fixed. ↩