# Basic Forward Pass Functionality. def forward( self, x: torch.Tensor, ) -> torch.Tensor: feature_acts = self.encode(x) sae_out = self.decode(feature_acts) # TEMP if self.use_error_term: with torch.no_grad(): # Recompute everything without hooks to get true error term # Otherwise, the output with error term will always equal input, even for causal interventions that affect x_reconstruct # This is in a no_grad context to detach the error, so we can compute SAE feature gradients (eg for attribution patching). See A.3 in https://arxiv.org/pdf/2403.19647.pdf for more detail # NOTE: we can't just use `sae_error = input - x_reconstruct.detach()` or something simpler, since this would mean intervening on features would mean ablating features still results in perfect reconstruction. with _disable_hooks(self): feature_acts_clean = self.encode(x) x_reconstruct_clean = self.decode(feature_acts_clean) sae_error = self.hook_sae_error(x - x_reconstruct_clean) sae_out = sae_out + sae_error return self.hook_sae_output(sae_out)
SAE lens error term
Creator
Creator
Seonglae ChoCreated
Created
2025 Feb 5 0:20Editor
Editor
Seonglae ChoEdited
Edited
2025 Feb 5 0:20Refs
Refs