There seems to be an issue on https://github.com/state-spaces/mamba/blob/main/mamba_ssm/modules/mamba2.py#L319
When D_has_hdim=True, self.D has shape (d_ssm) = (n_heads * headdim,)
So the the rearrange is unsqueezing the last dimension (n_heads * headdim,) -> (n_heads * headdim, 1).
However the shape of x and y is (b h p) which cannot be broadcasted to the shape of the rearranged D (n_heads * headdim, 1)
This behaviour is inconsistent with the forward() method
If needed, I can contribute a PR for this
Related issue: #887
There seems to be an issue on https://github.com/state-spaces/mamba/blob/main/mamba_ssm/modules/mamba2.py#L319
When
D_has_hdim=True,self.Dhas shape(d_ssm) = (n_heads * headdim,)So the the rearrange is unsqueezing the last dimension
(n_heads * headdim,) -> (n_heads * headdim, 1).However the shape of
xandyis(b h p)which cannot be broadcasted to the shape of the rearranged D(n_heads * headdim, 1)This behaviour is inconsistent with the
forward()methodIf needed, I can contribute a PR for this
Related issue: #887