Reconstruction z = W e n c ( x l − b d e c ) + b e n c x l ^ = W d e c z + b d e c = ∑ s i f i z= W_{enc}(x_l - b_{dec}) + b_{enc}\\
\hat{x_l} = W_{dec}z + b_{dec} = \sum s_if_i z = W e n c ( x l − b d ec ) + b e n c x l ^ = W d ec z + b d ec = ∑ s i f i Loss L ( x ) : = ∥ x − x ^ ( f ( x ) ) ∥ 2 2 ⏟ L reconstruct + λ S ( f ( x ) ) ⏟ L sparsity \mathcal{L}(\mathbf{x}) := \underbrace{\left\|\mathbf{x} - \hat{\mathbf{x}}\bigl(f(\mathbf{x})\bigr)\right\|_2^2}_{\mathcal{L}_{\text{reconstruct}}}
+ \underbrace{\lambda \mathcal{S}\bigl(f(\mathbf{x})\bigr)}_{\mathcal{L}_{\text{sparsity}}} L ( x ) := L reconstruct x − x ^ ( f ( x ) ) 2 2 + L sparsity λ S ( f ( x ) ) Sometimes, decoder weight sparsity is also included in the loss to include decoder part of sparsity.
L = ∥ x − x ^ ∥ 2 2 + λ ∑ i f i ( x ) ⋅ ∥ W d e c , : , i ∥ 2 \mathcal{L} = \left\|\mathbf{x} -\hat{\mathbf{x}}\right\|_2^2
+ \lambda \sum_{i} f_i(\mathbf{x}) \cdot \left\|W_{dec, :, i}\right\|_2 L = ∥ x − x ^ ∥ 2 2 + λ ∑ i f i ( x ) ⋅ ∥ W d ec , : , i ∥ 2 Hyperparameters Sometimes subtracting bias of decoder to center the input activation to learn meaningful features
Training SAE with L2 loss with several hyperparameters Techniques
Specifically, we use a layer 5/6 of the way into the network for GPT-4 series models, and we use layer 8 ( 3/4 of the way) for GPT-2 small. We use a context length of 64 tokens for all experiments.