Skip to content

Fix double rmsnorm on B/C in MambaInnerFn.backward at checkpoint_lvl=1#902

Open
Chessing234 wants to merge 1 commit intostate-spaces:mainfrom
Chessing234:fix/double-rmsnorm-backward
Open

Fix double rmsnorm on B/C in MambaInnerFn.backward at checkpoint_lvl=1#902
Chessing234 wants to merge 1 commit intostate-spaces:mainfrom
Chessing234:fix/double-rmsnorm-backward

Conversation

@Chessing234
Copy link
Copy Markdown
Contributor

Summary

  • When checkpoint_lvl >= 1, MambaInnerFn.forward() applies rmsnorm to B and C, then saves the normalized tensors. backward() loads these already-normalized tensors and applies rmsnorm again, producing incorrect gradients due to double normalization.
  • Fix: set B and C to None before ctx.save_for_backward (matching the existing pattern for delta), and recompute them from x_dbl in backward() before applying rmsnorm exactly once.
  • B_proj_bias and C_proj_bias are now stored on ctx so recomputation can correctly include projection biases when present.

Fixes #885

Test plan

  • Verify gradient correctness by comparing checkpoint_lvl=0 (no recomputation, known correct) with checkpoint_lvl=1 (recomputation) when b_rms_weight and c_rms_weight are set
  • Run torch.autograd.gradcheck on MambaInnerFn with rmsnorm enabled
  • Confirm no regression when b_rms_weight / c_rms_weight are None (original code path unchanged)

🤖 Generated with Claude Code

When checkpoint_lvl >= 1, forward() normalizes B and C with rmsnorm then
saves the already-normalized tensors via ctx.save_for_backward. backward()
then loads these tensors and applies rmsnorm again, resulting in double
normalization that produces incorrect gradients.

Fix by setting B and C to None before saving (matching the existing pattern
for delta), and recomputing them from x_dbl in backward before applying
rmsnorm exactly once. B_proj_bias and C_proj_bias are now stored on ctx
so the recomputation can correctly include projection biases.

Fixes state-spaces#885

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[bug] double rmsnorm on B/C in MambaInnerFn.backward at checkpoint_lvl=1

1 participant