SAE Training

Creator
Creator
Seonglae Cho
Created
Created
2025 Jan 21 13:28
Editor
Edited
Edited
2025 Mar 8 12:29

Reconstruction

z=Wenc(xlbdec)+bencxl^=Wdecz+bdec=sifiz= W_{enc}(x_l - b_{dec}) + b_{enc}\\ \hat{x_l} = W_{dec}z + b_{dec} = \sum s_if_i

Loss

L(x):=xx^(f(x))22Lreconstruct+λS(f(x))Lsparsity\mathcal{L}(\mathbf{x}) := \underbrace{\left\|\mathbf{x} - \hat{\mathbf{x}}\bigl(f(\mathbf{x})\bigr)\right\|_2^2}_{\mathcal{L}_{\text{reconstruct}}} + \underbrace{\lambda \mathcal{S}\bigl(f(\mathbf{x})\bigr)}_{\mathcal{L}_{\text{sparsity}}}
Sometimes, decoder weight sparsity is also included in the loss to include decoder part of sparsity.
L=xx^22+λifi(x)Wdec,:,i2\mathcal{L} = \left\|\mathbf{x} -\hat{\mathbf{x}}\right\|_2^2 + \lambda \sum_{i} f_i(\mathbf{x}) \cdot \left\|W_{dec, :, i}\right\|_2

Hyperparameters

  • expansion_factor
Sometimes subtracting bias of decoder to center the input activation to learn meaningful features
SAE Training Techniques
 
 
SAE Training Factors
 
 
 
 

Training SAE with L2 loss with several hyperparameters

Techniques
Specifically, we use a layer 5/6 of the way into the network for GPT-4 series models, and we use layer 8 ( 3/4 of the way) for GPT-2 small. We use a context length of 64 tokens for all experiments.
 
 
 

Recommendations