Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion mamba_ssm/ops/selective_scan_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,10 @@

from mamba_ssm.ops.triton.layer_norm import _layer_norm_fwd

import selective_scan_cuda
try:
import selective_scan_cuda
except ImportError:
selective_scan_cuda = None


class SelectiveScanFn(torch.autograd.Function):
Expand Down
4 changes: 2 additions & 2 deletions mamba_ssm/ops/triton/mamba3/angle_dt.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
configs=[
triton.Config({}, num_stages=s, num_warps=w)
for s in [1, 2, 3]
for w in [2, 4, 8]
for w in [1, 2, 4, 8]
],
key=["CHUNK_SIZE", "BLOCK_D", "HAS_INIT_STATE", "RETURN_OUTPUT_STATE", "IS_VARLEN"],
)
Expand Down Expand Up @@ -224,7 +224,7 @@ def angle_dt_fwd(
configs=[
triton.Config({}, num_stages=s, num_warps=w)
for s in [1, 2, 3]
for w in [2, 4, 8]
for w in [1, 2, 4, 8]
],
key=["CHUNK_SIZE", "BLOCK_D", "HAS_INIT_STATE", "HAS_GRAD_OUTPUT_STATE", "IS_VARLEN"],
)
Expand Down
39 changes: 30 additions & 9 deletions mamba_ssm/ops/triton/mamba3/mamba3_siso_bwd.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,19 +13,27 @@

import triton
import triton.language as tl
from mamba_ssm.ops.triton.mamba3.utils import cos_approx, sin_approx, sigmoid_approx
from mamba_ssm.ops.triton.mamba3.utils import (
cos_approx, sin_approx, sigmoid_approx,
_maxnreg, MAXNREG_VALUES, MAXNREG_VALUES_SMALL,
)

# =============================================================================
# dZ Kernel
# =============================================================================

@triton.autotune(
configs=[
triton.Config({"CHUNK_SIZE": cs}, num_stages=s, num_warps=w, maxnreg=r)
triton.Config({"CHUNK_SIZE": cs}, num_stages=s, num_warps=w, **_maxnreg(r))
for cs in [32, 64]
for s in [1, 2, 3]
for w in [2, 4, 8]
for r in [None, 128, 256]
for r in MAXNREG_VALUES
] + [
# Smaller configs for GPUs with limited register files (e.g. AMD RDNA4).
triton.Config({"CHUNK_SIZE": cs}, num_stages=1, num_warps=1, **_maxnreg(r))
for cs in [16, 32]
for r in MAXNREG_VALUES_SMALL
],
key=["HEADDIM_V"]
)
Expand Down Expand Up @@ -193,10 +201,14 @@ def grid(META):

@triton.autotune(
configs=[
triton.Config({}, num_stages=s, num_warps=w, maxnreg=r)
triton.Config({}, num_stages=s, num_warps=w, **_maxnreg(r))
for s in [1, 2, 3]
for w in [2, 4, 8]
for r in [None, 128, 256]
for r in MAXNREG_VALUES
] + [
# Smaller configs for GPUs with limited register files (e.g. AMD RDNA4).
triton.Config({}, num_stages=1, num_warps=1, **_maxnreg(r))
for r in MAXNREG_VALUES_SMALL
],
key=["CHUNK_SIZE", "HEADDIM_QK", "HEADDIM_V", "IS_VARLEN"]
)
Expand Down Expand Up @@ -811,10 +823,14 @@ def compute_dqkv(

@triton.autotune(
configs=[
triton.Config({}, num_stages=s, num_warps=w, maxnreg=r)
triton.Config({}, num_stages=s, num_warps=w, **_maxnreg(r))
for s in [1, 2, 3]
for w in [2, 4, 8]
for r in [None, 128, 256]
for r in MAXNREG_VALUES
] + [
# Smaller configs for GPUs with limited register files (e.g. AMD RDNA4).
triton.Config({}, num_stages=1, num_warps=1, **_maxnreg(r))
for r in MAXNREG_VALUES_SMALL
],
key=["CHUNK_SIZE", "BLOCK_HEADDIM_QK", "HEADDIM_QK", "GQA_RATIO"]
)
Expand Down Expand Up @@ -1418,11 +1434,16 @@ def apply_dk_state_post(
# =============================================================================
@triton.autotune(
configs=[
triton.Config({"CHUNK_SIZE": cs}, num_stages=s, num_warps=w, maxnreg=r)
triton.Config({"CHUNK_SIZE": cs}, num_stages=s, num_warps=w, **_maxnreg(r))
for cs in [64, 128, 256]
for s in [1, 2, 3]
for w in [2, 4, 8]
for r in [None, 128, 256]
for r in MAXNREG_VALUES
] + [
# Smaller configs for GPUs with limited register files (e.g. AMD RDNA4).
triton.Config({"CHUNK_SIZE": cs}, num_stages=1, num_warps=1, **_maxnreg(r))
for cs in [32, 64]
for r in MAXNREG_VALUES_SMALL
],
key=["HEADDIM_V", "HEADDIM_QK", "HAS_INPUT_STATE", "IS_VARLEN"]
)
Expand Down
17 changes: 13 additions & 4 deletions mamba_ssm/ops/triton/mamba3/mamba3_siso_fwd.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,26 @@

import triton
import triton.language as tl
from mamba_ssm.ops.triton.mamba3.utils import cos_approx, sin_approx, tanh_approx, silu, sigmoid_approx
from mamba_ssm.ops.triton.mamba3.utils import (
cos_approx, sin_approx, tanh_approx, silu, sigmoid_approx,
_maxnreg, MAXNREG_VALUES, MAXNREG_VALUES_SMALL,
)

@triton.autotune(
configs=[
triton.Config({}, num_stages=s, num_warps=w, maxnreg=r)
triton.Config({}, num_stages=s, num_warps=w, **_maxnreg(r))
for s in [1, 2, 3]
for w in [2, 4, 8]
for r in [None, 128, 256]
for r in MAXNREG_VALUES
] + [
# Configs targeting GPUs with smaller register files (e.g. AMD RDNA4).
# num_warps=1 halves per-wavefront register demand; num_stages=1 avoids
# extra live-range overlap from software pipelining.
triton.Config({}, num_stages=1, num_warps=1, **_maxnreg(r))
for r in MAXNREG_VALUES_SMALL
],
key=[
"CHUNK_SIZE", "HEADDIM_QK", "HEADDIM_V", "STORE_SSM_STATES_ADT_OUTV", "HAS_D",
"CHUNK_SIZE", "HEADDIM_QK", "HEADDIM_V", "STORE_SSM_STATES_ADT_OUTV", "HAS_D",
"HAS_Z", "HAS_INITIAL_STATES", "RETURN_FINAL_STATES", "IS_VARLEN"],
)
@triton.jit
Expand Down
2 changes: 1 addition & 1 deletion mamba_ssm/ops/triton/mamba3/mamba3_siso_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
configs=[
triton.Config({}, num_stages=s, num_warps=w)
for s in [1, 2, 3]
for w in [2, 4, 8]
for w in [1, 2, 4, 8]
],
key=[
"HEADDIM_QK", "HEADDIM_V", "HAS_D", "HAS_Z",],
Expand Down
Loading