Skip to content

Fix b_dec being double-counted in the TopK AuxK loss#134

Open
robbiebusinessacc wants to merge 1 commit into
EleutherAI:mainfrom
robbiebusinessacc:contrib/fix-auxk-bdec-double-count
Open

Fix b_dec being double-counted in the TopK AuxK loss#134
robbiebusinessacc wants to merge 1 commit into
EleutherAI:mainfrom
robbiebusinessacc:contrib/fix-auxk-bdec-double-count

Conversation

@robbiebusinessacc

Copy link
Copy Markdown

Fixes #132.

The AuxK loss (dead-latent revival, Appendix B.1) is computed against the
residual target e = y - sae_out, where sae_out already includes the
decoder bias b_dec. The second decoder pass that produces e_hat, however,
called self.decode(...), and decode() adds b_dec (return y + self.b_dec). This double-counts b_dec: e_hat ends up shifted by an extra
b_dec, so dead latents are pulled toward e - b_dec rather than e, and an
unintended gradient is placed on b_dec.

This is the normal case in practice, since b_dec is initialized to the data
mean (nonzero), so every TopK SAE/transcoder trained with auxk_alpha > 0 is
affected.

Fix

Compute e_hat by calling decoder_impl directly with W_dece_hat = decoder_impl(auxk_indices, auxk_acts.to(self.dtype), self.W_dec.mT)
skipping the + b_dec that decode() adds. When b_dec == 0 the result is
unchanged; when b_dec != 0 the AuxK target is now correct.

Testing

Added tests/test_auxk_loss.py, a CPU-only test (uses the eager decoder
fallback, no GPU required) that:

  • builds a SparseCoder with a nonzero b_dec,
  • runs forward(x, dead_mask=all-True),
  • asserts out.auxk_loss equals the manually decoded target without re-adding
    b_dec, and
  • asserts the previous (buggy) + b_dec formulation produces a different
    value.

The test fails against the previous code (Expected 2.0363 but got 2.7278)
and passes with the fix. ruff check, ruff format --check, and black --check are all clean on the changed files.

The AuxK loss target e = y - sae_out already accounts for b_dec, since
sae_out is produced by decode() which adds b_dec. Computing e_hat via
self.decode() added b_dec a second time, pulling dead latents toward
(e - b_dec) and placing an unintended gradient on b_dec whenever
b_dec != 0 (the normal case after init centers it on the data mean).

Call decoder_impl directly so e_hat does not re-add b_dec.
@CLAassistant

CLAassistant commented May 29, 2026

Copy link
Copy Markdown

CLA assistant check
All committers have signed the CLA.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

b_dec is incorrectly added to topk aux loss

2 participants