offset corrupts gradients silently for large batch/dstate configs
The buffer x pointer offset in both the forward and backward selective-scan CUDA kernels is computed using an all-int32 multiply.
When batch * dim * n_chunks * dstate exceeds max int32, it wraps the result negative.
The kernel then reads and writes memory located before x_ptr, silently corrupting adjacent tensors or triggering a CUDA illegal-address fault.
fixed in pr #883
offset corrupts gradients silently for large batch/dstate configs
The buffer x pointer offset in both the forward and backward selective-scan CUDA kernels is computed using an all-int32 multiply.
When
batch * dim * n_chunks * dstateexceeds max int32, it wraps the result negative.The kernel then reads and writes memory located before
x_ptr, silently corrupting adjacent tensors or triggering a CUDA illegal-address fault.fixed in pr #883