Fix b_dec being double-counted in the TopK AuxK loss#134
Open
robbiebusinessacc wants to merge 1 commit into
Open
Fix b_dec being double-counted in the TopK AuxK loss#134robbiebusinessacc wants to merge 1 commit into
robbiebusinessacc wants to merge 1 commit into
Conversation
The AuxK loss target e = y - sae_out already accounts for b_dec, since sae_out is produced by decode() which adds b_dec. Computing e_hat via self.decode() added b_dec a second time, pulling dead latents toward (e - b_dec) and placing an unintended gradient on b_dec whenever b_dec != 0 (the normal case after init centers it on the data mean). Call decoder_impl directly so e_hat does not re-add b_dec.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Fixes #132.
The AuxK loss (dead-latent revival, Appendix B.1) is computed against the
residual target
e = y - sae_out, wheresae_outalready includes thedecoder bias
b_dec. The second decoder pass that producese_hat, however,called
self.decode(...), anddecode()addsb_dec(return y + self.b_dec). This double-countsb_dec:e_hatends up shifted by an extrab_dec, so dead latents are pulled towarde - b_decrather thane, and anunintended gradient is placed on
b_dec.This is the normal case in practice, since
b_decis initialized to the datamean (nonzero), so every TopK SAE/transcoder trained with
auxk_alpha > 0isaffected.
Fix
Compute
e_hatby callingdecoder_impldirectly withW_dec—e_hat = decoder_impl(auxk_indices, auxk_acts.to(self.dtype), self.W_dec.mT)—skipping the
+ b_decthatdecode()adds. Whenb_dec == 0the result isunchanged; when
b_dec != 0the AuxK target is now correct.Testing
Added
tests/test_auxk_loss.py, a CPU-only test (uses the eager decoderfallback, no GPU required) that:
SparseCoderwith a nonzerob_dec,forward(x, dead_mask=all-True),out.auxk_lossequals the manually decoded target without re-addingb_dec, and+ b_decformulation produces a differentvalue.
The test fails against the previous code (
Expected 2.0363 but got 2.7278)and passes with the fix.
ruff check,ruff format --check, andblack --checkare all clean on the changed files.