Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2,107 changes: 2,107 additions & 0 deletions aiter/ops/flydsl/kernels/attn_bwd_mxfp8_gfx950.py

Large diffs are not rendered by default.

89 changes: 89 additions & 0 deletions aiter/ops/triton/_triton_kernels/attention/mha_fused_bwd.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,95 @@
)


@triton.jit
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reorg suggestion:

  • Add the new kernel to a brand new file, as you did with aiter/ops/triton/_triton_kernels/quant/mxfp8_quant.py. It can be something like aiter/ops/triton/_triton_kernels/attention/mxfp8_bwd_preprocess.py.
  • In a conversation we had via Teams, I discovered that the new kernel is only compatible with FlyDSL. The current organizations mixes it with Triton "fused" backward implementation.

def upcast_mxfp8(tensor, scale, BLOCK_M, BLOCK_D_POW2):
scale = (scale.to(tl.uint32) << 23).to(tl.float32, bitcast=True)
tensor = tensor.to(tl.float32)
tensor = tensor.reshape((BLOCK_M, BLOCK_D_POW2 // 32, 32))
tensor = tensor * scale[:, :, None]
tensor = tensor.reshape((BLOCK_M, BLOCK_D_POW2))
return tensor


@triton.jit
def _bwd_preprocess_mxfp8(
o_ptr,
o_scale_ptr,
do_ptr,
do_scale_ptr,
delta_ptr,
stride_o_b,
stride_o_m,
stride_o_k,
stride_o_scale_b,
stride_o_scale_m,
stride_o_scale_k,
stride_do_b,
stride_do_m,
stride_do_k,
stride_do_scale_b,
stride_do_scale_m,
stride_do_scale_k,
stride_delta_b,
stride_delta_m,
max_seqlen_q,
BLOCK_M: tl.constexpr,
BLOCK_D: tl.constexpr,
BLOCK_D_POW2: tl.constexpr,
):
pid_m = tl.program_id(0) # seqlen
bid = tl.program_id(1) # batch

# Compute offsets
BLOCK_D_SCALE: tl.constexpr = BLOCK_D // 32
BLOCK_D_SCALE_POW2: tl.constexpr = BLOCK_D_POW2 // 32
offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_k = tl.arange(0, BLOCK_D_POW2)
offs_k_scale = tl.arange(0, BLOCK_D_SCALE_POW2)

offs_o = (
bid * stride_o_b + offs_m[:, None] * stride_o_m + offs_k[None, :] * stride_o_k
)
offs_o_scale = (
bid * stride_o_scale_b
+ offs_m[:, None] * stride_o_scale_m
+ offs_k_scale[None, :] * stride_o_scale_k
)
offs_do = (
bid * stride_do_b
+ offs_m[:, None] * stride_do_m
+ offs_k[None, :] * stride_do_k
)
offs_do_scale = (
bid * stride_do_scale_b
+ offs_m[:, None] * stride_do_scale_m
+ offs_k_scale[None, :] * stride_do_scale_k
)

# create masks
mask_m = offs_m < max_seqlen_q
mask = mask_m[:, None]
mask_scale = mask
PADDED_HEAD: tl.constexpr = BLOCK_D != BLOCK_D_POW2
if PADDED_HEAD:
mask &= offs_k[None, :] < BLOCK_D
mask_scale &= offs_k_scale[None, :] < BLOCK_D_SCALE

# load [BLOCK_M, BLOCK_D_MODEL_POW2]
o = tl.load(o_ptr + offs_o, mask=mask, other=0.0)
o_scale = tl.load(o_scale_ptr + offs_o_scale, mask=mask_scale, other=0.0)
do = tl.load(do_ptr + offs_do, mask=mask, other=0.0)
do_scale = tl.load(do_scale_ptr + offs_do_scale, mask=mask_scale, other=0.0)

# compute and write-back to delta
o_fp32 = upcast_mxfp8(o, o_scale, BLOCK_M, BLOCK_D_POW2)
do_fp32 = upcast_mxfp8(do, do_scale, BLOCK_M, BLOCK_D_POW2)
delta = tl.sum(o_fp32 * do_fp32, axis=1)

offs_delta = bid * stride_delta_b + offs_m * stride_delta_m
tl.store(delta_ptr + offs_delta, delta, mask=mask_m)


@triton.jit(repr=_bwd_preprocess_repr)
def _bwd_preprocess(
o_ptr,
Expand Down
Loading
Loading