Skip to content

Mamba2.step() handles D incorrectly when D_has_dim=True #887

@GiftedNovaHD

Description

@GiftedNovaHD

Hi, I think there's a bug in Mamba2.step() when we set D_has_dim = True.

In __init__, self.D is initialized as

self.D = nn.Parameter(torch.ones(self.d_ssm if self.D_has_hdim else self.nheads, device=device))

Thus, when D_has_hdim=True, self.D has shape (d_ssm,) = (nheads * headdim,).

I believe the forward path appear to handle this correctly by reshaping D as (h, p)

So when D_has_hdim=True, self.D has shape (d_ssm,) = (nheads * headdim,).

But I don't think step() makes the same distinction because for this line, it implicitly assumes that self.D is per head, with shape (nheads,), and not per head-dim: (nheads * headdim,).

So for D_has_hdim = True, step() seems inconsistent with forward(), as forward() reshapes self.D from (h * p,) -> (h,p), whereas step() instead treats it as it were (h,), as mentioned earlier

I can contribute a PR to fix this

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions