Hi, I think there's a bug in Mamba2.step() when we set D_has_dim = True.
In __init__, self.D is initialized as
self.D = nn.Parameter(torch.ones(self.d_ssm if self.D_has_hdim else self.nheads, device=device))
Thus, when D_has_hdim=True, self.D has shape (d_ssm,) = (nheads * headdim,).
I believe the forward path appear to handle this correctly by reshaping D as (h, p)
So when D_has_hdim=True, self.D has shape (d_ssm,) = (nheads * headdim,).
But I don't think step() makes the same distinction because for this line, it implicitly assumes that self.D is per head, with shape (nheads,), and not per head-dim: (nheads * headdim,).
So for D_has_hdim = True, step() seems inconsistent with forward(), as forward() reshapes self.D from (h * p,) -> (h,p), whereas step() instead treats it as it were (h,), as mentioned earlier
I can contribute a PR to fix this
Hi, I think there's a bug in
Mamba2.step()when we setD_has_dim = True.In
__init__,self.Dis initialized asThus, when
D_has_hdim=True,self.Dhas shape(d_ssm,) = (nheads * headdim,).I believe the forward path appear to handle this correctly by reshaping
Das(h, p)But I don't think
step()makes the same distinction because for this line, it implicitly assumes thatself.Dis per head, with shape(nheads,), and not per head-dim:(nheads * headdim,).So for
D_has_hdim = True,step()seems inconsistent withforward(), asforward()reshapesself.Dfrom(h * p,) -> (h,p), whereasstep()instead treats it as it were(h,), as mentioned earlierI can contribute a PR to fix this