From c894f918b8f87b3f3ce725a02cda551e42f86ea7 Mon Sep 17 00:00:00 2001 From: Lukasz Burzawa Date: Fri, 8 May 2026 17:07:20 +0000 Subject: [PATCH 1/8] add mxfp8 quants and bwd preprocess --- .../attention/mha_fused_bwd.py | 89 ++++++ .../_triton_kernels/quant/mxfp8_quant.py | 273 ++++++++++++++++++ aiter/ops/triton/attention/mha_fused_bwd.py | 68 +++++ aiter/ops/triton/quant/mxfp8_quant.py | 132 +++++++++ op_tests/triton_tests/attention/test_mha.py | 31 ++ .../triton_tests/quant/test_quant_mxfp8.py | 176 +++++++++++ 6 files changed, 769 insertions(+) create mode 100644 aiter/ops/triton/_triton_kernels/quant/mxfp8_quant.py create mode 100644 aiter/ops/triton/quant/mxfp8_quant.py create mode 100644 op_tests/triton_tests/quant/test_quant_mxfp8.py diff --git a/aiter/ops/triton/_triton_kernels/attention/mha_fused_bwd.py b/aiter/ops/triton/_triton_kernels/attention/mha_fused_bwd.py index ff5b96b77c..13d6e55f44 100644 --- a/aiter/ops/triton/_triton_kernels/attention/mha_fused_bwd.py +++ b/aiter/ops/triton/_triton_kernels/attention/mha_fused_bwd.py @@ -29,6 +29,95 @@ ) +@triton.jit +def upcast_mxfp8(tensor, scale, BLOCK_M, BLOCK_D_POW2): + scale = (scale.to(tl.uint32) << 23).to(tl.float32, bitcast=True) + tensor = tensor.to(tl.float32) + tensor = tensor.reshape((BLOCK_M, BLOCK_D_POW2 // 32, 32)) + tensor = tensor * scale[:, :, None] + tensor = tensor.reshape((BLOCK_M, BLOCK_D_POW2)) + return tensor + + +@triton.jit +def _bwd_preprocess_mxfp8( + o_ptr, + o_scale_ptr, + do_ptr, + do_scale_ptr, + delta_ptr, + stride_o_b, + stride_o_m, + stride_o_k, + stride_o_scale_b, + stride_o_scale_m, + stride_o_scale_k, + stride_do_b, + stride_do_m, + stride_do_k, + stride_do_scale_b, + stride_do_scale_m, + stride_do_scale_k, + stride_delta_b, + stride_delta_m, + max_seqlen_q, + BLOCK_M: tl.constexpr, + BLOCK_D: tl.constexpr, + BLOCK_D_POW2: tl.constexpr, +): + pid_m = tl.program_id(0) # seqlen + bid = tl.program_id(1) # batch + + # Compute offsets + BLOCK_D_SCALE: tl.constexpr = BLOCK_D // 32 + BLOCK_D_SCALE_POW2: tl.constexpr = BLOCK_D_POW2 // 32 + offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_k = tl.arange(0, BLOCK_D_POW2) + offs_k_scale = tl.arange(0, BLOCK_D_SCALE_POW2) + + offs_o = ( + bid * stride_o_b + offs_m[:, None] * stride_o_m + offs_k[None, :] * stride_o_k + ) + offs_o_scale = ( + bid * stride_o_scale_b + + offs_m[:, None] * stride_o_scale_m + + offs_k_scale[None, :] * stride_o_scale_k + ) + offs_do = ( + bid * stride_do_b + + offs_m[:, None] * stride_do_m + + offs_k[None, :] * stride_do_k + ) + offs_do_scale = ( + bid * stride_do_scale_b + + offs_m[:, None] * stride_do_scale_m + + offs_k_scale[None, :] * stride_do_scale_k + ) + + # create masks + mask_m = offs_m < max_seqlen_q + mask = mask_m[:, None] + mask_scale = mask + PADDED_HEAD: tl.constexpr = BLOCK_D != BLOCK_D_POW2 + if PADDED_HEAD: + mask &= offs_k[None, :] < BLOCK_D + mask_scale &= offs_k_scale[None, :] < BLOCK_D_SCALE + + # load [BLOCK_M, BLOCK_D_MODEL_POW2] + o = tl.load(o_ptr + offs_o, mask=mask, other=0.0) + o_scale = tl.load(o_scale_ptr + offs_o_scale, mask=mask_scale, other=0.0) + do = tl.load(do_ptr + offs_do, mask=mask, other=0.0) + do_scale = tl.load(do_scale_ptr + offs_do_scale, mask=mask_scale, other=0.0) + + # compute and write-back to delta + o_fp32 = upcast_mxfp8(o, o_scale, BLOCK_M, BLOCK_D_POW2) + do_fp32 = upcast_mxfp8(do, do_scale, BLOCK_M, BLOCK_D_POW2) + delta = tl.sum(o_fp32 * do_fp32, axis=1) + + offs_delta = bid * stride_delta_b + offs_m * stride_delta_m + tl.store(delta_ptr + offs_delta, delta, mask=mask_m) + + @triton.jit(repr=_bwd_preprocess_repr) def _bwd_preprocess( o_ptr, diff --git a/aiter/ops/triton/_triton_kernels/quant/mxfp8_quant.py b/aiter/ops/triton/_triton_kernels/quant/mxfp8_quant.py new file mode 100644 index 0000000000..8e3ec93371 --- /dev/null +++ b/aiter/ops/triton/_triton_kernels/quant/mxfp8_quant.py @@ -0,0 +1,273 @@ +# SPDX-License-Identifier: MIT +# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. + +import triton +import triton.language as tl + + +@triton.jit +def _get_max_quant_val(dtype: tl.constexpr): + if dtype == tl.float8e5: + return 57344.0 + elif dtype == tl.float8e4nv: + return 448.0 + else: + tl.static_assert(False, f"Invalid {dtype=}") + + +@triton.jit +def _compute_mx_quant_and_scale( + src_tensor, + valid_src_mask, + mx_tensor_dtype: tl.constexpr, + SCALE_ROUNDING_MODE: tl.constexpr = 0, +): + BLOCK_SIZE_OUT_DIM: tl.constexpr = src_tensor.shape[0] + BLOCK_SIZE_QUANT_DIM: tl.constexpr = src_tensor.shape[1] + BLOCK_SIZE_QUANT_MX_SCALE: tl.constexpr = src_tensor.shape[1] // 32 + + # Explicit cast to fp32 since most ops are not supported on bfloat16. We avoid needless conversions to and from bf16 + f32_tensor = src_tensor.to(tl.float32) + abs_tensor = tl.abs(f32_tensor) + abs_tensor = tl.where( + valid_src_mask, abs_tensor, -1.0 + ) # Don't consider padding tensors in scale computation + abs_tensor = tl.reshape( + abs_tensor, [BLOCK_SIZE_OUT_DIM, BLOCK_SIZE_QUANT_MX_SCALE, 32] + ) + max_val = tl.max(abs_tensor, axis=2, keep_dims=True) + dequant_scale = max_val / _get_max_quant_val(mx_tensor_dtype) + if SCALE_ROUNDING_MODE == 0: + # ROUND_UP + # compute 2 ** ceil(log2(dequant_scale)) + # Adding 0x007FFFFF adds exponent by 1 unless mantissa is all zeros + # A corner case: exponent is 0xFF that will overflow but that's already + # NaN so assume we don't care. + dequant_scale_exponent = ( + dequant_scale.to(tl.uint32, bitcast=True) + 0x007FFFFF + ) & 0x7F800000 + else: + # ROUND_DOWN + # compute 2 ** floor(log2(dequant_scale)) + assert SCALE_ROUNDING_MODE == 1 + dequant_scale_exponent = dequant_scale.to(tl.uint32, bitcast=True) & 0x7F800000 + dequant_scale_rounded = dequant_scale_exponent.to(tl.float32, bitcast=True) + quant_scale = tl.where(dequant_scale_rounded == 0, 0, 1.0 / dequant_scale_rounded) + + f32_tensor = tl.reshape( + f32_tensor, [BLOCK_SIZE_OUT_DIM, BLOCK_SIZE_QUANT_MX_SCALE, 32] + ) + quant_tensor = f32_tensor * quant_scale + + # Reshape the tensors after scaling + quant_tensor = quant_tensor.reshape([BLOCK_SIZE_OUT_DIM, BLOCK_SIZE_QUANT_DIM]) + # Set the invalid portions of the tensor to 0. This will ensure that any padding tensors are 0 in the mx format. + quant_tensor = tl.where(valid_src_mask, quant_tensor, 0) + dequant_scale_exponent = dequant_scale_exponent.reshape( + [BLOCK_SIZE_OUT_DIM, BLOCK_SIZE_QUANT_MX_SCALE] + ) + + # First, we simply extract the exponent part of the scales and store the result + dequant_scale_exponent = (dequant_scale_exponent >> 23).to(tl.uint8) + # Now we must convert the tensors to the mx format. + out_tensor = quant_tensor.to(mx_tensor_dtype) + + return out_tensor, dequant_scale_exponent + + +@triton.jit +def _downcast_to_mxfp8( + mx_tensor_ptr, + stride_mxt_outer, + stride_mxt_quant: tl.constexpr, + mx_scale_ptr, + stride_mx_scale_outer, + stride_mx_scale_quant, + src_ptr, + stride_src_outer, + stride_src_quant, + outer_dim, + quant_dim, + BLOCK_SIZE_OUT_DIM: tl.constexpr, + BLOCK_SIZE_QUANT_DIM: tl.constexpr, + SCALE_ROUNDING_MODE: tl.constexpr, +): + + tl.static_assert( + stride_mxt_quant == 1, f"Output stride, {stride_mxt_quant=} must be 1." + ) + tl.static_assert( + BLOCK_SIZE_QUANT_DIM % 32 == 0, + f"{BLOCK_SIZE_QUANT_DIM=} must be a multiple of 32", + ) + + mx_tensor_dtype: tl.constexpr = mx_tensor_ptr.dtype.element_ty + tl.static_assert( + mx_tensor_dtype == tl.float8e4nv or mx_tensor_dtype == tl.float8e5, + f"Invalid {mx_tensor_dtype=}. Must be float8.", + ) + + src_dtype: tl.constexpr = src_ptr.dtype.element_ty + tl.static_assert( + mx_scale_ptr.dtype.element_ty == tl.uint8, + f"{mx_scale_ptr.dtype.element_ty=} must be uint8", + ) + tl.static_assert( + (src_dtype == tl.float32) + or (src_dtype == tl.bfloat16) + or (src_dtype == tl.float16), + f"{src_dtype=} must be float32 or bfloat16 or float16", + ) + + outer_block = tl.program_id(0).to(tl.int64) + quant_block = tl.program_id(1).to(tl.int64) + + BLOCK_SIZE_QUANT_MX_SCALE: tl.constexpr = BLOCK_SIZE_QUANT_DIM // 32 + BLOCK_SIZE_QUANT_MX_TENSOR: tl.constexpr = BLOCK_SIZE_QUANT_DIM + + start_src_quant = quant_block * BLOCK_SIZE_QUANT_DIM + start_mx_scale_quant = quant_block * BLOCK_SIZE_QUANT_MX_SCALE + start_mx_quant = quant_block * BLOCK_SIZE_QUANT_MX_TENSOR + start_out = outer_block * BLOCK_SIZE_OUT_DIM + + src_ptr += start_src_quant * stride_src_quant + start_out * stride_src_outer + mx_scale_ptr += ( + start_mx_scale_quant * stride_mx_scale_quant + start_out * stride_mx_scale_outer + ) + mx_tensor_ptr += start_mx_quant * stride_mxt_quant + start_out * stride_mxt_outer + + offs_src_quant = tl.arange(0, BLOCK_SIZE_QUANT_DIM)[None, :].to(tl.int64) + offs_mxt_quant = tl.arange(0, BLOCK_SIZE_QUANT_MX_TENSOR)[None, :].to(tl.int64) + offs_scale_quant = tl.arange(0, BLOCK_SIZE_QUANT_MX_SCALE)[None, :].to(tl.int64) + offs_outer = tl.arange(0, BLOCK_SIZE_OUT_DIM)[:, None].to(tl.int64) + + mask_src_quant = start_src_quant + offs_src_quant < quant_dim + mask_n = start_out + offs_outer < outer_dim + full_mask_src = mask_src_quant & mask_n + + mask_mxt_quant = start_mx_quant + offs_mxt_quant < quant_dim + full_mask_mxt = mask_mxt_quant & mask_n + + scale_mask_k = start_mx_scale_quant + offs_scale_quant < tl.cdiv(quant_dim, 32) + full_scale_mask = scale_mask_k & mask_n + + src_tensor_offsets = ( + offs_src_quant * stride_src_quant + offs_outer * stride_src_outer + ) + mx_scale_offsets = ( + offs_scale_quant * stride_mx_scale_quant + offs_outer * stride_mx_scale_outer + ) + mx_tensor_offsets = ( + offs_mxt_quant * stride_mxt_quant + offs_outer * stride_mxt_outer + ) + src_tensor = tl.load(src_ptr + src_tensor_offsets, mask=full_mask_src) + + out_tensor, scale_tensor = _compute_mx_quant_and_scale( + src_tensor, full_mask_src, mx_tensor_dtype, SCALE_ROUNDING_MODE + ) + + tl.store(mx_scale_ptr + mx_scale_offsets, scale_tensor, mask=full_scale_mask) + tl.store(mx_tensor_ptr + mx_tensor_offsets, out_tensor, mask=full_mask_mxt) + + +@triton.jit +def _upcast_from_mxfp8( + out_ptr, + stride_o_outer, + stride_o_quant: tl.constexpr, + mx_scale_ptr, + stride_scale_outer, + stride_scale_quant, + mx_tensor_ptr, + stride_tensor_outer, + stride_tensor_quant: tl.constexpr, + outer_dim, + quant_dim, + BLOCK_SIZE_OUT_DIM: tl.constexpr, + BLOCK_SIZE_QUANT_DIM: tl.constexpr, +): + + tl.static_assert( + stride_o_quant == 1, "the weight must be contiguous in the k dimension for mx" + ) + tl.static_assert( + BLOCK_SIZE_QUANT_DIM % 32 == 0, "BLOCK_SIZE_K must be a multiple of 32" + ) + # uint8 signifies two fp4 e2m1 values packed into a single byte + mx_tensor_dtype: tl.constexpr = mx_tensor_ptr.dtype.element_ty + dst_dtype: tl.constexpr = out_ptr.dtype.element_ty + tl.static_assert( + dst_dtype == tl.float32 or dst_dtype == tl.float16 or dst_dtype == tl.bfloat16 + ) + tl.static_assert( + (mx_tensor_dtype == tl.float8e4nv or mx_tensor_dtype == tl.float8e5) + or mx_tensor_dtype == dst_dtype, + "mx_tensor_ptr must be float8 or dst_dtype", + ) + tl.static_assert( + mx_scale_ptr.dtype.element_ty == tl.uint8, "mx_scale_ptr must be uint8" + ) + + BLOCK_SIZE_QUANT_MX_SCALE: tl.constexpr = BLOCK_SIZE_QUANT_DIM // 32 + BLOCK_SIZE_QUANT_MX_TENSOR: tl.constexpr = BLOCK_SIZE_QUANT_DIM + + # Compute starting indices for the quantized (packed) dimension and the outer dimension. + outer_block = tl.program_id(0).to(tl.int64) + quant_block = tl.program_id(1).to(tl.int64) + + start_mxt_quant = quant_block * BLOCK_SIZE_QUANT_MX_TENSOR + start_out_quant = quant_block * BLOCK_SIZE_QUANT_DIM + start_mx_scale_quant = quant_block * BLOCK_SIZE_QUANT_MX_SCALE + start_out = outer_block * BLOCK_SIZE_OUT_DIM + + mx_tensor_ptr += ( + start_mxt_quant * stride_tensor_quant + start_out * stride_tensor_outer + ) + mx_scale_ptr += ( + start_mx_scale_quant * stride_scale_quant + start_out * stride_scale_outer + ) + out_ptr += start_out * stride_o_outer + start_out_quant * stride_o_quant + + # Compute offsets and masks. + offs_src_quant = tl.arange(0, BLOCK_SIZE_QUANT_MX_TENSOR)[None, :].to(tl.int64) + offs_out_quant = tl.arange(0, BLOCK_SIZE_QUANT_DIM)[None, :].to(tl.int64) + offs_outer = tl.arange(0, BLOCK_SIZE_OUT_DIM)[:, None].to(tl.int64) + offs_scale = tl.arange(0, BLOCK_SIZE_QUANT_MX_SCALE)[None, :].to(tl.int64) + + mask_outer = start_out + offs_outer < outer_dim + mask_out_quant = start_out_quant + offs_out_quant < quant_dim + full_mask_out = mask_out_quant & mask_outer + + mask_src_quant = start_mxt_quant + offs_src_quant < quant_dim + full_mask_src = mask_src_quant & mask_outer + + mask_scale = start_mx_scale_quant + offs_scale < tl.cdiv(quant_dim, 32) + full_scale_mask = mask_scale & mask_outer + + tensor_offsets = ( + offs_src_quant * stride_tensor_quant + offs_outer * stride_tensor_outer + ) + scale_offsets = offs_scale * stride_scale_quant + offs_outer * stride_scale_outer + out_offsets = offs_out_quant * stride_o_quant + offs_outer * stride_o_outer + + # Load the packed tensor and scale. + tensor = tl.load(mx_tensor_ptr + tensor_offsets, mask=full_mask_src) + scale = tl.load(mx_scale_ptr + scale_offsets, mask=full_scale_mask) + + # Upcast the scale. + dst_scale = (scale.to(tl.uint32) << 23).to(tl.float32, bitcast=True) + + # Now upcast the tensor. + dst_tensor = tensor.to(tl.float32) + + # Reshape for proper broadcasting: the scale was stored with a 32-sized "inner" grouping. + dst_tensor = dst_tensor.reshape([BLOCK_SIZE_OUT_DIM, BLOCK_SIZE_QUANT_MX_SCALE, 32]) + dst_scale = dst_scale.reshape([BLOCK_SIZE_OUT_DIM, BLOCK_SIZE_QUANT_MX_SCALE, 1]) + scale = scale.reshape(dst_scale.shape) + + out_tensor = dst_tensor * dst_scale + # Correct any NaNs encoded via the scale. + out_tensor = tl.where(scale == 0xFF, float("nan"), out_tensor) + out_tensor = out_tensor.reshape([BLOCK_SIZE_OUT_DIM, BLOCK_SIZE_QUANT_DIM]) + out_tensor = out_tensor.to(dst_dtype) + tl.store(out_ptr + out_offsets, out_tensor, mask=full_mask_out) diff --git a/aiter/ops/triton/attention/mha_fused_bwd.py b/aiter/ops/triton/attention/mha_fused_bwd.py index 634d8a0a2c..b7eb191656 100644 --- a/aiter/ops/triton/attention/mha_fused_bwd.py +++ b/aiter/ops/triton/attention/mha_fused_bwd.py @@ -8,6 +8,7 @@ from aiter.ops.triton.utils.types import _is_fp8 from aiter.ops.triton.utils.logger import AiterTritonLogger from aiter.ops.triton._triton_kernels.attention.mha_fused_bwd import ( + _bwd_preprocess_mxfp8, _bwd_preprocess, _bwd_kernel_dkdvdq_causal, _bwd_kernel_dkdvdq_noncausal, @@ -18,6 +19,73 @@ _LOGGER = AiterTritonLogger() +def bwd_preprocess_mxfp8( + o: torch.Tensor, + o_scale: torch.Tensor, + do: torch.Tensor, + do_scale: torch.Tensor, + config: Optional[Dict[str, any]] = None, +): + """ + Backward mx8 preprocess function. + + Args: + o (torch.Tensor): Output from forward pass. Shape (..., seqlen, head_dim) + o_scale (torch.Tensor): MX scales for o computed along head dimension. Shape (..., seqlen, head_dim // 32) + do (torch.Tensor): Output gradient. Shape (..., seqlen, head_dim) + do_scale (torch.Tensor): Output gradient. Shape (..., seqlen, head_dim // 32) + config (Optional[Dict[str, any]]): Kernel tuning parameters. + + Returns: + torch.Tensor: Delta tensor (element-wise product of do and o) with shape matching softmax_lse. + """ + + # get strides and shape + if o.dim() > 3: # flatten batch and number of heads dimensions + o = o.reshape(-1, o.shape[-2], o.shape[-1]) + o_scale = o_scale.reshape(-1, o.shape[-2], o.shape[-1]) + do = do.reshape(-1, do.shape[-2], do.shape[-1]) + do_scale = do_scale.reshape(-1, do.shape[-2], do.shape[-1]) + batch, seqlen, head_dim = o.shape + + # BLOCK_D, BLOCK_D_POW2 + # padding for head_dim. Power of 2 or 16 + BLOCK_D_POW2 = triton.next_power_of_2(head_dim) + BLOCK_D_POW2 = max(BLOCK_D_POW2, 16) + + # init delta + delta = torch.empty((batch, seqlen), dtype=torch.float32).cuda() + + # preprocess + # compute D(delta) = rowsum(dO*O). Note, multiplication is element-wise. + if config is None: + config = _get_config() + + pre_grid = ( + triton.cdiv(seqlen, config["preprocess_kernel"]["PRE_BLOCK"]), + batch, + ) + + _bwd_preprocess_mxfp8[pre_grid]( + o, + o_scale, + do, + do_scale, + delta, + *o.stride(), + *o_scale.stride(), + *do.stride(), + *do_scale.stride(), + *delta.stride(), + seqlen, + BLOCK_M=config["preprocess_kernel"]["PRE_BLOCK"], + BLOCK_D=head_dim, + BLOCK_D_POW2=BLOCK_D_POW2, + ) + + return delta + + def flash_attn_fused_backward( do: torch.Tensor, q: torch.Tensor, diff --git a/aiter/ops/triton/quant/mxfp8_quant.py b/aiter/ops/triton/quant/mxfp8_quant.py new file mode 100644 index 0000000000..a5d0a2a635 --- /dev/null +++ b/aiter/ops/triton/quant/mxfp8_quant.py @@ -0,0 +1,132 @@ +# SPDX-License-Identifier: MIT +# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. + +import triton +import torch +from aiter.ops.triton._triton_kernels.quant.mxfp8_quant import ( + _downcast_to_mxfp8, + _upcast_from_mxfp8, +) +from aiter.ops.triton.utils.logger import AiterTritonLogger + +__all__ = [ + "downcast_to_mxfp8", + "upcast_from_mxfp8", +] + + +_LOGGER = AiterTritonLogger() + + +def downcast_to_mxfp8( + src_tensor: torch.Tensor, + out_quant_type: torch.dtype, + axis: int, + SCALE_ROUNDING_MODE: int = 0, +): + """ + Convert the src weights to mx8 format. The src weight is quantized along the axis dimension. + + If weight_quant_type is torch.uint8, we output mxfp4 where two e2m1 values are packed into a single byte. + Note that this means the k_dim of the tensor will be half of the logical k_dim. + + If weight_quant_type is torch.float8_e4m3fn or torch.float8_e5m2, we output mxfp8 with the float8s are stored + in their respective formats. + """ + ndim = src_tensor.ndim + assert -ndim <= axis < ndim, f"Invalid axis {axis=}" + axis = axis if axis >= 0 else axis + ndim + # downcast + src_tensor = src_tensor.transpose(axis, src_tensor.ndim - 1) + L = src_tensor.shape[-1] + out_shape = src_tensor.shape[:-1] + (L,) + out_scale_shape = src_tensor.shape[:-1] + (triton.cdiv(L, 32),) + + out_quant_tensor = src_tensor.new_empty(out_shape, dtype=out_quant_type) + out_scale = src_tensor.new_empty(out_scale_shape, dtype=torch.uint8) + + kernel_src_tensor = src_tensor.reshape(-1, src_tensor.shape[-1]) + kernel_quant_tensor = out_quant_tensor.view(-1, out_quant_tensor.shape[-1]) + kernel_scale = out_scale.view(-1, out_scale.shape[-1]) + + BLOCK_OUT_DIM = 128 + BLOCK_QUANT_DIM = 32 + grid_out = triton.cdiv(kernel_src_tensor.shape[0], BLOCK_OUT_DIM) + grid_quant = triton.cdiv(kernel_src_tensor.shape[1], BLOCK_QUANT_DIM) + + _downcast_to_mxfp8[(grid_out, grid_quant)]( + kernel_quant_tensor, + *kernel_quant_tensor.stride(), + kernel_scale, + *kernel_scale.stride(), + kernel_src_tensor, + *kernel_src_tensor.stride(), + *kernel_src_tensor.shape, + BLOCK_OUT_DIM, + BLOCK_QUANT_DIM, + SCALE_ROUNDING_MODE, + num_warps=8, + ) + + out_quant_tensor = out_quant_tensor.transpose(axis, src_tensor.ndim - 1) + out_scale = out_scale.transpose(axis, src_tensor.ndim - 1) + return out_quant_tensor, out_scale + + +def upcast_from_mxfp8( + tensor: torch.Tensor, scale: torch.Tensor, dtype: torch.dtype, axis: int +): + """ + Upcasts an mxfp8 weight tensor back to float16 or bfloat16. + + The function assumes that the tensors were quantized along the given axis. + It permutes the tensor so that the quantized axis is last, reshapes to 2D, + launches the Triton upcast kernel, and then unpermutes back to the original order. + """ + ndim = tensor.ndim + assert -ndim <= axis < ndim, f"Invalid axis {axis=}" + axis = axis if axis >= 0 else axis + ndim + assert tensor.ndim == scale.ndim, ( + f"Weight and scale must have the same number of dimensions. " + f"Got {tensor.ndim=} and {scale.ndim=}" + ) + # dtype checks + assert tensor.dtype in { + torch.float8_e5m2, + torch.float8_e4m3fn, + torch.float8_e4m3fnuz, + }, f"Invalid tensor dtype {tensor.dtype=}" + assert scale.dtype == torch.uint8, f"Invalid scale dtype {scale.dtype=}" + assert dtype in ( + torch.float16, + torch.bfloat16, + torch.float32, + ), f"Invalid output dtype {dtype=}" + # upcast + logical_quant_dim = tensor.shape[axis] + tensor = tensor.transpose(axis, tensor.ndim - 1).contiguous() + scale = scale.transpose(axis, scale.ndim - 1).contiguous() + out = torch.empty( + (*tensor.shape[:-1], logical_quant_dim), dtype=dtype, device=tensor.device + ) + reshaped_out = out.view(-1, out.shape[-1]) + reshaped_tensor = tensor.view(-1, tensor.shape[-1]) + reshaped_scale = scale.view(-1, scale.shape[-1]) + BLOCK_OUT_DIM = 128 + BLOCK_QUANT_DIM = 32 + blocks_out_dim = triton.cdiv(reshaped_out.shape[0], BLOCK_OUT_DIM) + blocks_quant_dim = triton.cdiv(reshaped_out.shape[1], BLOCK_QUANT_DIM) + _upcast_from_mxfp8[(blocks_out_dim, blocks_quant_dim)]( + reshaped_out, + *reshaped_out.stride(), + reshaped_scale, + *reshaped_scale.stride(), + reshaped_tensor, + *reshaped_tensor.stride(), + *reshaped_out.shape, + BLOCK_OUT_DIM, + BLOCK_QUANT_DIM, + num_warps=8, + ) + out = out.transpose(axis, scale.ndim - 1).contiguous() + return out diff --git a/op_tests/triton_tests/attention/test_mha.py b/op_tests/triton_tests/attention/test_mha.py index 73721efe2b..4480cae9e7 100644 --- a/op_tests/triton_tests/attention/test_mha.py +++ b/op_tests/triton_tests/attention/test_mha.py @@ -10,6 +10,7 @@ mha_set_use_fused_bwd_kernel, mha_set_use_int64_strides, ) +from aiter.ops.triton.attention.mha_fused_bwd import bwd_preprocess_mxfp8 from aiter.test_mha_common import ( attention_ref, attention_ref_with_tol, @@ -17,6 +18,7 @@ generate_qkv, ) from op_tests.triton_tests.attention.mha_test_utils import pad_rearrange_dropout_mask +from aiter.ops.triton.quant.mxfp8_quant import downcast_to_mxfp8, upcast_from_mxfp8 logging.basicConfig(level=logging.DEBUG) logger = logging.getLogger(__name__) @@ -445,6 +447,35 @@ def test_mha_varlen_with_dropout( ) +@pytest.mark.parametrize("BATCH", [1, 4, 32, 128]) +@pytest.mark.parametrize("SEQLEN", [512, 1024, 2048]) +@pytest.mark.parametrize("HEAD_SZ", [64, 128]) +def test_mha_backward_preprocess_mxfp8( + BATCH: int, + SEQLEN: int, + HEAD_SZ: int, +): + torch.cuda.empty_cache() + torch.manual_seed(20) + + o_fp32 = torch.randn(BATCH, SEQLEN, HEAD_SZ, device="cuda", dtype=torch.float32) + do_fp32 = torch.randn(BATCH, SEQLEN, HEAD_SZ, device="cuda", dtype=torch.float32) + o_fp8, o_scale = downcast_to_mxfp8(o_fp32, torch.float8_e4m3fn, -1) + do_fp8, do_scale = downcast_to_mxfp8(do_fp32, torch.float8_e4m3fn, -1) + o_fp32 = upcast_from_mxfp8(o_fp8, o_scale, torch.float32, -1) + do_fp32 = upcast_from_mxfp8(do_fp8, do_scale, torch.float32, -1) + + triton_out = bwd_preprocess_mxfp8( + o_fp8, + o_scale, + do_fp8, + do_scale, + ) + torch_out = (o_fp32 * do_fp32).sum(-1) + + torch.testing.assert_close(triton_out, torch_out, atol=0.01, rtol=0.01) + + # Production shapes based on real models: # HQ=32, HK=8: Llama 3 8B (GQA 4:1) # HQ=64, HK=8: Llama 3 70B (GQA 8:1) diff --git a/op_tests/triton_tests/quant/test_quant_mxfp8.py b/op_tests/triton_tests/quant/test_quant_mxfp8.py new file mode 100644 index 0000000000..e81d3f78da --- /dev/null +++ b/op_tests/triton_tests/quant/test_quant_mxfp8.py @@ -0,0 +1,176 @@ +# SPDX-License-Identifier: MIT +# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + +import torch +import pytest +from aiter.ops.triton.quant.mxfp8_quant import downcast_to_mxfp8, upcast_from_mxfp8 + + +def get_max_quant_val(dtype): + if dtype == torch.float8_e4m3fn: + return 448.0 + else: + return 57344.0 + + +def torch_downcast_to_mxfp8( + x: torch.Tensor, dtype: torch.dtype, axis: int, SCALE_ROUNDING_MODE: int = 0 +): + # returns tensor and scale in fp32 post quantization + + ndim = x.ndim + assert -ndim <= axis < ndim, f"Invalid axis {axis=}" + axis = axis if axis >= 0 else axis + ndim + + x = x.to(torch.float32) + x = x.transpose(axis, x.ndim - 1) + orig_shape = x.shape + quant_dim = orig_shape[-1] + pad_length = 32 - quant_dim % 32 + if pad_length == 32: + pad_length = 0 + padding = torch.empty(x.shape[:-1] + (pad_length,), dtype=x.dtype, device="cuda") + padding.fill_(-1.0) + x_padded = torch.cat((x, padding), -1) + x_abs_padded = torch.cat((torch.abs(x), padding), -1) + padded_shape = x_padded.shape + + new_shape = padded_shape[:-1] + (padded_shape[-1] // 32, 32) + x_padded = x_padded.reshape(new_shape) + x_abs_padded = x_abs_padded.reshape(new_shape) + scale = torch.amax(x_abs_padded, -1) + scale = scale / get_max_quant_val(dtype) + if SCALE_ROUNDING_MODE == 0: + scale = (scale.view(torch.int32) + 0x007FFFFF) & 0x7F800000 + else: + scale = scale.view(torch.int32) & 0x7F800000 + scale = scale.view(torch.float32) + + scale_inv = torch.where(scale == 0.0, 0.0, 1.0 / scale).unsqueeze(-1) + x_padded = x_padded * scale_inv + x_padded = x_padded.reshape(padded_shape) + x = x_padded[..., :quant_dim].clone() + x = x.to(dtype).to(torch.float32) + x = x.transpose(axis, x.ndim - 1) + scale = scale.transpose(axis, x.ndim - 1) + return x, scale + + +def upcast_scale(scale): + scale = scale.to(torch.int32) << 23 + scale = scale.view(torch.float32) + return scale + + +def torch_upcast_from_mxfp8( + x: torch.Tensor, + scale: torch.Tensor, + dtype: torch.dtype, + axis: int, +): + ndim = x.ndim + assert -ndim <= axis < ndim, f"Invalid axis {axis=}" + axis = axis if axis >= 0 else axis + ndim + + x = x.to(torch.float32) + x = x.transpose(axis, x.ndim - 1) + scale = scale.transpose(axis, x.ndim - 1) + orig_shape = x.shape + quant_dim = orig_shape[-1] + pad_length = 32 - quant_dim % 32 + if pad_length == 32: + pad_length = 0 + padding = torch.empty(x.shape[:-1] + (pad_length,), dtype=x.dtype, device="cuda") + padding.fill_(-1.0) + x_padded = torch.cat((x, padding), -1) + padded_shape = x_padded.shape + + new_shape = padded_shape[:-1] + (padded_shape[-1] // 32, 32) + x_padded = x_padded.reshape(new_shape) + scale = upcast_scale(scale).unsqueeze(-1) + x_padded = x_padded * scale + x_padded = x_padded.reshape(padded_shape) + x = x_padded[..., :quant_dim].clone() + x = x.transpose(axis, x.ndim - 1) + x = x.to(dtype) + return x + + +@pytest.mark.parametrize( + "shape, axis", + [ + ((1, 4), -1), + ((1, 28), -1), + ((1, 32), -1), + ((1, 64), -1), + ((1, 68), -1), + ((2, 4), -1), + ((2, 28), -1), + ((2, 32), -1), + ((2, 200, 64), 1), + ((2, 68), -1), + ((128, 4), 0), + ((128, 28), -1), + ((128, 32), -1), + ((128, 64), -1), + ((128, 68), -1), + ((256, 32), -1), + ((160, 40), -1), + ((280, 20), -1), + ], +) +@pytest.mark.parametrize("in_dtype", [torch.bfloat16, torch.float32]) +@pytest.mark.parametrize("out_dtype", [torch.float8_e4m3fn, torch.float8_e5m2]) +@pytest.mark.parametrize("SCALE_ROUNDING_MODE", [0, 1]) +def test_downcast_to_mxfp8(shape, axis, in_dtype, out_dtype, SCALE_ROUNDING_MODE): + torch.cuda.empty_cache() + torch.manual_seed(20) + x = torch.randn(*shape, dtype=in_dtype, device="cuda") + + out_torch, out_scale_torch = torch_downcast_to_mxfp8( + x, out_dtype, axis, SCALE_ROUNDING_MODE + ) + out_triton, out_scale_triton = downcast_to_mxfp8( + x, out_dtype, axis, SCALE_ROUNDING_MODE + ) + out_triton = out_triton.to(torch.float32) + out_scale_triton = upcast_scale(out_scale_triton) + + torch.testing.assert_close(out_triton, out_torch, atol=0.01, rtol=0.01) + torch.testing.assert_close(out_scale_triton, out_scale_torch, atol=0.01, rtol=0.01) + + +@pytest.mark.parametrize( + "shape, axis", + [ + ((1, 4), -1), + ((1, 28), -1), + ((1, 32), -1), + ((1, 64), -1), + ((1, 68), -1), + ((2, 4), -1), + ((2, 28), -1), + ((2, 32), -1), + ((2, 200, 64), 1), + ((2, 68), -1), + ((128, 4), 0), + ((128, 28), -1), + ((128, 32), -1), + ((128, 64), -1), + ((128, 68), -1), + ((256, 32), -1), + ((160, 40), -1), + ((280, 20), -1), + ], +) +@pytest.mark.parametrize("in_dtype", [torch.float8_e4m3fn, torch.float8_e5m2]) +@pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float32]) +def test_upcast_from_mxfp8(shape, axis, in_dtype, out_dtype): + torch.cuda.empty_cache() + torch.manual_seed(20) + x = torch.randn(*shape, dtype=out_dtype, device="cuda") + x, x_scale = downcast_to_mxfp8(x, in_dtype, axis) + out_triton = upcast_from_mxfp8(x, x_scale, out_dtype, axis) + out_torch = torch_upcast_from_mxfp8(x, x_scale, out_dtype, axis) + + torch.testing.assert_close(out_triton, out_torch, atol=0.01, rtol=0.01) From 10caa55e58e3509f907cd0605dc9e246dec0b59b Mon Sep 17 00:00:00 2001 From: Lukasz Burzawa Date: Fri, 8 May 2026 22:52:41 +0000 Subject: [PATCH 2/8] add attn bwd main kernel in flydsl --- .../flydsl/kernels/attn_bwd_mxfp8_gfx950.py | 1637 +++++++++++++++++ .../test_attn_bwd_mxfp8_gfx950.py | 181 ++ .../flydsl/bench_attn_bwd_mxfp8_gfx950.py | 202 ++ op_tests/op_benchmarks/flydsl/utils.py | 368 ++++ 4 files changed, 2388 insertions(+) create mode 100644 aiter/ops/flydsl/kernels/attn_bwd_mxfp8_gfx950.py create mode 100644 op_tests/flydsl_tests/test_attn_bwd_mxfp8_gfx950.py create mode 100644 op_tests/op_benchmarks/flydsl/bench_attn_bwd_mxfp8_gfx950.py create mode 100644 op_tests/op_benchmarks/flydsl/utils.py diff --git a/aiter/ops/flydsl/kernels/attn_bwd_mxfp8_gfx950.py b/aiter/ops/flydsl/kernels/attn_bwd_mxfp8_gfx950.py new file mode 100644 index 0000000000..ea2960b746 --- /dev/null +++ b/aiter/ops/flydsl/kernels/attn_bwd_mxfp8_gfx950.py @@ -0,0 +1,1637 @@ +"""Attn bwd kernel using the @flyc.kernel API.""" + +import flydsl.compiler as flyc +import flydsl.expr as fx +from flydsl.compiler.kernel_function import CompilationContext + +from flydsl.expr import range_constexpr +from flydsl.runtime.device import get_rocm_arch as get_hip_arch +from flydsl.utils.smem_allocator import SmemAllocator, SmemPtr + +from flydsl._mlir import ir +from flydsl._mlir.dialects import scf + +from flydsl.expr import arith, vector, math as fx_math, const_expr +from flydsl.expr import gpu +from flydsl.expr import buffer_ops, rocdl +from flydsl.expr.typing import T + +from aiter.ops.flydsl.kernels.mfma_preshuffle_pipeline import ( + buffer_copy_gmem16_dwordx4, + tile_chunk_coord_i32, + swizzle_xor16, +) + +def lds_transpose_load(lds_memref, elem_offset): + """Transpose-load from LDS memref via ds_read_tr8_b64 (gfx950). + + Args: + lds_memref: LDS memref value (address-space 3), typically from + ``SmemPtr.get()`` or ``get_op_result_or_value(...)``. + elem_offset: Per-lane linearized element offset into the memref + (ArithValue / ir.Value of index type / Python int). + + Returns: + Loaded and transposed vector ``ir.Value``. + """ + from flydsl._mlir.dialects import llvm, memref + from flydsl.expr.arith import _to_raw + from flydsl.expr.utils.arith import ArithValue as AV + + lds_ptr_ty = ir.Type.parse("!llvm.ptr<3>") + raw_memref = arith.unwrap(lds_memref) + lds_base = memref.extract_aligned_pointer_as_index(raw_memref) + + byte_off = AV(arith.unwrap(elem_offset, index=True)) + total_byte_idx = AV(lds_base) + byte_off + addr_i32 = _to_raw(arith.index_cast(T.i32, total_byte_idx)) + ptr_val = llvm.inttoptr(lds_ptr_ty, addr_i32) + + result_type=T.i32x2 + result = llvm.call_intrinsic(result_type, "llvm.amdgcn.ds.read.tr8.b64", [ptr_val], [], []) + return result + + +def compile_attn_bwd_mxfp8_gfx950( + *, + seqlen: int, + head_dim: int, + tile_m: int, + tile_n: int, + tile_head: int, + sm_scale: float, + causal: bool = False, + waves_per_eu: int = None, +): + """Compile the attention backward mx8 kernel using the @flyc.kernel API. + + Returns a JitFunction that auto-compiles and executes when called. + Compile-time constants: seqlen, head_dim, tile_m/n/head + Runtime parameters: batch + + """ + + elem_bytes = 1 + tile_head_mx = tile_head // 32 + tile_m_mx = tile_m // 32 + tile_n_mx = tile_n // 32 + + gpu_arch = get_hip_arch() + + allocator_pong = SmemAllocator(None, arch=gpu_arch, global_sym_name="smem0") + allocator_ping = SmemAllocator(None, arch=gpu_arch, global_sym_name="smem1") + allocator_k_quant_head = SmemAllocator(None, arch=gpu_arch, global_sym_name="smem_k_quant_head") + allocator_k_scale_head = SmemAllocator(None, arch=gpu_arch, global_sym_name="smem_k_scale_head") + allocator_k_quant_n = SmemAllocator(None, arch=gpu_arch, global_sym_name="smem_k_quant_n") + allocator_k_scale_n = SmemAllocator(None, arch=gpu_arch, global_sym_name="smem_k_scale_n") + allocator_v = SmemAllocator(None, arch=gpu_arch, global_sym_name="smem_v") + allocator_v_scale = SmemAllocator(None, arch=gpu_arch, global_sym_name="smem_v_scale") + allocator_ppt_shuffle = SmemAllocator(None, arch=gpu_arch, global_sym_name="smem_ppt_shuffle") + allocator_ppt_scale_shuffle = SmemAllocator(None, arch=gpu_arch, global_sym_name="smem_ppt_scale_shuffle") + allocator_dst_shuffle = SmemAllocator(None, arch=gpu_arch, global_sym_name="smem_dst_shuffle") + allocator_dst_scale_shuffle = SmemAllocator(None, arch=gpu_arch, global_sym_name="smem_dst_scale_shuffle") + allocator_ds_shuffle = SmemAllocator(None, arch=gpu_arch, global_sym_name="smem_ds_shuffle") + allocator_ds_scale_shuffle = SmemAllocator(None, arch=gpu_arch, global_sym_name="smem_ds_scale_shuffle") + + num_waves = 4 + wave_size = 64 + total_threads = 256 + + bytes_per_tile_qo = int(tile_m) * int(tile_head) + bytes_per_thread_qo = bytes_per_tile_qo // total_threads + qo_load_bytes = 16 + + bytes_per_tile_kv = int(tile_n) * int(tile_head) + bytes_per_thread_kv = bytes_per_tile_kv // total_threads + kv_load_bytes = 16 + + bytes_per_tile_qo_scale = (int(tile_m) * int(tile_head)) // 32 + bytes_per_thread_qo_scale = max(1, bytes_per_tile_qo_scale // total_threads) + + bytes_per_tile_kv_scale = (int(tile_n) * int(tile_head)) // 32 + bytes_per_thread_kv_scale = max(1, bytes_per_tile_kv_scale // total_threads) + + def _elem_type(): + return T.f8 + + def _vec16_type(): + return T.f8x16 + + # ── LDS sizing (pure Python, no MLIR ops) ──────────────────────────────── + lds_qo_tile_bytes = int(tile_m) * int(tile_head) + lds_qo_scale_tile_bytes = (int(tile_m) * int(tile_head)) // 32 + lds_k_tile_bytes = int(tile_n) * int(tile_head) + lds_k_scale_head_tile_bytes = int(tile_n) * int(tile_head_mx) + lds_k_scale_n_tile_bytes = int(tile_n_mx) * int(tile_head) + lds_v_tile_bytes = int(tile_n) * int(tile_head) + lds_v_scale_tile_bytes = int(tile_n) * int(tile_head_mx) + lds_ppt_tile_bytes = int(tile_n) * int(tile_m) + lds_ppt_scale_tile_bytes = int(tile_n) * int(tile_m_mx) + lds_dst_tile_bytes = int(tile_n) * int(tile_m) + lds_dst_scale_tile_bytes = int(tile_n) * int(tile_m_mx) + lds_ds_tile_bytes = int(tile_m) * int(tile_n) + lds_ds_scale_tile_bytes = int(tile_m) * int(tile_n_mx) + + buffer_size_bytes = lds_qo_tile_bytes * 4 + lds_qo_scale_tile_bytes * 4 + + lds_pong_offset = allocator_pong._align(allocator_pong.ptr, 16) + allocator_pong.ptr = lds_pong_offset + buffer_size_bytes + lds_q_quant_head_pong_offset = lds_pong_offset + lds_q_scale_head_pong_offset = lds_q_quant_head_pong_offset + lds_qo_tile_bytes + lds_q_quant_m_pong_offset = lds_q_scale_head_pong_offset + lds_qo_scale_tile_bytes + lds_q_scale_m_pong_offset = lds_q_quant_m_pong_offset + lds_qo_tile_bytes + lds_do_quant_head_pong_offset = lds_q_scale_m_pong_offset + lds_qo_scale_tile_bytes + lds_do_scale_head_pong_offset = lds_do_quant_head_pong_offset + lds_qo_tile_bytes + lds_do_quant_m_pong_offset = lds_do_scale_head_pong_offset + lds_qo_scale_tile_bytes + lds_do_scale_m_pong_offset = lds_do_quant_m_pong_offset + lds_qo_tile_bytes + + lds_ping_offset = allocator_ping._align(allocator_ping.ptr, 16) + allocator_ping.ptr = lds_ping_offset + buffer_size_bytes + lds_q_quant_head_ping_offset = lds_ping_offset + lds_q_scale_head_ping_offset = lds_q_quant_head_ping_offset + lds_qo_tile_bytes + lds_q_quant_m_ping_offset = lds_q_scale_head_ping_offset + lds_qo_scale_tile_bytes + lds_q_scale_m_ping_offset = lds_q_quant_m_ping_offset + lds_qo_tile_bytes + lds_do_quant_head_ping_offset = lds_q_scale_m_ping_offset + lds_qo_scale_tile_bytes + lds_do_scale_head_ping_offset = lds_do_quant_head_ping_offset + lds_qo_tile_bytes + lds_do_quant_m_ping_offset = lds_do_scale_head_ping_offset + lds_qo_scale_tile_bytes + lds_do_scale_m_ping_offset = lds_do_quant_m_ping_offset + lds_qo_tile_bytes + + lds_k_quant_head_offset = allocator_k_quant_head._align(allocator_k_quant_head.ptr, 16) + allocator_k_quant_head.ptr = lds_k_quant_head_offset + lds_k_tile_bytes + + lds_k_scale_head_offset = allocator_k_scale_head._align(allocator_k_scale_head.ptr, 16) + allocator_k_scale_head.ptr = lds_k_scale_head_offset + lds_k_scale_head_tile_bytes + + lds_k_quant_n_offset = allocator_k_quant_n._align(allocator_k_quant_n.ptr, 16) + allocator_k_quant_n.ptr = lds_k_quant_n_offset + lds_k_tile_bytes + + lds_k_scale_n_offset = allocator_k_scale_n._align(allocator_k_scale_n.ptr, 16) + allocator_k_scale_n.ptr = lds_k_scale_n_offset + lds_k_scale_n_tile_bytes + + lds_v_offset = allocator_v._align(allocator_v.ptr, 16) + allocator_v.ptr = lds_v_offset + lds_v_tile_bytes + + lds_v_scale_offset = allocator_v_scale._align(allocator_v_scale.ptr, 16) + allocator_v_scale.ptr = lds_v_scale_offset + lds_v_scale_tile_bytes + + lds_ppt_shuffle_offset = allocator_ppt_shuffle._align(allocator_ppt_shuffle.ptr, 16) + allocator_ppt_shuffle.ptr = lds_ppt_shuffle_offset + lds_ppt_tile_bytes + + lds_ppt_scale_shuffle_offset = allocator_ppt_scale_shuffle._align(allocator_ppt_scale_shuffle.ptr, 16) + allocator_ppt_scale_shuffle.ptr = lds_ppt_scale_shuffle_offset + lds_ppt_scale_tile_bytes + + lds_dst_shuffle_offset = allocator_dst_shuffle._align(allocator_dst_shuffle.ptr, 16) + allocator_dst_shuffle.ptr = lds_dst_shuffle_offset + lds_dst_tile_bytes + + lds_dst_scale_shuffle_offset = allocator_dst_scale_shuffle._align(allocator_dst_scale_shuffle.ptr, 16) + allocator_dst_scale_shuffle.ptr = lds_dst_scale_shuffle_offset + lds_dst_scale_tile_bytes + + lds_ds_shuffle_offset = allocator_ds_shuffle._align(allocator_ds_shuffle.ptr, 16) + allocator_ds_shuffle.ptr = lds_ds_shuffle_offset + lds_ds_tile_bytes + + lds_ds_scale_shuffle_offset = allocator_ds_scale_shuffle._align(allocator_ds_scale_shuffle.ptr, 16) + allocator_ds_scale_shuffle.ptr = lds_ds_scale_shuffle_offset + lds_ds_scale_tile_bytes + + # ── Kernel function ──────────────────────────────────────────────────── + @flyc.kernel + def kernel_attn_bwd( + arg_dq: fx.Tensor, + arg_dk: fx.Tensor, + arg_dv: fx.Tensor, + arg_q_quant_head: fx.Tensor, + arg_q_scale_head: fx.Tensor, + arg_q_quant_m: fx.Tensor, + arg_q_scale_m: fx.Tensor, + arg_k_quant_head: fx.Tensor, + arg_k_scale_head: fx.Tensor, + arg_k_quant_n: fx.Tensor, + arg_k_scale_n: fx.Tensor, + arg_v: fx.Tensor, + arg_v_scale: fx.Tensor, + arg_do_quant_head: fx.Tensor, + arg_do_scale_head: fx.Tensor, + arg_do_quant_m: fx.Tensor, + arg_do_scale_m: fx.Tensor, + arg_m: fx.Tensor, + arg_D: fx.Tensor, + batch: fx.Int32 + ): + + # ---- Types ---- + zero_f = arith.constant(0.0, type=T.f32) + acc_init = arith.constant_vector(0.0, T.f32x4) + log2e = arith.constant(1.4426950408889634, type=T.f32) + c_sm_scale = arith.constant(sm_scale, type=T.f32) + fp8_max_rcp = arith.constant(1.0 / 448.0, type=T.f32) + + tx = gpu.thread_id("x") + bx = gpu.block_id("x") + by = gpu.block_id("y") + + # ---- LDS (separate ping/pong buffers) ---- + base_ptr_pong = allocator_pong.get_base() + base_ptr_ping = allocator_ping.get_base() + base_ptr_k_quant_head = allocator_k_quant_head.get_base() + base_ptr_k_scale_head = allocator_k_scale_head.get_base() + base_ptr_k_quant_n = allocator_k_quant_n.get_base() + base_ptr_k_scale_n = allocator_k_scale_n.get_base() + base_ptr_v = allocator_v.get_base() + base_ptr_v_scale = allocator_v_scale.get_base() + base_ptr_ppt_shuffle = allocator_ppt_shuffle.get_base() + base_ptr_ppt_scale_shuffle = allocator_ppt_scale_shuffle.get_base() + base_ptr_dst_shuffle = allocator_dst_shuffle.get_base() + base_ptr_dst_scale_shuffle = allocator_dst_scale_shuffle.get_base() + base_ptr_ds_shuffle = allocator_ds_shuffle.get_base() + base_ptr_ds_scale_shuffle = allocator_ds_scale_shuffle.get_base() + + lds_q_quant_head_pong = SmemPtr( + base_ptr_pong, lds_q_quant_head_pong_offset, T.f8, shape=(tile_m * tile_head,) + ).get() + lds_q_quant_head_ping = SmemPtr( + base_ptr_ping, lds_q_quant_head_ping_offset, T.f8, shape=(tile_m * tile_head,) + ).get() + lds_q_scale_head_pong = SmemPtr( + base_ptr_pong, lds_q_scale_head_pong_offset, T.i8, shape=(tile_m * tile_head_mx,) + ).get() + lds_q_scale_head_ping = SmemPtr( + base_ptr_ping, lds_q_scale_head_ping_offset, T.i8, shape=(tile_m * tile_head_mx,) + ).get() + lds_q_quant_m_pong = SmemPtr( + base_ptr_pong, lds_q_quant_m_pong_offset, T.f8, shape=(tile_m * tile_head,) + ).get() + lds_q_quant_m_ping = SmemPtr( + base_ptr_ping, lds_q_quant_m_ping_offset, T.f8, shape=(tile_m * tile_head,) + ).get() + lds_q_scale_m_pong = SmemPtr( + base_ptr_pong, lds_q_scale_m_pong_offset, T.i8, shape=(tile_m_mx * tile_head,) + ).get() + lds_q_scale_m_ping = SmemPtr( + base_ptr_ping, lds_q_scale_m_ping_offset, T.i8, shape=(tile_m_mx * tile_head,) + ).get() + lds_do_quant_head_pong = SmemPtr( + base_ptr_pong, lds_do_quant_head_pong_offset, T.f8, shape=(tile_m * tile_head,) + ).get() + lds_do_quant_head_ping = SmemPtr( + base_ptr_ping, lds_do_quant_head_ping_offset, T.f8, shape=(tile_m * tile_head,) + ).get() + lds_do_scale_head_pong = SmemPtr( + base_ptr_pong, lds_do_scale_head_pong_offset, T.i8, shape=(tile_m * tile_head_mx,) + ).get() + lds_do_scale_head_ping = SmemPtr( + base_ptr_ping, lds_do_scale_head_ping_offset, T.i8, shape=(tile_m * tile_head_mx,) + ).get() + lds_do_quant_m_pong = SmemPtr( + base_ptr_pong, lds_do_quant_m_pong_offset, T.f8, shape=(tile_m * tile_head,) + ).get() + lds_do_quant_m_ping = SmemPtr( + base_ptr_ping, lds_do_quant_m_ping_offset, T.f8, shape=(tile_m * tile_head,) + ).get() + lds_do_scale_m_pong = SmemPtr( + base_ptr_pong, lds_do_scale_m_pong_offset, T.i8, shape=(tile_head * tile_m_mx,) + ).get() + lds_do_scale_m_ping = SmemPtr( + base_ptr_ping, lds_do_scale_m_ping_offset, T.i8, shape=(tile_head * tile_m_mx,) + ).get() + lds_k_quant_head = SmemPtr( + base_ptr_k_quant_head, lds_k_quant_head_offset, T.f8, shape=(tile_n * tile_head,) + ).get() + lds_k_scale_head = SmemPtr( + base_ptr_k_scale_head, lds_k_scale_head_offset, T.i8, shape=(tile_n * tile_head_mx,) + ).get() + lds_k_quant_n = SmemPtr( + base_ptr_k_quant_n, lds_k_quant_n_offset, T.f8, shape=(tile_n * tile_head,) + ).get() + lds_k_scale_n = SmemPtr( + base_ptr_k_scale_n, lds_k_scale_n_offset, T.i8, shape=(tile_n_mx * tile_head,) + ).get() + lds_v = SmemPtr( + base_ptr_v, lds_v_offset, T.f8, shape=(tile_n * tile_head,) + ).get() + lds_v_scale = SmemPtr( + base_ptr_v_scale, lds_v_scale_offset, T.i8, shape=(tile_n * tile_head_mx,) + ).get() + lds_ppt_shuffle = SmemPtr( + base_ptr_ppt_shuffle, lds_ppt_shuffle_offset, T.f8, shape=(tile_n * tile_m,) + ).get() + lds_ppt_scale_shuffle = SmemPtr( + base_ptr_ppt_scale_shuffle, lds_ppt_scale_shuffle_offset, T.i8, shape=(tile_n * tile_m_mx,) + ).get() + lds_dst_shuffle = SmemPtr( + base_ptr_dst_shuffle, lds_dst_shuffle_offset, T.f8, shape=(tile_n * tile_m,) + ).get() + lds_dst_scale_shuffle = SmemPtr( + base_ptr_dst_scale_shuffle, lds_dst_scale_shuffle_offset, T.i8, shape=(tile_n * tile_m_mx,) + ).get() + lds_ds_shuffle = SmemPtr( + base_ptr_ds_shuffle, lds_ds_shuffle_offset, T.f8, shape=(tile_m * tile_n,) + ).get() + lds_ds_scale_shuffle = SmemPtr( + base_ptr_ds_scale_shuffle, lds_ds_scale_shuffle_offset, T.i8, shape=(tile_m * tile_n_mx,) + ).get() + + + # ---- Buffer resources (runtime byte sizes for OOB protection) ---- + head_dim_mx = head_dim // 32 + global_buffer_size = fx.Index(batch * seqlen * head_dim) + global_buffer_size_scale = fx.Index(batch * seqlen * head_dim_mx) + q_nrec = arith.index_cast(T.i64, global_buffer_size) + q_scale_head_nrec = arith.index_cast(T.i64, global_buffer_size_scale) + q_scale_m_nrec = arith.index_cast(T.i64, global_buffer_size_scale) + k_nrec = arith.index_cast(T.i64, global_buffer_size) + k_scale_head_nrec = arith.index_cast(T.i64, global_buffer_size_scale) + k_scale_n_nrec = arith.index_cast(T.i64, global_buffer_size_scale) + v_nrec = arith.index_cast(T.i64, global_buffer_size) + v_scale_nrec = arith.index_cast(T.i64, global_buffer_size_scale) + do_nrec = arith.index_cast(T.i64, global_buffer_size) + do_scale_head_nrec = arith.index_cast(T.i64, global_buffer_size_scale) + do_scale_m_nrec = arith.index_cast(T.i64, global_buffer_size_scale) + dq_nrec = arith.index_cast(T.i64, global_buffer_size * 4) + dk_nrec = arith.index_cast(T.i64, global_buffer_size * 2) + dv_nrec = arith.index_cast(T.i64, global_buffer_size * 2) + m_nrec = arith.index_cast(T.i64, fx.Index(batch * seqlen * 4)) + D_nrec = arith.index_cast(T.i64, fx.Index(batch * seqlen * 4)) + + q_quant_head_rsrc = buffer_ops.create_buffer_resource(arg_q_quant_head, max_size=False, num_records_bytes=q_nrec) + q_scale_head_rsrc = buffer_ops.create_buffer_resource(arg_q_scale_head, max_size=False, num_records_bytes=q_scale_head_nrec) + q_quant_m_rsrc = buffer_ops.create_buffer_resource(arg_q_quant_m, max_size=False, num_records_bytes=q_nrec) + q_scale_m_rsrc = buffer_ops.create_buffer_resource(arg_q_scale_m, max_size=False, num_records_bytes=q_scale_m_nrec) + k_quant_head_rsrc = buffer_ops.create_buffer_resource(arg_k_quant_head, max_size=False, num_records_bytes=k_nrec) + k_scale_head_rsrc = buffer_ops.create_buffer_resource(arg_k_scale_head, max_size=False, num_records_bytes=k_scale_head_nrec) + k_quant_n_rsrc = buffer_ops.create_buffer_resource(arg_k_quant_n, max_size=False, num_records_bytes=k_nrec) + k_scale_n_rsrc = buffer_ops.create_buffer_resource(arg_k_scale_n, max_size=False, num_records_bytes=k_scale_n_nrec) + v_rsrc = buffer_ops.create_buffer_resource(arg_v, max_size=False, num_records_bytes=v_nrec) + v_scale_rsrc = buffer_ops.create_buffer_resource(arg_v_scale, max_size=False, num_records_bytes=v_scale_nrec) + do_quant_head_rsrc = buffer_ops.create_buffer_resource(arg_do_quant_head, max_size=False, num_records_bytes=do_nrec) + do_scale_head_rsrc = buffer_ops.create_buffer_resource(arg_do_scale_head, max_size=False, num_records_bytes=do_scale_head_nrec) + do_quant_m_rsrc = buffer_ops.create_buffer_resource(arg_do_quant_m, max_size=False, num_records_bytes=do_nrec) + do_scale_m_rsrc = buffer_ops.create_buffer_resource(arg_do_scale_m, max_size=False, num_records_bytes=do_scale_m_nrec) + dq_rsrc = buffer_ops.create_buffer_resource(arg_dq, max_size=False, num_records_bytes=dq_nrec) + dk_rsrc = buffer_ops.create_buffer_resource(arg_dk, max_size=False, + num_records_bytes=dk_nrec) + dv_rsrc = buffer_ops.create_buffer_resource(arg_dv, max_size=False, + num_records_bytes=dv_nrec) + m_rsrc = buffer_ops.create_buffer_resource(arg_m, max_size=False, + num_records_bytes=m_nrec) + D_rsrc = buffer_ops.create_buffer_resource(arg_D, max_size=False, num_records_bytes=D_nrec) + + global_offset_n = bx * tile_n + global_offset_n_mx = global_offset_n // 32 + + # ---- Wave / lane decomposition ---- + layout_wave_lane = fx.make_layout((4, wave_size), (64, 1)) + coord_wave_lane = fx.idx2crd(tx, layout_wave_lane) + wave_id = fx.get(coord_wave_lane, 0) + lane_id = fx.get(coord_wave_lane, 1) + + layout_lane16 = fx.make_layout((4, 16), (16, 1)) + coord_lane16 = fx.idx2crd(lane_id, layout_lane16) + lane_div_16 = fx.get(coord_lane16, 0) + lane_mod_16 = fx.get(coord_lane16, 1) + + layout_lane2 = fx.make_layout((8, 2), (2, 1)) + coord_lane2 = fx.idx2crd(lane_mod_16, layout_lane2) + lane_div_2 = fx.get(coord_lane2, 0) + lane_mod_2 = fx.get(coord_lane2, 1) + + # wave partitioning for qk, p, dp, ds + ps_m_num_waves = 2 + ps_n_num_waves = 2 + ps_wave_layout = fx.make_layout((ps_m_num_waves, ps_n_num_waves), (ps_n_num_waves, 1)) + ps_coord = fx.idx2crd(wave_id, ps_wave_layout) + ps_m_wave_id = fx.get(ps_coord, 0) + ps_n_wave_id = fx.get(ps_coord, 1) + ps_m_per_wave = tile_m // ps_m_num_waves + ps_m_mx_per_wave = tile_m_mx // ps_m_num_waves + ps_m_num_subtiles = ps_m_per_wave // 16 + ps_n_per_wave = tile_n // ps_n_num_waves + ps_n_mx_per_wave = tile_n_mx // ps_n_num_waves + ps_n_num_subtiles = ps_n_per_wave // 16 + ps_n_accs = ps_n_num_subtiles * ps_m_num_subtiles + + # wave partitioning for dv gemm + dv_n_num_waves = 2 + dv_head_num_waves = 2 + dv_wave_layout = fx.make_layout((dv_n_num_waves, dv_head_num_waves), (dv_head_num_waves, 1)) + dv_coord = fx.idx2crd(wave_id, dv_wave_layout) + dv_n_wave_id = fx.get(dv_coord, 0) + dv_head_wave_id = fx.get(dv_coord, 1) + dv_n_per_wave = tile_n // dv_n_num_waves + dv_n_num_subtiles = dv_n_per_wave // 16 + dv_head_per_wave = tile_head // dv_head_num_waves + dv_head_num_subtiles = dv_head_per_wave // 16 + dv_n_accs = dv_n_num_subtiles * dv_head_num_subtiles + + # wave partitioning for dk gemm + dk_n_num_waves = 2 + dk_head_num_waves = 2 + dk_wave_layout = fx.make_layout((dk_n_num_waves, dk_head_num_waves), (dk_head_num_waves, 1)) + dk_coord = fx.idx2crd(wave_id, dk_wave_layout) + dk_n_wave_id = fx.get(dk_coord, 0) + dk_head_wave_id = fx.get(dk_coord, 1) + dk_n_per_wave = tile_n // dk_n_num_waves + dk_num_subtiles_n = dk_n_per_wave // 16 + dk_head_per_wave = tile_head // dk_head_num_waves + dk_num_subtiles_head = dk_head_per_wave // 16 + dk_n_accs = dk_num_subtiles_n * dk_num_subtiles_head + + # wave partitioning for dq gemm + dq_m_num_waves = 2 + dq_head_num_waves = 2 + dq_wave_layout = fx.make_layout((dq_m_num_waves, dq_head_num_waves), (dq_head_num_waves, 1)) + dq_coord = fx.idx2crd(wave_id, dq_wave_layout) + dq_m_wave_id = fx.get(dq_coord, 0) + dq_head_wave_id = fx.get(dq_coord, 1) + dq_m_per_wave = tile_m // dq_m_num_waves + dq_num_subtiles_m = dq_m_per_wave // 16 + dq_head_per_wave = tile_head // dq_head_num_waves + dq_num_subtiles_head = dq_head_per_wave // 16 + dq_n_accs = dq_num_subtiles_m * dq_num_subtiles_head + + # ── A LDS load helpers ── + + def lds_load_16b(curr_row_lds, col_base, lds_stride, lds_buffer, swizzle=16): + if swizzle == 16: + col_base = swizzle_xor16(curr_row_lds, col_base, lds_stride // swizzle) + idx = curr_row_lds * lds_stride + col_base + return vector.load_op(_vec16_type(), lds_buffer, [idx]) + + def lds_load_8b_transposed(curr_row_lds, col_base, lds_stride, lds_buffer, swizzle=16): + if swizzle == 16: + col_base = swizzle_xor16(curr_row_lds, col_base, lds_stride // swizzle) + col_base = col_base + lane_mod_2 * 8 + idx = curr_row_lds * lds_stride + col_base + return lds_transpose_load(lds_buffer, idx) + + def lds_load_packs_k64(curr_row_lds, col_base, lds_stride, lds_buffer, swizzle=16): + vec = lds_load_16b(curr_row_lds, col_base, lds_stride, lds_buffer, swizzle) + vec = vector.bitcast(T.i64x2, vec) + val0_i64 = vector.extract(vec, static_position=[0], dynamic_position=[]) + val1_i64 = vector.extract(vec, static_position=[1], dynamic_position=[]) + return val0_i64, val1_i64 + + def lds_load_packs_k32_transposed(curr_row_lds, col_base, lds_stride, lds_buffer, swizzle=16): + vec = lds_load_8b_transposed(curr_row_lds, col_base, lds_stride, lds_buffer, swizzle) + vec = vector.bitcast(T.vec(1, T.i64), vec) + val_i64 = vector.extract(vec, static_position=[0], dynamic_position=[]) + return val_i64 + + def lds_scale_load(row, col, lds_stride, lds_buffer): + idx = row * lds_stride + col + vec = vector.load_op(T.vec(1, T.i8), lds_buffer, [idx]) + val = vector.extract(vec, static_position=[0], dynamic_position=[]) + val = val.extui(T.i32) + return val + + + # ── A global→reg load ───────────────────────────────────────────── + head_dim_div4 = head_dim // 4 + tile_m_div16 = tile_m // 16 + tile_head_div16 = arith.index(tile_head // 16) + num_qo_loads = bytes_per_thread_qo // qo_load_bytes + num_kv_loads = bytes_per_thread_kv // kv_load_bytes + tile_head_dwords = tile_head // 4 + layout_qo_tile_div4 = fx.make_layout((tile_m, tile_head_dwords), (tile_head_dwords, 1)) + layout_kv_tile_div4 = fx.make_layout((tile_n, tile_head_dwords), (tile_head_dwords, 1)) + c4 = fx.Index(4) + tx_i32_base = tx * c4 + batch_offset_tensor = by * seqlen * head_dim_div4 + batch_offset_scale = by * seqlen * head_dim_mx + batch_offset_mD = by * seqlen + batch_offset_dq = by * seqlen * head_dim + batch_offset_dkdv = by * seqlen * head_dim + + + def load_q_quant_head_16(idx_elem): + return buffer_copy_gmem16_dwordx4( + buffer_ops, vector, + elem_type=_elem_type(), + idx_i32=idx_elem, + rsrc=q_quant_head_rsrc, vec_elems=16, + elem_bytes=elem_bytes, + ) + + def load_q_quant_m_16(idx_elem): + return buffer_copy_gmem16_dwordx4( + buffer_ops, vector, + elem_type=_elem_type(), + idx_i32=idx_elem, + rsrc=q_quant_m_rsrc, vec_elems=16, + elem_bytes=elem_bytes, + ) + + def load_k_quant_head_16(idx_elem): + return buffer_copy_gmem16_dwordx4( + buffer_ops, vector, + elem_type=_elem_type(), + idx_i32=idx_elem, + rsrc=k_quant_head_rsrc, vec_elems=16, + elem_bytes=elem_bytes, + ) + + def load_k_quant_n_16(idx_elem): + return buffer_copy_gmem16_dwordx4( + buffer_ops, vector, + elem_type=_elem_type(), + idx_i32=idx_elem, + rsrc=k_quant_n_rsrc, vec_elems=16, + elem_bytes=elem_bytes, + ) + + def load_v_16(idx_elem): + return buffer_copy_gmem16_dwordx4( + buffer_ops, vector, + elem_type=_elem_type(), + idx_i32=idx_elem, + rsrc=v_rsrc, vec_elems=16, + elem_bytes=elem_bytes, + ) + + def load_do_quant_head_16(idx_elem): + return buffer_copy_gmem16_dwordx4( + buffer_ops, vector, + elem_type=_elem_type(), + idx_i32=idx_elem, + rsrc=do_quant_head_rsrc, vec_elems=16, + elem_bytes=elem_bytes, + ) + + def load_do_quant_m_16(idx_elem): + return buffer_copy_gmem16_dwordx4( + buffer_ops, vector, + elem_type=_elem_type(), + idx_i32=idx_elem, + rsrc=do_quant_m_rsrc, vec_elems=16, + elem_bytes=elem_bytes, + ) + + def qo_tile_chunk_coord_i32(i: int): + return tile_chunk_coord_i32( + arith, tx_i32_base=tx_i32_base, i=i, + total_threads=total_threads, layout_tile_div4=layout_qo_tile_div4, + ) + + def kv_tile_chunk_coord_i32(i: int): + return tile_chunk_coord_i32( + arith, tx_i32_base=tx_i32_base, i=i, + total_threads=total_threads, layout_tile_div4=layout_kv_tile_div4, + ) + + def prefetch_q_quant_head_tile(offset_m): + parts = [] + for i in range_constexpr(num_qo_loads): + row_q_local, col_q_local_i32 = qo_tile_chunk_coord_i32(i) + row_q_global = offset_m + row_q_local + idx_elem = batch_offset_tensor + row_q_global * head_dim_div4 + col_q_local_i32 + q_16B = load_q_quant_head_16(idx_elem) + parts.append(vector.bitcast(T.i32x4, q_16B)) + return parts + + def prefetch_q_quant_m_tile(offset_m): + parts = [] + for i in range_constexpr(num_qo_loads): + row_q_local, col_q_local_i32 = qo_tile_chunk_coord_i32(i) + row_q_global = offset_m + row_q_local + idx_elem = batch_offset_tensor + row_q_global * head_dim_div4 + col_q_local_i32 + q_16B = load_q_quant_m_16(idx_elem) + parts.append(vector.bitcast(T.i32x4, q_16B)) + return parts + + def prefetch_k_quant_head_tile(): + parts = [] + for i in range_constexpr(num_kv_loads): + row_k_local, col_k_local_i32 = kv_tile_chunk_coord_i32(i) + row_k_global = global_offset_n + row_k_local + idx_elem = batch_offset_tensor + row_k_global * head_dim_div4 + col_k_local_i32 + k_16B = load_k_quant_head_16(idx_elem) + parts.append(vector.bitcast(T.i32x4, k_16B)) + return parts + + def prefetch_k_quant_n_tile(): + parts = [] + for i in range_constexpr(num_kv_loads): + row_k_local, col_k_local_i32 = kv_tile_chunk_coord_i32(i) + row_k_global = global_offset_n + row_k_local + idx_elem = batch_offset_tensor + row_k_global * head_dim_div4 + col_k_local_i32 + k_16B = load_k_quant_n_16(idx_elem) + parts.append(vector.bitcast(T.i32x4, k_16B)) + return parts + + def prefetch_v_tile(): + parts = [] + for i in range_constexpr(num_kv_loads): + row_v_local, col_v_local_i32 = kv_tile_chunk_coord_i32(i) + row_v_global = global_offset_n + row_v_local + idx_elem = batch_offset_tensor + row_v_global * head_dim_div4 + col_v_local_i32 + v_16B = load_v_16(idx_elem) + parts.append(vector.bitcast(T.i32x4, v_16B)) + return parts + + def prefetch_do_quant_head_tile(offset_m): + parts = [] + for i in range_constexpr(num_qo_loads): + row_do_local, col_do_local_i32 = qo_tile_chunk_coord_i32(i) + row_do_global = offset_m + row_do_local + idx_elem = batch_offset_tensor + row_do_global * head_dim_div4 + col_do_local_i32 + do_16B = load_do_quant_head_16(idx_elem) + parts.append(vector.bitcast(T.i32x4, do_16B)) + return parts + + def prefetch_do_quant_m_tile(offset_m): + parts = [] + for i in range_constexpr(num_qo_loads): + row_do_local, col_do_local_i32 = qo_tile_chunk_coord_i32(i) + row_do_global = offset_m + row_do_local + idx_elem = batch_offset_tensor + row_do_global * head_dim_div4 + col_do_local_i32 + do_16B = load_do_quant_m_16(idx_elem) + parts.append(vector.bitcast(T.i32x4, do_16B)) + return parts + + def prefetch_q_scale_head_tile(offset_m): + vec_width = bytes_per_thread_qo_scale + if const_expr(vec_width == 1): + if const_expr(bytes_per_tile_qo_scale < total_threads): + idx_elem = batch_offset_scale + offset_m * head_dim_mx + tx % bytes_per_tile_qo_scale + else: + idx_elem = batch_offset_scale + offset_m * head_dim_mx + tx + vec = buffer_ops.buffer_load(q_scale_head_rsrc, idx_elem, vec_width=1, dtype=T.i8) + vec = vector.from_elements(T.vec(1, T.i8), [vec]) + else: # vec_width=2 + idx_elem = (batch_offset_scale + offset_m * head_dim_mx + tx * vec_width) // 2 + vec = buffer_ops.buffer_load(q_scale_head_rsrc, idx_elem, vec_width=1, dtype=T.i16) + vec = vector.from_elements(T.vec(1, T.i16), [vec]) + vec = vector.bitcast(T.i8x2, vec) + return vec + + def prefetch_q_scale_m_tile(offset_m): + vec_width = bytes_per_thread_qo_scale + if const_expr(vec_width == 1): + if const_expr(bytes_per_tile_qo_scale < total_threads): + idx_elem = batch_offset_scale + offset_m * head_dim + tx % bytes_per_tile_qo_scale + else: + idx_elem = batch_offset_scale + offset_m * head_dim + tx + vec = buffer_ops.buffer_load(q_scale_m_rsrc, idx_elem, vec_width=1, dtype=T.i8) + vec = vector.from_elements(T.vec(1, T.i8), [vec]) + else: # vec_width=2 + idx_elem = (batch_offset_scale + offset_m * head_dim + tx * vec_width) // 2 + vec = buffer_ops.buffer_load(q_scale_m_rsrc, idx_elem, vec_width=1, dtype=T.i16) + vec = vector.from_elements(T.vec(1, T.i16), [vec]) + vec = vector.bitcast(T.i8x2, vec) + return vec + + def prefetch_k_scale_head_tile(): + vec_width = bytes_per_thread_kv_scale + if const_expr(vec_width == 1): + if const_expr(bytes_per_tile_kv_scale < total_threads): + idx_elem = batch_offset_scale + global_offset_n * head_dim_mx + tx % bytes_per_tile_kv_scale + else: + idx_elem = batch_offset_scale + global_offset_n * head_dim_mx + tx + vec = buffer_ops.buffer_load(k_scale_head_rsrc, idx_elem, vec_width=1, dtype=T.i8) + vec = vector.from_elements(T.vec(1, T.i8), [vec]) + else: # vec_width=2 + idx_elem = (batch_offset_scale + global_offset_n * head_dim_mx + tx * vec_width) // 2 + vec = buffer_ops.buffer_load(k_scale_head_rsrc, idx_elem, vec_width=1, dtype=T.i16) + vec = vector.from_elements(T.vec(1, T.i16), [vec]) + vec = vector.bitcast(T.i8x2, vec) + return vec + + def prefetch_k_scale_n_tile(): + vec_width = bytes_per_thread_kv_scale + if const_expr(vec_width == 1): + if const_expr(bytes_per_tile_kv_scale < total_threads): + idx_elem = batch_offset_scale + global_offset_n_mx * head_dim + tx % bytes_per_tile_kv_scale + else: + idx_elem = batch_offset_scale + global_offset_n_mx * head_dim + tx + vec = buffer_ops.buffer_load(k_scale_n_rsrc, idx_elem, vec_width=1, dtype=T.i8) + vec = vector.from_elements(T.vec(1, T.i8), [vec]) + else: # vec_width=2 + idx_elem = (batch_offset_scale + global_offset_n_mx * head_dim + tx * vec_width) // 2 + vec = buffer_ops.buffer_load(k_scale_n_rsrc, idx_elem, vec_width=1, dtype=T.i16) + vec = vector.from_elements(T.vec(1, T.i16), [vec]) + vec = vector.bitcast(T.i8x2, vec) + return vec + + def prefetch_v_scale_tile(): + vec_width = bytes_per_thread_kv_scale + if const_expr(vec_width == 1): + if const_expr(bytes_per_tile_kv_scale < total_threads): + idx_elem = batch_offset_scale + global_offset_n * head_dim_mx + tx % bytes_per_tile_kv_scale + else: + idx_elem = batch_offset_scale + global_offset_n * head_dim_mx + tx + vec = buffer_ops.buffer_load(v_scale_rsrc, idx_elem, vec_width=1, dtype=T.i8) + vec = vector.from_elements(T.vec(1, T.i8), [vec]) + else: # vec_width=2 + idx_elem = (batch_offset_scale + global_offset_n * head_dim_mx + tx * vec_width) // 2 + vec = buffer_ops.buffer_load(v_scale_rsrc, idx_elem, vec_width=1, dtype=T.i16) + vec = vector.from_elements(T.vec(1, T.i16), [vec]) + vec = vector.bitcast(T.i8x2, vec) + return vec + + def prefetch_do_scale_head_tile(offset_m): + vec_width = bytes_per_thread_qo_scale + if const_expr(vec_width == 1): + if const_expr(bytes_per_tile_qo_scale < total_threads): + idx_elem = batch_offset_scale + offset_m * head_dim_mx + tx % bytes_per_tile_qo_scale + else: + idx_elem = batch_offset_scale + offset_m * head_dim_mx + tx + vec = buffer_ops.buffer_load(do_scale_head_rsrc, idx_elem, vec_width=1, dtype=T.i8) + vec = vector.from_elements(T.vec(1, T.i8), [vec]) + else: # vec_width=2 + idx_elem = (batch_offset_scale + offset_m * head_dim_mx + tx * vec_width) // 2 + vec = buffer_ops.buffer_load(do_scale_head_rsrc, idx_elem, vec_width=1, dtype=T.i16) + vec = vector.from_elements(T.vec(1, T.i16), [vec]) + vec = vector.bitcast(T.i8x2, vec) + return vec + + def prefetch_do_scale_m_tile(offset_m): + vec_width = bytes_per_thread_qo_scale + if const_expr(vec_width == 1): + if const_expr(bytes_per_tile_qo_scale < total_threads): + idx_elem = batch_offset_scale + offset_m * head_dim + tx % bytes_per_tile_qo_scale + else: + idx_elem = batch_offset_scale + offset_m * head_dim + tx + vec = buffer_ops.buffer_load(do_scale_m_rsrc, idx_elem, vec_width=1, dtype=T.i8) + vec = vector.from_elements(T.vec(1, T.i8), [vec]) + else: # vec_width=2 + idx_elem = (batch_offset_scale + offset_m * head_dim + tx * vec_width) // 2 + vec = buffer_ops.buffer_load(do_scale_m_rsrc, idx_elem, vec_width=1, dtype=T.i16) + vec = vector.from_elements(T.vec(1, T.i16), [vec]) + vec = vector.bitcast(T.i8x2, vec) + return vec + + def store_q_tile_to_lds(vec_q_parts, lds_buffer): + for i in range_constexpr(num_qo_loads): + row_q_local, col_q_local_i32 = qo_tile_chunk_coord_i32(i) + col_local_bytes = col_q_local_i32 * c4 + col_swz_bytes = swizzle_xor16(row_q_local, col_local_bytes, tile_head_div16) + col_swz = col_swz_bytes + idx0 = row_q_local * tile_head + col_swz + v16 = vector.bitcast(_vec16_type(), vec_q_parts[i]) + vector.store(v16, lds_buffer, [idx0]) + + def store_k_tile_to_lds(vec_k_parts, lds_buffer): + for i in range_constexpr(num_kv_loads): + row_k_local, col_k_local_i32 = kv_tile_chunk_coord_i32(i) + col_local_bytes = col_k_local_i32 * c4 + col_swz_bytes = swizzle_xor16(row_k_local, col_local_bytes, tile_head_div16) + col_swz = col_swz_bytes + idx0 = row_k_local * tile_head + col_swz + v16 = vector.bitcast(_vec16_type(), vec_k_parts[i]) + vector.store(v16, lds_buffer, [idx0]) + + def store_v_tile_to_lds(vec_v_parts, lds_buffer): + for i in range_constexpr(num_kv_loads): + row_v_local, col_v_local_i32 = kv_tile_chunk_coord_i32(i) + col_local_bytes = col_v_local_i32 * c4 + col_swz_bytes = swizzle_xor16(row_v_local, col_local_bytes, tile_head_div16) + col_swz = col_swz_bytes + idx0 = row_v_local * tile_head + col_swz + v16 = vector.bitcast(_vec16_type(), vec_v_parts[i]) + vector.store(v16, lds_buffer, [idx0]) + + def store_do_tile_to_lds(vec_do_parts, lds_buffer): + for i in range_constexpr(num_qo_loads): + row_do_local, col_do_local_i32 = qo_tile_chunk_coord_i32(i) + col_local_bytes = col_do_local_i32 * c4 + col_swz_bytes = swizzle_xor16(row_do_local, col_local_bytes, tile_head_div16) + col_swz = col_swz_bytes + idx0 = row_do_local * tile_head + col_swz + v16 = vector.bitcast(_vec16_type(), vec_do_parts[i]) + vector.store(v16, lds_buffer, [idx0]) + + + def store_q_scale_tile_to_lds(vec_scale, lds_buffer): + vec_width = bytes_per_thread_qo_scale + idx = tx * vec_width + if total_threads > bytes_per_tile_qo_scale: + idx = idx % bytes_per_tile_qo_scale + vector.store(vec_scale, lds_buffer, [idx]) + + def store_k_scale_tile_to_lds(vec_scale, lds_buffer): + vec_width = bytes_per_thread_kv_scale + idx = tx * vec_width + if total_threads > bytes_per_tile_kv_scale: + idx = idx % bytes_per_tile_kv_scale + vector.store(vec_scale, lds_buffer, [idx]) + + def store_v_scale_tile_to_lds(vec_scale, lds_buffer): + vec_width = bytes_per_thread_kv_scale + idx = tx * vec_width + if total_threads > bytes_per_tile_kv_scale: + idx = idx % bytes_per_tile_kv_scale + vector.store(vec_scale, lds_buffer, [idx]) + + def store_do_scale_tile_to_lds(vec_scale, lds_buffer): + vec_width = bytes_per_thread_qo_scale + idx = tx * vec_width + if total_threads > bytes_per_tile_qo_scale: + idx = idx % bytes_per_tile_qo_scale + vector.store(vec_scale, lds_buffer, [idx]) + + + # ── Compute tile (MFMA) ─────────────────────────────────────────── + + def pack_i64x4_to_i32x8(x0, x1, x2, x3): + vec4_i64 = T.vec(4, T.i64) + vec8_i32 = T.vec(8, T.i32) + v4 = vector.from_elements(vec4_i64, [x0, x1, x2, x3]) + return vector.bitcast(vec8_i32, v4) + + + def compute_qk(lds_a_buffer, lds_a_scale_buffer, lds_b_buffer, lds_b_scale_buffer): + # (m, head) @ (head, n) = (m, n) + + current_accs_list = [acc_init] * ps_n_accs + mfma_res_ty = T.f32x4 + + ku0 = 0 + ku1 = 1 + lds_col0 = ku0 * 64 + lane_div_16 * 16 # 16 elements packed per lane, 64 per wave + lds_col1 = ku1 * 64 + lane_div_16 * 16 + lds_scale_col = lane_div_16 + if const_expr(tile_head == 64): + lds_scale_col = lds_scale_col % 2 + + for mi in range_constexpr(ps_m_num_subtiles): + lds_a_row = ps_m_wave_id * ps_m_per_wave + mi * 16 + lane_mod_16 + a0, a1 = lds_load_packs_k64(lds_a_row, lds_col0, tile_head, lds_a_buffer) + if const_expr(tile_head == 128): + a2, a3 = lds_load_packs_k64(lds_a_row, lds_col1, tile_head, lds_a_buffer) + else: + a2 = a3 = fx.Int64(0) + a128 = pack_i64x4_to_i32x8(a0, a1, a2, a3) + + lds_a_scale_row = lds_a_row + a_scale = lds_scale_load(lds_a_scale_row, lds_scale_col, tile_head_mx, lds_a_scale_buffer) + + for ni in range_constexpr(ps_n_num_subtiles): + lds_b_row = ps_n_wave_id * ps_n_per_wave + ni * 16 + lane_mod_16 + b0, b1 = lds_load_packs_k64(lds_b_row, lds_col0, tile_head, lds_b_buffer) + if const_expr(tile_head == 128): + b2, b3 = lds_load_packs_k64(lds_b_row, lds_col1, tile_head, lds_b_buffer) + else: + b2 = b3 = fx.Int64(0) + b128 = pack_i64x4_to_i32x8(b0, b1, b2, b3) + b_scale = lds_scale_load(lds_b_row, lds_scale_col, tile_head_mx, lds_b_scale_buffer) + + #fx.printf("ni={}, mi={}", ni, mi) + acc_idx = mi * ps_n_num_subtiles + ni + current_accs_list[acc_idx] = rocdl.mfma_scale_f32_16x16x128_f8f6f4( + mfma_res_ty, + [a128, b128, current_accs_list[acc_idx], + 0, 0, 0, a_scale, 0, b_scale], + ) + return current_accs_list + + + def softmax(accs_in, offset_m): + # inputs are tile_m x tile_n shape + + accs_out = [acc_init] * ps_n_accs + + for mi in range_constexpr(ps_m_num_subtiles): + global_m_norm_idx = batch_offset_mD + offset_m + ps_m_wave_id * ps_m_per_wave + mi * 16 + lane_div_16 * 4 #+ ii + m_norm_vector = buffer_ops.buffer_load(m_rsrc, global_m_norm_idx, vec_width=4) + + for ni in range_constexpr(ps_n_num_subtiles): + + acc_idx = mi * ps_n_num_subtiles + ni + acc = accs_in[acc_idx] + + vals_f32 = [] + for ii in range_constexpr(4): + val_f32 = vector.extract(acc, static_position=[ii], dynamic_position=[]) + m_norm = vector.extract(m_norm_vector, static_position=[ii], dynamic_position=[]) + val_f32 = val_f32 * c_sm_scale + val_f32 = val_f32 - m_norm + val_f32 = val_f32 * log2e + val_f32 = rocdl.exp2(T.f32, val_f32) + if causal: + global_m = offset_m + ps_m_wave_id * ps_m_per_wave + mi * 16 + lane_div_16 * 4 + ii + global_n = global_offset_n + ps_n_wave_id * ps_n_per_wave + ni * 16 + lane_mod_16 + needs_mask = arith.cmpi(arith.CmpIPredicate.ugt, global_n, global_m) + mask_if = scf.IfOp(needs_mask, [T.f32], has_else=True) + with ir.InsertionPoint(mask_if.then_block): + scf.YieldOp([arith.constant(0.0, type=T.f32)]) + with ir.InsertionPoint(mask_if.else_block): + scf.YieldOp([val_f32]) + val_f32 = mask_if.results[0] + vals_f32.append(val_f32) + vals_f32_vector = vector.from_elements(T.f32x4, vals_f32) + accs_out[acc_idx] = vals_f32_vector + + return accs_out + + + def compute_dv(accs_in, lds_a_buffer, lds_a_scale_buffer, lds_b_buffer, lds_b_scale_buffer): + current_accs_list = list(accs_in) + mfma_res_ty = T.f32x4 + num_subtiles_reduction = max(1, tile_m // 128) + for ku128 in range_constexpr(num_subtiles_reduction): + ku0 = ku128 * 2 + ku1 = ku0 + 1 + + lds_a_col0 = ku0 * 64 + lane_div_16 * 16 # 16 elements packed per lane, 64 per wave + lds_a_col1 = ku1 * 64 + lane_div_16 * 16 + + lds_b_row0 = ku0 * 64 + lane_div_16 * 16 + lane_div_2 + lds_b_row1 = ku0 * 64 + lane_div_16 * 16 + 8 + lane_div_2 + lds_b_row2 = ku1 * 64 + lane_div_16 * 16 + lane_div_2 + lds_b_row3 = ku1 * 64 + lane_div_16 * 16 + 8 + lane_div_2 + + lds_a_scale_col = lane_div_16 + lds_b_scale_row = lane_div_16 + if const_expr(tile_m == 64): + lds_a_scale_col = lds_a_scale_col % 2 + lds_b_scale_row = lds_b_scale_row % 2 + + for ni in range_constexpr(dv_n_num_subtiles): + lds_a_row = dv_n_wave_id * dv_n_per_wave + ni * 16 + lane_mod_16 + a0, a1 = lds_load_packs_k64(lds_a_row, lds_a_col0, tile_m, lds_a_buffer) + if const_expr(tile_m == 128): + a2, a3 = lds_load_packs_k64(lds_a_row, lds_a_col1, tile_m, lds_a_buffer) + else: + a2 = a3 = fx.Int64(0) + a128 = pack_i64x4_to_i32x8(a0, a1, a2, a3) + a_scale = lds_scale_load(lds_a_row, lds_a_scale_col, tile_m_mx, lds_a_scale_buffer) + + for hi in range_constexpr(dv_head_num_subtiles): + lds_b_col = dv_head_wave_id * dv_head_per_wave + hi * 16 + b0 = lds_load_packs_k32_transposed(lds_b_row0, lds_b_col, tile_head, lds_b_buffer) + b1 = lds_load_packs_k32_transposed(lds_b_row1, lds_b_col, tile_head, lds_b_buffer) + if const_expr(tile_m == 128): + b2 = lds_load_packs_k32_transposed(lds_b_row2, lds_b_col, tile_head, lds_b_buffer) + b3 = lds_load_packs_k32_transposed(lds_b_row3, lds_b_col, tile_head, lds_b_buffer) + else: + b2 = fx.Int64(0) + b3 = fx.Int64(0) + b128 = pack_i64x4_to_i32x8(b0, b1, b2, b3) + + lds_b_scale_col = dv_head_wave_id * dv_head_per_wave + hi * 16 + lane_mod_16 + b_scale = lds_scale_load(lds_b_scale_row, lds_b_scale_col, tile_head, lds_b_scale_buffer) + + acc_idx = ni * dv_head_num_subtiles + hi + current_accs_list[acc_idx] = rocdl.mfma_scale_f32_16x16x128_f8f6f4( + mfma_res_ty, + [a128, b128, current_accs_list[acc_idx], + 0, 0, 0, a_scale, 0, b_scale], + ) + return current_accs_list + + + def compute_dp(lds_a_buffer, lds_a_scale_buffer, lds_b_buffer, lds_b_scale_buffer): + current_accs_list = [acc_init] * ps_n_accs + mfma_res_ty = T.f32x4 + ku0 = 0 + ku1 = 1 + lds_col0 = ku0 * 64 + lane_div_16 * 16 # 16 elements packed per lane, 64 per wave + lds_col1 = ku1 * 64 + lane_div_16 * 16 + lds_scale_col = lane_div_16 + if const_expr(tile_head == 64): + lds_scale_col = lds_scale_col % 2 + + for mi in range_constexpr(ps_m_num_subtiles): + lds_a_row = ps_m_wave_id * ps_m_per_wave + mi * 16 + lane_mod_16 + a0, a1 = lds_load_packs_k64(lds_a_row, lds_col0, tile_head, lds_a_buffer) + if const_expr(tile_head == 128): + a2, a3 = lds_load_packs_k64(lds_a_row, lds_col1, tile_head, lds_a_buffer) + else: + a2 = a3 = fx.Int64(0) + a128 = pack_i64x4_to_i32x8(a0, a1, a2, a3) + a_scale = lds_scale_load(lds_a_row, lds_scale_col, tile_head_mx, lds_a_scale_buffer) + + for ni in range_constexpr(ps_n_num_subtiles): + lds_b_row = ps_n_wave_id * ps_n_per_wave + ni * 16 + lane_mod_16 + b0, b1 = lds_load_packs_k64(lds_b_row, lds_col0, tile_head, lds_b_buffer) + if const_expr(tile_head == 128): + b2, b3 = lds_load_packs_k64(lds_b_row, lds_col1, tile_head, lds_b_buffer) + else: + b2 = b3 = fx.Int64(0) + b128 = pack_i64x4_to_i32x8(b0, b1, b2, b3) + b_scale = lds_scale_load(lds_b_row, lds_scale_col, tile_head_mx, lds_b_scale_buffer) + + acc_idx = mi * ps_n_num_subtiles + ni + current_accs_list[acc_idx] = rocdl.mfma_scale_f32_16x16x128_f8f6f4( + mfma_res_ty, + [a128, b128, current_accs_list[acc_idx], + 0, 0, 0, a_scale, 0, b_scale], + ) + return current_accs_list + + + def compute_ds(dp_accs, p_accs, offset_m): + # inputs are tile_m x tile_n shape + + accs_out = [acc_init] * ps_n_accs + + for mi in range_constexpr(ps_m_num_subtiles): + + global_D_idx = batch_offset_mD + offset_m + ps_m_wave_id * ps_m_per_wave + mi * 16 + lane_div_16 * 4 + D_vector = buffer_ops.buffer_load(D_rsrc, global_D_idx, vec_width=4) + + for ni in range_constexpr(ps_n_num_subtiles): + + acc_idx = mi * ps_n_num_subtiles + ni + dp_f32x4 = dp_accs[acc_idx] + p_f32x4 = p_accs[acc_idx] + + vals_f32 = [] + for ii in range_constexpr(4): + dp_f32 = vector.extract(dp_f32x4, static_position=[ii], dynamic_position=[]) + p_f32 = vector.extract(p_f32x4, static_position=[ii], dynamic_position=[]) + D = vector.extract(D_vector, static_position=[ii], dynamic_position=[]) + ds_f32 = p_f32 * (dp_f32 - D) + vals_f32.append(ds_f32) + + vals_f32_vector = vector.from_elements(T.f32x4, vals_f32) + accs_out[acc_idx] = vals_f32_vector + + return accs_out + + + def wave_reduce_max_4threads(x): + width_i32 = arith.constant(64, type=T.i32) + w = x + for sh in [32, 16]: + off = arith.constant(sh, type=T.i32) + peer = w.shuffle_xor(off, width_i32) + w = w.maximumf(peer) + return w + + + def mxquant_m_and_store_to_lds(accs_in, lds_buffer, lds_buffer_scale): + # inputs are tile_m x tile_n shape + + for mi in range_constexpr(ps_m_num_subtiles // 2): + for ni in range_constexpr(ps_n_num_subtiles): + + acc_idx0 = (mi * 2) * ps_n_num_subtiles + ni + acc_idx1 = (mi * 2 + 1) * ps_n_num_subtiles + ni + acc0 = accs_in[acc_idx0] + acc1 = accs_in[acc_idx1] + + vals_subtile0 = [] + vals_subtile1 = [] + vals_abs = [] + for ii in range_constexpr(4): + val0 = vector.extract(acc0, static_position=[ii], dynamic_position=[]) + vals_subtile0.append(val0) + val1 = vector.extract(acc1, static_position=[ii], dynamic_position=[]) + vals_subtile1.append(val1) + val0_abs = fx_math.absf(val0) + val1_abs = fx_math.absf(val1) + vals_abs.append(val0_abs) + vals_abs.append(val1_abs) + + vals_abs_vector = vector.from_elements(T.vec(8, T.f32), vals_abs) + val_max = vector.reduction(T.f32, "maxnumf", vals_abs_vector) + val_max = wave_reduce_max_4threads(val_max) + val_max = val_max * fp8_max_rcp + val_max = arith.bitcast(T.i32, val_max) + val_max = val_max + arith.constant(0x007FFFFF, type=T.i32) + val_max = val_max & arith.constant(0x7F800000, type=T.i32) + val_max_f32 = arith.bitcast(T.f32, val_max) + val_max_rcp = arith.select(val_max_f32 == zero_f, zero_f, rocdl.rcp(T.f32, val_max_f32)) + scale = val_max >> 23 + scale = arith.trunci(T.i8, scale) + scale_vector = vector.from_elements(T.vec(1, T.i8), [scale]) + + for ii in range_constexpr(4): + vals_subtile0[ii] = vals_subtile0[ii] * val_max_rcp + vals_subtile1[ii] = vals_subtile1[ii] * val_max_rcp + + val_f8_packed_i32 = rocdl.cvt_pk_fp8_f32(T.i32, vals_subtile0[0], vals_subtile0[1], fx.Int32(0), False) + val_f8_packed_i32 = rocdl.cvt_pk_fp8_f32(T.i32, vals_subtile0[2], vals_subtile0[3], val_f8_packed_i32, True) + val_f8_packed_i32_vector = vector.from_elements(T.vec(1, T.i32), [val_f8_packed_i32]) + val_f8x4_subtile0 = vector.bitcast(T.f8x4, val_f8_packed_i32_vector) + + val_f8_packed_i32 = rocdl.cvt_pk_fp8_f32(T.i32, vals_subtile1[0], vals_subtile1[1], fx.Int32(0), False) + val_f8_packed_i32 = rocdl.cvt_pk_fp8_f32(T.i32, vals_subtile1[2], vals_subtile1[3], val_f8_packed_i32, True) + val_f8_packed_i32_vector = vector.from_elements(T.vec(1, T.i32), [val_f8_packed_i32]) + val_f8x4_subtile1 = vector.bitcast(T.f8x4, val_f8_packed_i32_vector) + + lds_row = ps_n_wave_id * ps_n_per_wave + ni * 16 + lane_mod_16 + lds_col_base0 = ps_m_wave_id * ps_m_per_wave + (mi * 2) * 16 #+ lane_div_16 * 4 + lds_col_base1 = ps_m_wave_id * ps_m_per_wave + (mi * 2 + 1) * 16 #+ lane_div_16 * 4 + lds_col0 = swizzle_xor16(lds_row, lds_col_base0, tile_m_div16) + lds_col1 = swizzle_xor16(lds_row, lds_col_base1, tile_m_div16) + lds_col0 = lds_col0 + lane_div_16 * 4 + lds_col1 = lds_col1 + lane_div_16 * 4 + lds_scale_col = ps_m_wave_id * ps_m_mx_per_wave + mi + lds_idx0 = lds_row * tile_m + lds_col0 + lds_idx1 = lds_row * tile_m + lds_col1 + lds_scale_idx = lds_row * tile_m_mx + lds_scale_col + + vector.store(val_f8x4_subtile0, lds_buffer, [lds_idx0]) + vector.store(val_f8x4_subtile1, lds_buffer, [lds_idx1]) + vector.store(scale_vector, lds_buffer_scale, [lds_scale_idx]) + + + def wave_reduce_max_16threads(x): + width_i32 = arith.constant(64, type=T.i32) + w = x + for sh in [8, 4, 2, 1]: + off = arith.constant(sh, type=T.i32) + peer = w.shuffle_xor(off, width_i32) + w = w.maximumf(peer) + return w + + + def mxquant_n_and_store_to_lds(accs_in, lds_buffer, lds_buffer_scale): + # inputs are tile_m x tile_n shape + + for mi in range_constexpr(ps_m_num_subtiles): + for ni in range_constexpr(ps_n_num_subtiles // 2): + + acc_idx0 = mi * ps_n_num_subtiles + ni * 2 + acc_idx1 = mi * ps_n_num_subtiles + ni * 2 + 1 + acc0 = accs_in[acc_idx0] + acc1 = accs_in[acc_idx1] + + vals_subtile0 = [] + vals_subtile1 = [] + scales = [] + for ii in range_constexpr(4): + val0 = vector.extract(acc0, static_position=[ii], dynamic_position=[]) + val1 = vector.extract(acc1, static_position=[ii], dynamic_position=[]) + val0_abs = fx_math.absf(val0) + val1_abs = fx_math.absf(val1) + val_max = arith.maximumf(val0_abs, val1_abs) + val_max = wave_reduce_max_16threads(val_max) + val_max = val_max * fp8_max_rcp + val_max = arith.bitcast(T.i32, val_max) + val_max = val_max + arith.constant(0x007FFFFF, type=T.i32) + val_max = val_max & arith.constant(0x7F800000, type=T.i32) + val_max_f32 = arith.bitcast(T.f32, val_max) + val_max_rcp = arith.select(val_max_f32 == zero_f, zero_f, rocdl.rcp(T.f32, val_max_f32)) + val0_quant = val0 * val_max_rcp + vals_subtile0.append(val0_quant) + val1_quant = val1 * val_max_rcp + vals_subtile1.append(val1_quant) + scale = val_max >> 23 + scale = arith.trunci(T.i8, scale) + scales.append(scale) + + val_f8_packed_i32 = rocdl.cvt_pk_fp8_f32(T.i32, vals_subtile0[0], vals_subtile0[1], fx.Int32(0), False) + val_f8_packed_i32 = rocdl.cvt_pk_fp8_f32(T.i32, vals_subtile0[2], vals_subtile0[3], val_f8_packed_i32, True) + val_f8_packed_i32_vector = vector.from_elements(T.vec(1, T.i32), [val_f8_packed_i32]) + val_f8x4_subtile0 = vector.bitcast(T.f8x4, val_f8_packed_i32_vector) + + val_f8_packed_i32 = rocdl.cvt_pk_fp8_f32(T.i32, vals_subtile1[0], vals_subtile1[1], fx.Int32(0), False) + val_f8_packed_i32 = rocdl.cvt_pk_fp8_f32(T.i32, vals_subtile1[2], vals_subtile1[3], val_f8_packed_i32, True) + val_f8_packed_i32_vector = vector.from_elements(T.vec(1, T.i32), [val_f8_packed_i32]) + val_f8x4_subtile1 = vector.bitcast(T.f8x4, val_f8_packed_i32_vector) + + lds_row0 = ps_n_wave_id * ps_n_per_wave + (ni * 2) * 16 + lane_mod_16 + lds_row1 = ps_n_wave_id * ps_n_per_wave + (ni * 2 + 1) * 16 + lane_mod_16 + lds_col_base = ps_m_wave_id * ps_m_per_wave + mi * 16 #+ lane_div_16 * 4 + lds_col0 = swizzle_xor16(lds_row0, lds_col_base, tile_m_div16) + lds_col1 = swizzle_xor16(lds_row1, lds_col_base, tile_m_div16) + lds_col0 = lds_col0 + lane_div_16 * 4 + lds_col1 = lds_col1 + lane_div_16 * 4 + lds_idx0 = lds_row0 * tile_m + lds_col0 + lds_idx1 = lds_row1 * tile_m + lds_col1 + vector.store(val_f8x4_subtile0, lds_buffer, [lds_idx0]) + vector.store(val_f8x4_subtile1, lds_buffer, [lds_idx1]) + + for ii in range_constexpr(4): + lds_scale_row = ps_m_wave_id * ps_m_per_wave + mi * 16 + lane_div_16 * 4 + ii + lds_scale_col = ps_n_wave_id * ps_n_mx_per_wave + ni + lds_scale_idx = lds_scale_row * tile_n_mx + lds_scale_col + + scale_vector = vector.from_elements(T.vec(1, T.i8), [scales[ii]]) + vector.store(scale_vector, lds_buffer_scale, [lds_scale_idx]) + + + def compute_dk(accs_in, lds_a_buffer, lds_a_scale_buffer, lds_b_buffer, lds_b_scale_buffer): + current_accs_list = list(accs_in) + mfma_res_ty = T.f32x4 + num_subtiles_reduction = max(1, tile_m // 128) + for ku128 in range_constexpr(num_subtiles_reduction): + ku0 = ku128 * 2 + ku1 = ku0 + 1 + + lds_a_col0 = ku0 * 64 + lane_div_16 * 16 # 16 elements packed per lane, 64 per wave + lds_a_col1 = ku1 * 64 + lane_div_16 * 16 + + lds_b_row0 = ku0 * 64 + lane_div_16 * 16 + lane_div_2 + lds_b_row1 = ku0 * 64 + lane_div_16 * 16 + 8 + lane_div_2 + lds_b_row2 = ku1 * 64 + lane_div_16 * 16 + lane_div_2 + lds_b_row3 = ku1 * 64 + lane_div_16 * 16 + 8 + lane_div_2 + + lds_a_scale_col = lane_div_16 + lds_b_scale_row = lane_div_16 + if tile_m == 64: + lds_a_scale_col = lds_a_scale_col % 2 + lds_b_scale_row = lds_b_scale_row % 2 + + for ni in range_constexpr(dk_num_subtiles_n): + lds_a_row = dk_n_wave_id * dk_n_per_wave + ni * 16 + lane_mod_16 + a0, a1 = lds_load_packs_k64(lds_a_row, lds_a_col0, tile_m, lds_a_buffer) + if const_expr(tile_m == 128): + a2, a3 = lds_load_packs_k64(lds_a_row, lds_a_col1, tile_m, lds_a_buffer) + else: + a2 = fx.Int64(0) + a3 = fx.Int64(0) + a128 = pack_i64x4_to_i32x8(a0, a1, a2, a3) + a_scale = lds_scale_load(lds_a_row, lds_a_scale_col, tile_m_mx, lds_a_scale_buffer) + + for hi in range_constexpr(dk_num_subtiles_head): + lds_b_col = dk_head_wave_id * dk_head_per_wave + hi * 16 #+ lane_mod_2 * 8 + b0 = lds_load_packs_k32_transposed(lds_b_row0, lds_b_col, tile_head, lds_b_buffer) + b1 = lds_load_packs_k32_transposed(lds_b_row1, lds_b_col, tile_head, lds_b_buffer) + if const_expr(tile_m == 128): + b2 = lds_load_packs_k32_transposed(lds_b_row2, lds_b_col, tile_head, lds_b_buffer) + b3 = lds_load_packs_k32_transposed(lds_b_row3, lds_b_col, tile_head, lds_b_buffer) + else: + b2 = fx.Int64(0) + b3 = fx.Int64(0) + b128 = pack_i64x4_to_i32x8(b0, b1, b2, b3) + + lds_b_scale_col = dk_head_wave_id * dk_head_per_wave + hi * 16 + lane_mod_16 + b_scale = lds_scale_load(lds_b_scale_row, lds_b_scale_col, tile_head, lds_b_scale_buffer) + + acc_idx = ni * dk_num_subtiles_head + hi + current_accs_list[acc_idx] = rocdl.mfma_scale_f32_16x16x128_f8f6f4( + mfma_res_ty, + [a128, b128, current_accs_list[acc_idx], + 0, 0, 0, a_scale, 0, b_scale], + ) + return current_accs_list + + + def compute_dq(lds_a_buffer, lds_a_scale_buffer, lds_b_buffer, lds_b_scale_buffer): + # (m, n) @ (n, head) = (m, head) + + current_accs_list = [acc_init] * dq_n_accs + mfma_res_ty = T.f32x4 + + num_subtiles_reduction = max(1, tile_n // 128) + for ku128 in range_constexpr(num_subtiles_reduction): + ku0 = ku128 * 2 + ku1 = ku0 + 1 + + lds_a_row0 = ku0 * 64 + lane_div_16 * 16 + lane_div_2 + lds_a_row1 = ku0 * 64 + lane_div_16 * 16 + 8 + lane_div_2 + lds_a_row2 = ku1 * 64 + lane_div_16 * 16 + lane_div_2 + lds_a_row3 = ku1 * 64 + lane_div_16 * 16 + 8 + lane_div_2 + + lds_b_row0 = ku0 * 64 + lane_div_16 * 16 + lane_div_2 + lds_b_row1 = ku0 * 64 + lane_div_16 * 16 + 8 + lane_div_2 + lds_b_row2 = ku1 * 64 + lane_div_16 * 16 + lane_div_2 + lds_b_row3 = ku1 * 64 + lane_div_16 * 16 + 8 + lane_div_2 + + lds_a_scale_col = lane_div_16 + lds_b_scale_row = lane_div_16 + if const_expr(tile_n == 64): + lds_a_scale_col = lds_a_scale_col % 2 + lds_b_scale_row = lds_b_scale_row % 2 + + for mi in range_constexpr(dq_num_subtiles_m): + lds_a_col = dq_m_wave_id * dq_m_per_wave + mi * 16 #+ lane_mod_2 * 8 + a0 = lds_load_packs_k32_transposed(lds_a_row0, lds_a_col, tile_m, lds_a_buffer) + a1 = lds_load_packs_k32_transposed(lds_a_row1, lds_a_col, tile_m, lds_a_buffer) + if const_expr(tile_n == 128): + a2 = lds_load_packs_k32_transposed(lds_a_row2, lds_a_col, tile_m, lds_a_buffer) + a3 = lds_load_packs_k32_transposed(lds_a_row3, lds_a_col, tile_m, lds_a_buffer) + else: + a2 = fx.Int64(0) + a3 = fx.Int64(0) + + a128 = pack_i64x4_to_i32x8(a0, a1, a2, a3) + lds_a_scale_row = dq_m_wave_id * dq_m_per_wave + mi * 16 + lane_mod_16 + a_scale = lds_scale_load(lds_a_scale_row, lds_a_scale_col, tile_n_mx, lds_a_scale_buffer) + + for hi in range_constexpr(dq_num_subtiles_head): + lds_b_col = dq_head_wave_id * dq_head_per_wave + hi * 16 #+ lane_mod_2 * 8 + b0 = lds_load_packs_k32_transposed(lds_b_row0, lds_b_col, tile_head, lds_b_buffer) + b1 = lds_load_packs_k32_transposed(lds_b_row1, lds_b_col, tile_head, lds_b_buffer) + if const_expr(tile_n == 128): + b2 = lds_load_packs_k32_transposed(lds_b_row2, lds_b_col, tile_head, lds_b_buffer) + b3 = lds_load_packs_k32_transposed(lds_b_row3, lds_b_col, tile_head, lds_b_buffer) + else: + b2 = fx.Int64(0) + b3 = fx.Int64(0) + b128 = pack_i64x4_to_i32x8(b0, b1, b2, b3) + + lds_b_scale_col = dq_head_wave_id * dq_head_per_wave + hi * 16 + lane_mod_16 + b_scale = lds_scale_load(lds_b_scale_row, lds_b_scale_col, tile_head, lds_b_scale_buffer) + + acc_idx = mi * dq_num_subtiles_head + hi + current_accs_list[acc_idx] = rocdl.mfma_scale_f32_16x16x128_f8f6f4( + mfma_res_ty, + [a128, b128, current_accs_list[acc_idx], + 0, 0, 0, a_scale, 0, b_scale], + ) + return current_accs_list + + + def store_dq_atomic(final_accs, offset_m): + for mi in range_constexpr(dq_num_subtiles_m): + for hi in range_constexpr(dq_num_subtiles_head): + for ii in range_constexpr(4): + global_row = offset_m + dq_m_wave_id * dq_m_per_wave + mi * 16 + lane_div_16 * 4 + ii + global_col = dq_head_wave_id * dq_head_per_wave + hi * 16 + lane_mod_16 + global_idx = batch_offset_dq + global_row * head_dim + global_col + global_idx_bytes = global_idx * 4 + + acc_idx = mi * dq_num_subtiles_head + hi + acc = final_accs[acc_idx] + val_f32 = vector.extract(acc, static_position=[ii], dynamic_position=[]) + val_f32 = val_f32 * c_sm_scale + rocdl.raw_ptr_buffer_atomic_fadd(val_f32, dq_rsrc, fx.Int32(global_idx_bytes), fx.Int32(0), fx.Int32(0)) + #buffer_ops.buffer_store(val_f32, dq_rsrc, global_idx) + + + def store_dk_bf16(final_accs): + for ni in range_constexpr(dk_num_subtiles_n): + for hi in range_constexpr(dk_num_subtiles_head): + acc_idx = ni * dk_num_subtiles_head + hi + acc = final_accs[acc_idx] + for ii in range_constexpr(4): + + global_row = global_offset_n + dk_n_wave_id * dk_n_per_wave + ni * 16 + lane_div_16 * 4 + ii + global_col = dk_head_wave_id * dk_head_per_wave + hi * 16 + lane_mod_16 + global_idx = batch_offset_dkdv + global_row * head_dim + global_col + + acc_idx = ni * dk_num_subtiles_head + hi + acc = final_accs[acc_idx] + val_f32 = vector.extract(acc, static_position=[ii], dynamic_position=[]) + val_f32 = val_f32 * c_sm_scale + val_bf16 = arith.trunc_f(T.bf16, val_f32) + buffer_ops.buffer_store(val_bf16, dk_rsrc, global_idx) + + + def store_dv_bf16(final_accs): + for ni in range_constexpr(dv_n_num_subtiles): + for hi in range_constexpr(dv_head_num_subtiles): + acc_idx = ni * dv_head_num_subtiles + hi + acc = final_accs[acc_idx] + for ii in range_constexpr(4): + + global_row = global_offset_n + dv_n_wave_id * dv_n_per_wave + ni * 16 + lane_div_16 * 4 + ii + global_col = dv_head_wave_id * dv_head_per_wave + hi * 16 + lane_mod_16 + global_idx = batch_offset_dkdv + global_row * head_dim + global_col + + acc_idx = ni * dv_head_num_subtiles + hi + acc = final_accs[acc_idx] + val_f32 = vector.extract(acc, static_position=[ii], dynamic_position=[]) + val_bf16 = arith.trunc_f(T.bf16, val_f32) + buffer_ops.buffer_store(val_bf16, dv_rsrc, global_idx) + + + # ── Scheduling hints ────────────────────────────────────────────── + rocdl.sched_barrier(0) + + def hot_loop_scheduler(): + rocdl.sched_barrier(0) + return + + # ── Main pipeline ───────────────────────────────────────────────── + + def _pack_state(dk, dv): + return list(dk) + list(dv) + + def _unpack_state(vals): + dk = list(vals[:dk_n_accs]) + dv = list(vals[dk_n_accs:]) + return dk, dv + + def pingpong(offset_m, inner_state): + dk, dv = _unpack_state(inner_state) + + next_offset_m = offset_m + tile_m + next_offset_m_mx = next_offset_m // 32 + store_q_tile_to_lds(prefetch_q_quant_head_tile(next_offset_m), lds_q_quant_head_ping) + store_q_scale_tile_to_lds(prefetch_q_scale_head_tile(next_offset_m), lds_q_scale_head_ping) + store_q_tile_to_lds(prefetch_q_quant_m_tile(next_offset_m), lds_q_quant_m_ping) + store_q_scale_tile_to_lds(prefetch_q_scale_m_tile(next_offset_m_mx), lds_q_scale_m_ping) + store_do_tile_to_lds(prefetch_do_quant_head_tile(next_offset_m), lds_do_quant_head_ping) + store_do_scale_tile_to_lds(prefetch_do_scale_head_tile(next_offset_m), lds_do_scale_head_ping) + store_do_tile_to_lds(prefetch_do_quant_m_tile(next_offset_m), lds_do_quant_m_ping) + store_do_scale_tile_to_lds(prefetch_do_scale_m_tile(next_offset_m_mx), lds_do_scale_m_ping) + + qk = compute_qk(lds_q_quant_head_pong, lds_q_scale_head_pong, lds_k_quant_head, lds_k_scale_head) + p = softmax(qk, offset_m) + mxquant_m_and_store_to_lds(p, lds_ppt_shuffle, lds_ppt_scale_shuffle) + gpu.barrier() + dv = compute_dv(dv, lds_ppt_shuffle, lds_ppt_scale_shuffle, lds_do_quant_m_pong, lds_do_scale_m_pong) + dp = compute_dp(lds_do_quant_head_pong, lds_do_scale_head_pong, lds_v, lds_v_scale) + ds = compute_ds(dp, p, offset_m) + mxquant_m_and_store_to_lds(ds, lds_dst_shuffle, lds_dst_scale_shuffle) + mxquant_n_and_store_to_lds(ds, lds_ds_shuffle, lds_ds_scale_shuffle) + gpu.barrier() + dk = compute_dk(dk, lds_dst_shuffle, lds_dst_scale_shuffle, lds_q_quant_m_pong, lds_q_scale_m_pong) + dq = compute_dq(lds_ds_shuffle, lds_ds_scale_shuffle, lds_k_quant_n, lds_k_scale_n) + store_dq_atomic(dq, offset_m) + hot_loop_scheduler() + gpu.barrier() + + next_offset_m = offset_m + (tile_m * 2) + next_offset_m_mx = next_offset_m // 32 + store_q_tile_to_lds(prefetch_q_quant_head_tile(next_offset_m), lds_q_quant_head_pong) + store_q_scale_tile_to_lds(prefetch_q_scale_head_tile(next_offset_m), lds_q_scale_head_pong) + store_q_tile_to_lds(prefetch_q_quant_m_tile(next_offset_m), lds_q_quant_m_pong) + store_q_scale_tile_to_lds(prefetch_q_scale_m_tile(next_offset_m_mx), lds_q_scale_m_pong) + store_do_tile_to_lds(prefetch_do_quant_head_tile(next_offset_m), lds_do_quant_head_pong) + store_do_scale_tile_to_lds(prefetch_do_scale_head_tile(next_offset_m), lds_do_scale_head_pong) + store_do_tile_to_lds(prefetch_do_quant_m_tile(next_offset_m), lds_do_quant_m_pong) + store_do_scale_tile_to_lds(prefetch_do_scale_m_tile(next_offset_m_mx), lds_do_scale_m_pong) + + qk = compute_qk(lds_q_quant_head_ping, lds_q_scale_head_ping, lds_k_quant_head, lds_k_scale_head) + p = softmax(qk, offset_m + tile_m) + mxquant_m_and_store_to_lds(p, lds_ppt_shuffle, lds_ppt_scale_shuffle) + gpu.barrier() + dv = compute_dv(dv, lds_ppt_shuffle, lds_ppt_scale_shuffle, lds_do_quant_m_ping, lds_do_scale_m_ping) + dp = compute_dp(lds_do_quant_head_ping, lds_do_scale_head_ping, lds_v, lds_v_scale) + ds = compute_ds(dp, p, offset_m + tile_m) + mxquant_m_and_store_to_lds(ds, lds_dst_shuffle, lds_dst_scale_shuffle) + mxquant_n_and_store_to_lds(ds, lds_ds_shuffle, lds_ds_scale_shuffle) + gpu.barrier() + dk = compute_dk(dk, lds_dst_shuffle, lds_dst_scale_shuffle, lds_q_quant_m_ping, lds_q_scale_m_ping) + dq = compute_dq(lds_ds_shuffle, lds_ds_scale_shuffle, lds_k_quant_n, lds_k_scale_n) + store_dq_atomic(dq, offset_m + tile_m) + hot_loop_scheduler() + gpu.barrier() + + return _pack_state(dk, dv) + + if const_expr(causal): + start_m = (global_offset_n // (tile_m * 2)) * (tile_m * 2) + else: + start_m = fx.Index(0) + start_m_mx = start_m // 32 + + store_q_tile_to_lds(prefetch_q_quant_head_tile(start_m), lds_q_quant_head_pong) + store_q_scale_tile_to_lds(prefetch_q_scale_head_tile(start_m), lds_q_scale_head_pong) + store_q_tile_to_lds(prefetch_q_quant_m_tile(start_m), lds_q_quant_m_pong) + store_q_scale_tile_to_lds(prefetch_q_scale_m_tile(start_m_mx), lds_q_scale_m_pong) + store_k_tile_to_lds(prefetch_k_quant_head_tile(), lds_k_quant_head) + store_k_scale_tile_to_lds(prefetch_k_scale_head_tile(), lds_k_scale_head) + store_k_tile_to_lds(prefetch_k_quant_n_tile(), lds_k_quant_n) + store_k_scale_tile_to_lds(prefetch_k_scale_n_tile(), lds_k_scale_n) + store_v_tile_to_lds(prefetch_v_tile(), lds_v) + store_v_scale_tile_to_lds(prefetch_v_scale_tile(), lds_v_scale) + store_do_tile_to_lds(prefetch_do_quant_head_tile(start_m), lds_do_quant_head_pong) + store_do_scale_tile_to_lds(prefetch_do_scale_head_tile(start_m), lds_do_scale_head_pong) + store_do_tile_to_lds(prefetch_do_quant_m_tile(start_m), lds_do_quant_m_pong) + store_do_scale_tile_to_lds(prefetch_do_scale_m_tile(start_m_mx), lds_do_scale_m_pong) + gpu.barrier() + dk = [acc_init] * dk_n_accs + dv = [acc_init] * dv_n_accs + + num_tiles_loop = seqlen // tile_m + if const_expr((num_tiles_loop % 2) == 1): + upper_bound = seqlen - tile_m + init_state = _pack_state(dk, dv) + for iv, inner in range(start_m, upper_bound, tile_m * 2, init=init_state): + results = yield pingpong(iv, inner) + dk, dv = _unpack_state(results) + + curr_m = arith.index(seqlen - tile_m) + qk = compute_qk(lds_q_quant_head_pong, lds_q_scale_head_pong, lds_k_quant_head, lds_k_scale_head) + p = softmax(qk, curr_m) + mxquant_m_and_store_to_lds(p, lds_ppt_shuffle, lds_ppt_scale_shuffle) + gpu.barrier() + dv = compute_dv(dv, lds_ppt_shuffle, lds_ppt_scale_shuffle, lds_do_quant_m_pong, lds_do_scale_m_pong) + dp = compute_dp(lds_do_quant_head_pong, lds_do_scale_head_pong, lds_v, lds_v_scale) + ds = compute_ds(dp, p, curr_m) + mxquant_m_and_store_to_lds(ds, lds_dst_shuffle, lds_dst_scale_shuffle) + mxquant_n_and_store_to_lds(ds, lds_ds_shuffle, lds_ds_scale_shuffle) + gpu.barrier() + dk = compute_dk(dk, lds_dst_shuffle, lds_dst_scale_shuffle, lds_q_quant_m_pong, lds_q_scale_m_pong) + dq = compute_dq(lds_ds_shuffle, lds_ds_scale_shuffle, lds_k_quant_n, lds_k_scale_n) + store_dq_atomic(dq, curr_m) + else: + upper_bound = seqlen - (tile_m * 2) + init_state = _pack_state(dk, dv) + for iv, inner in range(start_m, upper_bound, tile_m * 2, init=init_state): + results = yield pingpong(iv, inner) + dk, dv = _unpack_state(results) + + curr_m = arith.index(seqlen - tile_m * 2) + last_m = arith.index(seqlen - tile_m) + last_m_mx = last_m // 32 + store_q_tile_to_lds(prefetch_q_quant_head_tile(last_m), lds_q_quant_head_ping) + store_q_scale_tile_to_lds(prefetch_q_scale_head_tile(last_m), lds_q_scale_head_ping) + store_q_tile_to_lds(prefetch_q_quant_m_tile(last_m), lds_q_quant_m_ping) + store_q_scale_tile_to_lds(prefetch_q_scale_m_tile(last_m_mx), lds_q_scale_m_ping) + store_do_tile_to_lds(prefetch_do_quant_head_tile(last_m), lds_do_quant_head_ping) + store_do_scale_tile_to_lds(prefetch_do_scale_head_tile(last_m), lds_do_scale_head_ping) + store_do_tile_to_lds(prefetch_do_quant_m_tile(last_m), lds_do_quant_m_ping) + store_do_scale_tile_to_lds(prefetch_do_scale_m_tile(last_m_mx), lds_do_scale_m_ping) + + qk = compute_qk(lds_q_quant_head_pong, lds_q_scale_head_pong, lds_k_quant_head, lds_k_scale_head) + p = softmax(qk, curr_m) + mxquant_m_and_store_to_lds(p, lds_ppt_shuffle, lds_ppt_scale_shuffle) + gpu.barrier() + dv = compute_dv(dv, lds_ppt_shuffle, lds_ppt_scale_shuffle, lds_do_quant_m_pong, lds_do_scale_m_pong) + dp = compute_dp(lds_do_quant_head_pong, lds_do_scale_head_pong, lds_v, lds_v_scale) + ds = compute_ds(dp, p, curr_m) + mxquant_m_and_store_to_lds(ds, lds_dst_shuffle, lds_dst_scale_shuffle) + mxquant_n_and_store_to_lds(ds, lds_ds_shuffle, lds_ds_scale_shuffle) + gpu.barrier() + dk = compute_dk(dk, lds_dst_shuffle, lds_dst_scale_shuffle, lds_q_quant_m_pong, lds_q_scale_m_pong) + dq = compute_dq(lds_ds_shuffle, lds_ds_scale_shuffle, lds_k_quant_n, lds_k_scale_n) + store_dq_atomic(dq, curr_m) + + hot_loop_scheduler() + gpu.barrier() + + curr_m = last_m + qk = compute_qk(lds_q_quant_head_ping, lds_q_scale_head_ping, lds_k_quant_head, lds_k_scale_head) + p = softmax(qk, curr_m) + mxquant_m_and_store_to_lds(p, lds_ppt_shuffle, lds_ppt_scale_shuffle) + gpu.barrier() + dv = compute_dv(dv, lds_ppt_shuffle, lds_ppt_scale_shuffle, lds_do_quant_m_ping, lds_do_scale_m_ping) + dp = compute_dp(lds_do_quant_head_ping, lds_do_scale_head_ping, lds_v, lds_v_scale) + ds = compute_ds(dp, p, curr_m) + mxquant_m_and_store_to_lds(ds, lds_dst_shuffle, lds_dst_scale_shuffle) + mxquant_n_and_store_to_lds(ds, lds_ds_shuffle, lds_ds_scale_shuffle) + gpu.barrier() + dk = compute_dk(dk, lds_dst_shuffle, lds_dst_scale_shuffle, lds_q_quant_m_ping, lds_q_scale_m_ping) + dq = compute_dq(lds_ds_shuffle, lds_ds_scale_shuffle, lds_k_quant_n, lds_k_scale_n) + store_dq_atomic(dq, curr_m) + + store_dk_bf16(dk) + store_dv_bf16(dv) + + + # ── Host launcher ────────────────────────────────────────────────────── + _cache_tag = (tile_m, tile_n, head_dim) + + @flyc.jit + def launch_attn_bwd( + arg_dq: fx.Tensor, + arg_dk: fx.Tensor, + arg_dv: fx.Tensor, + arg_q_quant_head: fx.Tensor, + arg_q_scale_head: fx.Tensor, + arg_q_quant_m: fx.Tensor, + arg_q_scale_m: fx.Tensor, + arg_k_quant_head: fx.Tensor, + arg_k_scale_head: fx.Tensor, + arg_k_quant_n: fx.Tensor, + arg_k_scale_n: fx.Tensor, + arg_v: fx.Tensor, + arg_v_scale: fx.Tensor, + arg_do_quant_head: fx.Tensor, + arg_do_scale_head: fx.Tensor, + arg_do_quant_m: fx.Tensor, + arg_do_scale_m: fx.Tensor, + arg_m: fx.Tensor, + arg_D: fx.Tensor, + batch: fx.Int32, + stream: fx.Stream, + ): + _ = _cache_tag + allocator_pong.finalized = False + allocator_ping.finalized = False + allocator_k_quant_head.finalized = False + allocator_k_scale_head.finalized = False + allocator_k_quant_n.finalized = False + allocator_k_scale_n.finalized = False + allocator_v.finalized = False + allocator_v_scale.finalized = False + allocator_ppt_shuffle.finalized = False + allocator_ppt_scale_shuffle.finalized = False + allocator_dst_shuffle.finalized = False + allocator_dst_scale_shuffle.finalized = False + allocator_ds_shuffle.finalized = False + allocator_ds_scale_shuffle.finalized = False + ctx = CompilationContext.get_current() + with ir.InsertionPoint(ctx.gpu_module_body): + allocator_pong.finalize() + allocator_ping.finalize() + allocator_k_quant_head.finalize() + allocator_k_scale_head.finalize() + allocator_k_quant_n.finalize() + allocator_k_scale_n.finalize() + allocator_v.finalize() + allocator_v_scale.finalize() + allocator_ppt_shuffle.finalize() + allocator_ppt_scale_shuffle.finalize() + allocator_dst_shuffle.finalize() + allocator_dst_scale_shuffle.finalize() + allocator_ds_shuffle.finalize() + allocator_ds_scale_shuffle.finalize() + + gx = seqlen // tile_n + gy = batch + + launcher = kernel_attn_bwd(arg_dq, arg_dk, arg_dv, arg_q_quant_head, arg_q_scale_head, arg_q_quant_m, arg_q_scale_m, arg_k_quant_head, arg_k_scale_head, arg_k_quant_n, arg_k_scale_n, arg_v, arg_v_scale, arg_do_quant_head, arg_do_scale_head, arg_do_quant_m, arg_do_scale_m, arg_m, arg_D, batch) + if waves_per_eu is not None: + _wpe = int(waves_per_eu) + if _wpe >= 1: + for op in ctx.gpu_module_body.operations: + if hasattr(op, 'attributes') and op.OPERATION_NAME == "gpu.func": + op.attributes["rocdl.waves_per_eu"] = ir.IntegerAttr.get(T.i32, _wpe) + launcher.launch( + grid=(gx, gy, 1), + block=(256, 1, 1), + stream=stream, + ) + + return launch_attn_bwd + + +__all__ = ["compile_attn_bwd_mxfp8_gfx950"] diff --git a/op_tests/flydsl_tests/test_attn_bwd_mxfp8_gfx950.py b/op_tests/flydsl_tests/test_attn_bwd_mxfp8_gfx950.py new file mode 100644 index 0000000000..8eff2b452a --- /dev/null +++ b/op_tests/flydsl_tests/test_attn_bwd_mxfp8_gfx950.py @@ -0,0 +1,181 @@ +#!/usr/bin/env python3 +"""Attention backward fp8 test — @flyc.kernel API. + +Kernel implementation lives in `kernels/attn_bwd_mxfp8_gfx950.py`. +""" + +import logging +import torch +import pytest + +from aiter.ops.triton.quant.mxfp8_quant import downcast_to_mxfp8, upcast_from_mxfp8 +from aiter.ops.flydsl.kernels.attn_bwd_mxfp8_gfx950 import compile_attn_bwd_mxfp8_gfx950 +from flydsl.runtime.device import get_rocm_arch + +logging.basicConfig(level=logging.INFO) + +if not torch.cuda.is_available(): + pytest.skip("CUDA/ROCm not available. Skipping GPU tests.", allow_module_level=True) + +ARCH = str(get_rocm_arch()) + +def check_result(test_out, ref_out, atol=0.01, rtol=0.01, pass_pct=95.0): + """Compare outputs and print result. Returns (passed, max_delta, pct_close).""" + close_mask = torch.isclose(test_out.float(), ref_out.float(), atol=atol, rtol=rtol) + pct_close = close_mask.float().mean().item() * 100 + passed = pct_close > pass_pct + if passed: + return True + + max_delta = (ref_out.float() - test_out.float()).abs().max().item() + print( + f" max_delta={max_delta:.4f}, {pct_close:.1f}% close (atol={atol}, rtol={rtol})" + ) + print(f" ref sample: {ref_out.reshape(-1)[:8]}") + print(f" test sample: {test_out.reshape(-1)[:8]}") + print(f" --> {'PASS' if passed else 'FAIL'}") + + +def mx_quant(x, dim=-1): + x_fp8, x_scale = downcast_to_mxfp8(x, torch.float8_e4m3fn, dim) + x_fp32 = upcast_from_mxfp8(x_fp8, x_scale, torch.float32, dim) + return x_fp32.contiguous(), x_fp8.contiguous(), x_scale.contiguous() + + +def run_torch(q_fp32_head, q_fp32_m, k_fp32_head, k_fp32_n, v, do_fp32_head, do_fp32_m, m, D, sm_scale, causal, dtype=torch.float32): + seqlen = q_fp32_head.shape[1] + device = q_fp32_head.device + v_f32 = v.to(torch.float32) + qk = torch.matmul(q_fp32_head, k_fp32_head.transpose(-2, -1)) * sm_scale + p = torch.exp(qk - m[:, :, None]) + if causal: + mask = torch.tril(torch.ones((seqlen, seqlen), device=device)) #.T + p[:, mask == 0] = 0.0 + + ppT, _, _ = mx_quant(p, -2) + ppT = ppT.transpose(-2, -1) + dv = torch.matmul(ppT, do_fp32_m) + dp = torch.matmul(do_fp32_head, v_f32.transpose(-2, -1)) + ds = p * (dp - D[:, :, None]) + dsT, _, _ = mx_quant(ds, -1) + dsT = dsT.transpose(-2, -1) + ds, _, _ = mx_quant(ds, -2) + dk = torch.matmul(dsT, q_fp32_m) * sm_scale + dq = torch.matmul(ds, k_fp32_n) * sm_scale + + return dq, dk, dv + + +@pytest.mark.parametrize("batch", [2, 8, 45, 256]) +@pytest.mark.parametrize("seqlen", [128, 1024, 1152, 4096]) +@pytest.mark.parametrize("head_dim", [64, 128]) +@pytest.mark.parametrize("tile_m", [64, 128]) +@pytest.mark.parametrize("tile_n", [64, 128]) +@pytest.mark.parametrize("causal", [False, True]) +def test_attn_bwd_flyc( + batch, seqlen, head_dim, + tile_m, tile_n, + causal, + waves_per_eu: int = 0, +): + tile_head = head_dim + if tile_m == 128 and tile_head == 128: + pytest.skip("Too large block size") + + torch.manual_seed(0) + + sm_scale = 0.5 + _wpe = int(waves_per_eu) + launch_fn = compile_attn_bwd_mxfp8_gfx950( + seqlen=seqlen, head_dim=head_dim, + tile_m=tile_m, tile_n=tile_n, tile_head=tile_head, + sm_scale=sm_scale, + causal=causal, + waves_per_eu=_wpe, + ) + + device = torch.device("cuda") + q_fp32 = torch.randn(batch, seqlen, head_dim, device=device, dtype=torch.float32) * 0.5 + k_fp32 = torch.randn(batch, seqlen, head_dim, device=device, dtype=torch.float32) * 0.5 + v_fp32 = torch.randn(batch, seqlen, head_dim, device=device, dtype=torch.float32) * 0.5 + o_fp32 = torch.randn(batch, seqlen, head_dim, device=device, dtype=torch.float32) * 0.5 + do_fp32 = torch.randn(batch, seqlen, head_dim, device=device, dtype=torch.float32) * 0.5 + + qk = q_fp32 @ k_fp32.transpose(-2, -1) + qk = qk * sm_scale + m = qk.max(dim=-1)[0] + p = (qk - m[:, :, None]).exp() + l = p.sum(dim=-1) + p = p / l[:, :, None] + o_fp32 = torch.matmul(p, v_fp32) + m = m + torch.log(l) + D = (o_fp32 * do_fp32).sum(dim=-1) + + q_fp32_head, q_quant_head, q_scale_head = mx_quant(q_fp32, -1) + q_fp32_m, q_quant_m, q_scale_m = mx_quant(q_fp32, -2) + k_fp32_head, k_quant_head, k_scale_head = mx_quant(k_fp32, -1) + k_fp32_n, k_quant_n, k_scale_n = mx_quant(k_fp32, -2) + v_fp32, v_quant, v_scale = mx_quant(v_fp32) + do_fp32_head, do_quant_head, do_scale_head = mx_quant(do_fp32, -1) + do_fp32_m, do_quant_m, do_scale_m = mx_quant(do_fp32, -2) + + dq_ref, dk_ref, dv_ref = run_torch(q_fp32_head, q_fp32_m, k_fp32_head, k_fp32_n, v_fp32, do_fp32_head, do_fp32_m, m, D, sm_scale, causal, dtype=torch.float32) + dq_fly = torch.zeros((batch, seqlen, head_dim), dtype=torch.float32, device=device) + dk_fly = torch.zeros((batch, seqlen, head_dim), dtype=torch.bfloat16, device=device) + dv_fly = torch.zeros((batch, seqlen, head_dim), dtype=torch.bfloat16, device=device) + + def launch_kernel(dq, dk, dv, q_quant_head, q_scale_head, q_quant_m, q_scale_m, k_quant_head, k_scale_head, k_quant_n, k_scale_n, v, v_scale, do_quant_head, do_scale_head, do_quant_m, do_scale_m, m, D, batch): + launch_fn( + dq.contiguous().view(-1), + dk.contiguous().view(-1), + dv.contiguous().view(-1), + q_quant_head.contiguous().view(-1), + q_scale_head.contiguous().view(-1), + q_quant_m.contiguous().view(-1), + q_scale_m.contiguous().view(-1), + k_quant_head.contiguous().view(-1), + k_scale_head.contiguous().view(-1), + k_quant_n.contiguous().view(-1), + k_scale_n.contiguous().view(-1), + v.contiguous().view(-1), + v_scale.contiguous().view(-1), + do_quant_head.contiguous().view(-1), + do_scale_head.contiguous().view(-1), + do_quant_m.contiguous().view(-1), + do_scale_m.contiguous().view(-1), + m.contiguous().view(-1), + D.contiguous().view(-1), + batch, + torch.cuda.current_stream(), + ) + + launch_kernel( + dq_fly, + dk_fly, + dv_fly, + q_quant_head, + q_scale_head, + q_quant_m, + q_scale_m, + k_quant_head, + k_scale_head, + k_quant_n, + k_scale_n, + v_quant, + v_scale, + do_quant_head, + do_scale_head, + do_quant_m, + do_scale_m, + m, + D, + batch + ) + + dq_fly_fp32 = dq_fly.to(torch.float32) + dk_fly_fp32 = dk_fly.to(torch.float32) + dv_fly_fp32 = dv_fly.to(torch.float32) + + assert check_result(dq_fly_fp32, dq_ref, rtol=0.01, atol=0.01, pass_pct=99.0) + assert check_result(dk_fly_fp32, dk_ref, rtol=0.01, atol=0.01, pass_pct=99.0) + assert check_result(dv_fly_fp32, dv_ref, rtol=0.01, atol=0.01, pass_pct=99.0) diff --git a/op_tests/op_benchmarks/flydsl/bench_attn_bwd_mxfp8_gfx950.py b/op_tests/op_benchmarks/flydsl/bench_attn_bwd_mxfp8_gfx950.py new file mode 100644 index 0000000000..b3d7932eba --- /dev/null +++ b/op_tests/op_benchmarks/flydsl/bench_attn_bwd_mxfp8_gfx950.py @@ -0,0 +1,202 @@ +#!/usr/bin/env python3 +"""Attention backward test — @flyc.kernel API. + +Kernel implementation lives in `flydsl/kernels/attn_bwd_mxfp8_gfx950.py`. +This file is the perf and correctness harness. +""" + +import logging +import torch +from aiter.ops.flydsl.kernels.attn_bwd_mxfp8_gfx950 import compile_attn_bwd_mxfp8_gfx950 +from utils import run_perftest +from op_tests.flydsl_tests.test_attn_bwd_mxfp8_gfx950 import run_torch, mx_quant, check_result +from flydsl.runtime.device import get_rocm_arch + +logging.basicConfig(level=logging.INFO) +ARCH = str(get_rocm_arch()) +DEFAULT_BENCH_ITERS = 20 +DEFAULT_BENCH_WARMUP = 3 + +def bench_attn_bwd_flyc( + batch, seqlen, head_dim, + tile_m, tile_n, + causal, + test_graph, + bench_iters: int = DEFAULT_BENCH_ITERS, + bench_warmup: int = DEFAULT_BENCH_WARMUP, + waves_per_eu: int = 0, + check_correctness: bool = False +): + """Attention bwd using the @flyc.kernel / @flyc.jit API.""" + tile_head = head_dim + print("=" * 80) + print( + f"[flyc] Attention Backward Test (Tile: {tile_m}x{tile_n}x{tile_head})" + ) + print("=" * 80) + + sm_scale = 0.5 + _wpe = int(waves_per_eu) if waves_per_eu else 0 + launch_fn = compile_attn_bwd_mxfp8_gfx950( + seqlen=seqlen, head_dim=head_dim, + tile_m=tile_m, tile_n=tile_n, tile_head=tile_head, + sm_scale=sm_scale, + causal=causal, + waves_per_eu=_wpe, + ) + print(f"✓ Kernel prepared") + + device = torch.device("cuda") + q_fp32 = torch.randn(batch, seqlen, head_dim, device=device, dtype=torch.float32) * 0.5 + k_fp32 = torch.randn(batch, seqlen, head_dim, device=device, dtype=torch.float32) * 0.5 + v_fp32 = torch.randn(batch, seqlen, head_dim, device=device, dtype=torch.float32) * 0.5 + o_fp32 = torch.randn(batch, seqlen, head_dim, device=device, dtype=torch.float32) * 0.5 + do_fp32 = torch.randn(batch, seqlen, head_dim, device=device, dtype=torch.float32) * 0.5 + + q_fp32_head, q_quant_head, q_scale_head = mx_quant(q_fp32, -1) + q_fp32_m, q_quant_m, q_scale_m = mx_quant(q_fp32, -2) + k_fp32_head, k_quant_head, k_scale_head = mx_quant(k_fp32, -1) + k_fp32_n, k_quant_n, k_scale_n = mx_quant(k_fp32, -2) + v_fp32, v_quant, v_scale = mx_quant(v_fp32) + do_fp32_head, do_quant_head, do_scale_head = mx_quant(do_fp32, -1) + do_fp32_m, do_quant_m, do_scale_m = mx_quant(do_fp32, -2) + + qk = q_fp32 @ k_fp32.transpose(-2, -1) + qk = qk * sm_scale + m = qk.max(dim=-1)[0] + p = (qk - m[:, :, None]).exp() + l = p.sum(dim=-1) + m = m + torch.log(l) + D = (o_fp32 * do_fp32).sum(dim=-1) + + if check_correctness: + dq_ref, dk_ref, dv_ref = run_torch(q_fp32_head, q_fp32_m, k_fp32_head, k_fp32_n, v_fp32, do_fp32_head, do_fp32_m, m, D, sm_scale, causal) + dq_fly = torch.zeros((batch, seqlen, head_dim), dtype=torch.float32, device=device) + dk_fly = torch.zeros((batch, seqlen, head_dim), dtype=torch.bfloat16, device=device) + dv_fly = torch.zeros((batch, seqlen, head_dim), dtype=torch.bfloat16, device=device) + + def launch_kernel(dq, dk, dv, q_quant_head, q_scale_head, q_quant_m, q_scale_m, k_quant_head, k_scale_head, k_quant_n, k_scale_n, v, v_scale, do_quant_head, do_scale_head, do_quant_m, do_scale_m, m, D, batch): + launch_fn( + dq.contiguous().view(-1), + dk.contiguous().view(-1), + dv.contiguous().view(-1), + q_quant_head.contiguous().view(-1), + q_scale_head.contiguous().view(-1), + q_quant_m.contiguous().view(-1), + q_scale_m.contiguous().view(-1), + k_quant_head.contiguous().view(-1), + k_scale_head.contiguous().view(-1), + k_quant_n.contiguous().view(-1), + k_scale_n.contiguous().view(-1), + v.contiguous().view(-1), + v_scale.contiguous().view(-1), + do_quant_head.contiguous().view(-1), + do_scale_head.contiguous().view(-1), + do_quant_m.contiguous().view(-1), + do_scale_m.contiguous().view(-1), + m.contiguous().view(-1), + D.contiguous().view(-1), + batch, + torch.cuda.current_stream(), + ) + + bench_iters = max(2, int(bench_iters)) + bench_warmup = int(bench_warmup) + _, us = run_perftest( + launch_kernel, + dq_fly, + dk_fly, + dv_fly, + q_quant_head, + q_scale_head, + q_quant_m, + q_scale_m, + k_quant_head, + k_scale_head, + k_quant_n, + k_scale_n, + v_quant, + v_scale, + do_quant_head, + do_scale_head, + do_quant_m, + do_scale_m, + m, + D, + batch, + num_iters=bench_iters, + num_warmup=bench_warmup, + testGraph=test_graph, + ) + torch.cuda.synchronize() + + dq_fly.zero_() + launch_kernel( + dq_fly, + dk_fly, + dv_fly, + q_quant_head, + q_scale_head, + q_quant_m, + q_scale_m, + k_quant_head, + k_scale_head, + k_quant_n, + k_scale_n, + v_quant, + v_scale, + do_quant_head, + do_scale_head, + do_quant_m, + do_scale_m, + m, + D, + batch + ) + + dq_fly_fp32 = dq_fly.to(torch.float32) + dk_fly_fp32 = dk_fly.to(torch.float32) + dv_fly_fp32 = dv_fly.to(torch.float32) + + if check_correctness: + assert check_result(dq_fly_fp32, dq_ref, rtol=0.01, atol=0.01) + assert check_result(dk_fly_fp32, dk_ref, rtol=0.01, atol=0.01) + assert check_result(dv_fly_fp32, dv_ref, rtol=0.01, atol=0.01) + + bytes_moved = (7 + 4 + 2 * 2) * seqlen * head_dim + 2 * 4 * seqlen + flops = batch * (5 * 2 * seqlen * seqlen * head_dim + 5 * seqlen * seqlen + 2 * 3 * seqlen * seqlen) + if causal: + flops /= 2 + tflops = flops / (us / 1e6) / 1e12 + tbps = bytes_moved / 1e12 / (us / 1e6) + print(f"[flyc] Throughput: {us:.1f} us, {tflops:.2f} TFLOPS, BW: {tbps:.3f} TB/s") + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser(description="Preshuffle GEMM benchmark") + parser.add_argument("--batch", type=int, default=1) + parser.add_argument("--seqlen", type=int, default=1024) + parser.add_argument("--head", type=int, default=128) + parser.add_argument("--tile_m", type=int, default=128) + parser.add_argument("--tile_n", type=int, default=128) + parser.add_argument("--causal", action="store_true", default=False) + parser.add_argument("--num_iters", type=int, default=DEFAULT_BENCH_ITERS) + parser.add_argument("--num_warmup", type=int, default=DEFAULT_BENCH_WARMUP) + parser.add_argument("--waves_per_eu", type=int, default=0, choices=[0, 1, 2, 3, 4]) + parser.add_argument("--test_graph", action="store_true", default=False) + parser.add_argument("--check_correctness", action="store_true", default=False) + args = parser.parse_args() + torch.set_default_device("cuda") + + bench_attn_bwd_flyc( + batch=args.batch, seqlen=args.seqlen, head_dim=args.head, + tile_m=args.tile_m, tile_n=args.tile_n, + causal=args.causal, + test_graph=bool(args.test_graph), + bench_iters=args.num_iters, + bench_warmup=args.num_warmup, + waves_per_eu=int(args.waves_per_eu), + check_correctness=args.check_correctness + ) diff --git a/op_tests/op_benchmarks/flydsl/utils.py b/op_tests/op_benchmarks/flydsl/utils.py new file mode 100644 index 0000000000..bbcf69d2bf --- /dev/null +++ b/op_tests/op_benchmarks/flydsl/utils.py @@ -0,0 +1,368 @@ +# SPDX-License-Identifier: MIT +# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +# from https://github.com/ROCm/aiter/blob/main/aiter/test_common.py +import torch +import torch.profiler as tpf +import os +import copy +import time +import numpy as np +import pandas as pd +import logging + +logger = logging.getLogger("flydsl") + +pd.set_option("display.max_rows", 200) +## debug ## +# pd.set_option("display.max_rows", None) +# pd.set_option("display.max_columns", None) +# pd.set_option("display.width", None) +# pd.set_option("display.max_colwidth", None) +# pd.set_option("display.expand_frame_repr", False) + + +def perftest( + num_iters=20, num_warmup=3, testGraph=False, num_rotate_args=0, needTrace=False +): + def decorator(func): + def wrapper(*args, **kwargs): + # ROCm torch.profiler (ROCTracer) is not always stable when invoked repeatedly + # under pytest (multiple tests, repeated init/teardown). For unit tests, the + # profiler is not required; fall back to simple timing. + # + num = num_rotate_args + if num < 1: + gpu_id = torch.cuda.current_device() + iter_used_memory, inputSize, _, _ = device_memory_profiling( + func, *args, **kwargs + ) + + properties = torch.cuda.get_device_properties(gpu_id) + free_memory = torch.cuda.mem_get_info(gpu_id)[0] + cache_size = min( + getattr(properties, "L2_cache_size", 4096 * 1024) * 64 * 128, + (free_memory - iter_used_memory + inputSize) * 0.9, + ) + cache_size = max(cache_size, 0) + num = int((cache_size + inputSize - 1) // inputSize) + num = min(num, num_iters) + + rotate_args = [ + (copy.deepcopy(args), copy.deepcopy(kwargs)) for _ in range(num - 1) + ] + [(args, kwargs)] + run_iters(num_warmup, func, *args, **kwargs) + torch.cuda.synchronize() + + if int(os.environ.get("FLYDSL_LOG_MORE", 0)): + latencies = [] + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + for _ in range(num_iters): + start_event.record() + data = func(*args, **kwargs) + end_event.record() + end_event.synchronize() + latencies.append(start_event.elapsed_time(end_event)) + avg = np.mean(latencies) * 1000 + logger.info(f"avg: {avg} us/iter from cuda.Event") + with tpf.profile( + activities=[tpf.ProfilerActivity.CPU, tpf.ProfilerActivity.CUDA], + profile_memory=False, + with_stack=False, + with_modules=True, + ) as prof: + data = run_iters_rotate(num_iters, func, rotate_args) + torch.cuda.synchronize() + torch.cuda.empty_cache() + avg = get_trace_perf(prof, num_iters) + + if testGraph: + graph = torch.cuda.CUDAGraph() + with torch.cuda.graph(graph): + data = run_iters_rotate(num_iters, func, rotate_args) + with tpf.profile( + activities=[tpf.ProfilerActivity.CPU, tpf.ProfilerActivity.CUDA], + profile_memory=True, + with_stack=True, + with_modules=True, + ) as prof: + run_iters(1, graph.replay) + avg = get_trace_perf(prof, num_iters) + logger.info(f"avg: {avg} us/iter with hipgraph") + + return data, avg + + return wrapper + + return decorator + + +def benchmark(): + def decorator(func): + def wrapper(*args, **kwargs): + callargs = log_args(func, *args, **kwargs) + ret = func(*args, **kwargs) + if ret is not None: + callargs.update(ret) + return callargs + + return wrapper + + return decorator + + +def device_memory_profiling(func, *args, **kwargs): + gpu_id = torch.cuda.current_device() + inputSize = ( + sum( + [ + el.nbytes + for el in args + if isinstance(el, torch.Tensor) and el.device.index == gpu_id + ] + ) + + 1 + ) + torch.cuda.reset_peak_memory_stats(gpu_id) + cuda_memory_before = ( + torch.cuda.mem_get_info(gpu_id)[1] - torch.cuda.mem_get_info(gpu_id)[0] + ) + torch_memory_before = torch.cuda.memory_reserved(gpu_id) + torch_peak_before = torch.cuda.memory_stats(gpu_id).get( + "allocated_bytes.all.peak", 0 + ) + non_torch_memory_before = cuda_memory_before - torch_memory_before + + data = func(*args, **kwargs) + + torch.cuda.reset_peak_memory_stats(gpu_id) + cuda_memory_after = ( + torch.cuda.mem_get_info(gpu_id)[1] - torch.cuda.mem_get_info(gpu_id)[0] + ) + torch_memory_after = torch.cuda.memory_reserved(gpu_id) + torch_peak_after = torch.cuda.memory_stats(gpu_id).get( + "allocated_bytes.all.peak", 0 + ) + non_torch_memory_after = cuda_memory_after - torch_memory_after + + torch_peak_increase = torch_peak_after - torch_peak_before + non_torch_increase = non_torch_memory_after - non_torch_memory_before + iter_used_memory = torch_peak_increase + non_torch_increase + inputSize + + return iter_used_memory, inputSize, torch_peak_increase, non_torch_increase + + +def run_iters(num_iters, func, *args, **kwargs): + data = None + for _ in range(num_iters): + data = func(*args, **kwargs) + return data + + +def run_iters_rotate(num_iters, func, rotate_args): + data = None + num_rotate_args = len(rotate_args) + for _ in range(num_iters): + args, kwargs = rotate_args[_ % num_rotate_args] + data = func(*args, **kwargs) + + return data + + +def run_perftest( + func, + *args, + num_iters=20, + num_warmup=3, + testGraph=False, + num_rotate_args=0, + needTrace=False, + **kwargs, +): + + @perftest( + num_iters=num_iters, + num_warmup=num_warmup, + testGraph=testGraph, + num_rotate_args=num_rotate_args, + needTrace=needTrace, + ) + def worker(*args, **kwargs): + return func(*args, **kwargs) + + return worker(*args, **kwargs) + + +def log_args(func, *args, **kwargs): + import inspect + + callargs = inspect.getcallargs(func, *args, **kwargs) + + prefix = f"calling {func.__name__}(" + blanks = " " * (len(prefix)) + + def getTensorInfo(el): + if isinstance(el, torch.Tensor): + return f"{el.shape} {el.dtype} {el.device} {hex(el.data_ptr())}" + elif isinstance(el, tuple): + viewNum = 5 + if len(el) > viewNum: + el = list(el[:viewNum]) + ["..."] + return f'\n{" "*(len(prefix)+31)}'.join( + ["("] + [f" {getTensorInfo(e)}" for e in el] + [")"] + ) + return el + + info = [f"{el:<28} = {getTensorInfo(callargs[el])}" for el in callargs] + info = f",\n{blanks}".join(info) + logger.info(f"\n{prefix}{info})") + return callargs + + +def post_process_data(df, num_iters, warm_iter=1): + """remove abnormal data""" + + device_df = df[df["device_type"].astype(str).str.contains("DeviceType.CUDA")] + # print("devicedf is ", device_df) + if device_df.empty: + return [], 0 + kernels_num = int(len(device_df) / num_iters) + + act_iters = num_iters + valid_n = len(device_df) + dropped_indexs = [] + if len(device_df) % num_iters == 0: + kernels_num = int(len(device_df) / num_iters) + else: + ##get correct kernel num + name_list = device_df["name"].tolist() + max_kernel_num = 20 + n = len(name_list) + for step in range(1, min(max_kernel_num, n // 2 + 1)): + sub_list = [name_list[i] for i in range(step)] + m = len(sub_list) + + valid_n = int(n / m) * m + pattern_match = all( + name_list[i] == sub_list[i % m] for i in range(int(n / m) * m) + ) + if pattern_match: + kernels_num = m + act_iters = valid_n / m + break + dropped_indexs = device_df.iloc[valid_n:].index.tolist() + if kernels_num == 0: + print("data missed, the time may be inaccurate!") + + test_df = device_df.iloc[:valid_n].reset_index() + grouped_kernel_df = test_df.groupby(test_df.index // kernels_num, sort=False).agg( + {"self_device_time_total": "sum", "index": list} + ) + + # rm warm iters + sum_df = grouped_kernel_df.iloc[warm_iter:].reset_index(drop=True) + out_range_idx = [] + if num_iters > 30: + # IQR to remove abnormal data + k = 1.5 + Q1 = sum_df["self_device_time_total"].quantile(0.25) + Q3 = sum_df["self_device_time_total"].quantile(0.75) + IQR = Q3 - Q1 + lower = Q1 - k * IQR + upper = Q3 + k * IQR + out_range_idx = sum_df.index[ + (sum_df["self_device_time_total"] < lower) + | (sum_df["self_device_time_total"] > upper) + ].tolist() + out_range_num = len(out_range_idx) + + indices = {idx for i in out_range_idx for idx in sum_df.iloc[i]["index"]} + + index_sublists = grouped_kernel_df["index"].head(warm_iter).tolist() + indices_to_add = [idx for sublist in index_sublists for idx in sublist] + indices.update(indices_to_add) + indices.update(dropped_indexs) + if int(os.environ.get("FLYDSL_LOG_MORE", 0)): + logger.info(f"abnormal data indices: {indices}") + for i in indices: + logger.info(f"abnormal data: {df.iloc[i]['self_device_time_total']}") + return list(indices), out_range_num + warm_iter + num_iters - act_iters + + +def get_trace_perf(prof, num_iters): + assert num_iters > 1 + warm_iter = 1 + num_iters -= warm_iter + df = [] + cols = [ + "name", + "self_cpu_time_total", + "self_device_time_total", + "device_type", + "device_index", + ] + for el in prof.events(): + df.append([getattr(el, x, None) for x in cols]) + df = pd.DataFrame(df, columns=cols) + ###remove abnormal data + dropped_num = warm_iter + dropped_indexs, dropped_num = post_process_data( + df, num_iters + warm_iter, warm_iter + ) + df = df.drop(dropped_indexs) + iter_init = 0 # warm_iter dropped + df["cnt"] = 1 + rets = [] + + for name, d in df.groupby("name", sort=False): + kernel_num_per_iter = iter_init + if str(d["device_type"].iat[0]).split(".")[-1] != "CUDA": + kernel_num_per_iter = 1 + r = d.iloc[kernel_num_per_iter:][ + ["cnt", "self_cpu_time_total", "self_device_time_total"] + ].sum() + if not r.empty: + device_type = str(d["device_type"].iat[0]).split(".")[-1] + r["name"] = name + r["device_type"] = device_type + r["device_index"] = str(d["device_index"].iat[0]) + if device_type == "CUDA": + r["device_time_sum"] = r["self_device_time_total"] + r["host_time_sum"] = 0 + else: + r["host_time_sum"] = r["self_device_time_total"] + r["device_time_sum"] = 0 + rets.append(r) + df = pd.DataFrame(rets) + cols = [ + "name", + "cnt", + "host_time_sum", + "device_time_sum", + "device_type", + "device_index", + ] + cols = [el for el in cols if el in df.columns] + df = df[(df.host_time_sum > 0) | (df.device_time_sum > 0)] + + timerList = [ + "host_time_sum", + "device_time_sum", + ] + df = df[cols].sort_values(timerList, ignore_index=True) + actual_iters = num_iters + warm_iter - dropped_num + if df.empty: + logger.info("no valida data after post process!") + + avg_name = "[avg us/iter]" + for el in timerList: + if el == "host_time_sum": + df.at[avg_name, el] = df[el].sum() / num_iters + else: + df.at[avg_name, el] = df[el].sum() / actual_iters + if int(os.environ.get("FLYDSL_LOG_MORE", 0)): + pd.set_option("display.expand_frame_repr", False) + pd.set_option("display.max_colwidth", 90) + pd.set_option("display.float_format", "{:,.1f}".format) + logger.info(f"{df}") + return df.at[avg_name, "device_time_sum"] \ No newline at end of file From d6fd3a2e1ad429c00eaadf62ad29b565c7c64ba4 Mon Sep 17 00:00:00 2001 From: Lukasz Burzawa Date: Fri, 8 May 2026 16:06:16 -0700 Subject: [PATCH 3/8] Update aiter/ops/flydsl/kernels/attn_bwd_mxfp8_gfx950.py Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- aiter/ops/flydsl/kernels/attn_bwd_mxfp8_gfx950.py | 1 - 1 file changed, 1 deletion(-) diff --git a/aiter/ops/flydsl/kernels/attn_bwd_mxfp8_gfx950.py b/aiter/ops/flydsl/kernels/attn_bwd_mxfp8_gfx950.py index ea2960b746..31618e06bd 100644 --- a/aiter/ops/flydsl/kernels/attn_bwd_mxfp8_gfx950.py +++ b/aiter/ops/flydsl/kernels/attn_bwd_mxfp8_gfx950.py @@ -93,7 +93,6 @@ def compile_attn_bwd_mxfp8_gfx950( allocator_ds_shuffle = SmemAllocator(None, arch=gpu_arch, global_sym_name="smem_ds_shuffle") allocator_ds_scale_shuffle = SmemAllocator(None, arch=gpu_arch, global_sym_name="smem_ds_scale_shuffle") - num_waves = 4 wave_size = 64 total_threads = 256 From 0fc8327c77c1ef7f96907710ede4432131dde24a Mon Sep 17 00:00:00 2001 From: Lukasz Burzawa Date: Wed, 13 May 2026 16:24:20 +0000 Subject: [PATCH 4/8] add gqa support --- .../flydsl/kernels/attn_bwd_mxfp8_gfx950.py | 185 +++++++++++------- .../test_attn_bwd_mxfp8_gfx950.py | 86 +++++--- .../flydsl/bench_attn_bwd_mxfp8_gfx950.py | 50 +++-- 3 files changed, 206 insertions(+), 115 deletions(-) diff --git a/aiter/ops/flydsl/kernels/attn_bwd_mxfp8_gfx950.py b/aiter/ops/flydsl/kernels/attn_bwd_mxfp8_gfx950.py index 31618e06bd..47fd212e7e 100644 --- a/aiter/ops/flydsl/kernels/attn_bwd_mxfp8_gfx950.py +++ b/aiter/ops/flydsl/kernels/attn_bwd_mxfp8_gfx950.py @@ -54,6 +54,8 @@ def lds_transpose_load(lds_memref, elem_offset): def compile_attn_bwd_mxfp8_gfx950( *, + num_heads_q: int, + num_heads_kv: int, seqlen: int, head_dim: int, tile_m: int, @@ -75,6 +77,7 @@ def compile_attn_bwd_mxfp8_gfx950( tile_head_mx = tile_head // 32 tile_m_mx = tile_m // 32 tile_n_mx = tile_n // 32 + gqa_size = num_heads_q // num_heads_kv gpu_arch = get_hip_arch() @@ -213,7 +216,15 @@ def kernel_attn_bwd( arg_do_scale_m: fx.Tensor, arg_m: fx.Tensor, arg_D: fx.Tensor, - batch: fx.Int32 + batch: fx.Int32, + stride_qo_batch: fx.Int32, + stride_qo_scale_batch: fx.Int32, + stride_kv_batch: fx.Int32, + stride_kv_scale_batch: fx.Int32, + stride_MD_batch: fx.Int32, + stride_qkvo_nheads: fx.Int32, + stride_qkvo_scale_nheads: fx.Int32, + stride_MD_nheads: fx.Int32 ): # ---- Types ---- @@ -226,6 +237,10 @@ def kernel_attn_bwd( tx = gpu.thread_id("x") bx = gpu.block_id("x") by = gpu.block_id("y") + bz = gpu.block_id("z") + batch_id = bz + head_q = bx + head_kv = head_q // gqa_size # ---- LDS (separate ping/pong buffers) ---- base_ptr_pong = allocator_pong.get_base() @@ -331,24 +346,26 @@ def kernel_attn_bwd( # ---- Buffer resources (runtime byte sizes for OOB protection) ---- head_dim_mx = head_dim // 32 - global_buffer_size = fx.Index(batch * seqlen * head_dim) - global_buffer_size_scale = fx.Index(batch * seqlen * head_dim_mx) - q_nrec = arith.index_cast(T.i64, global_buffer_size) - q_scale_head_nrec = arith.index_cast(T.i64, global_buffer_size_scale) - q_scale_m_nrec = arith.index_cast(T.i64, global_buffer_size_scale) - k_nrec = arith.index_cast(T.i64, global_buffer_size) - k_scale_head_nrec = arith.index_cast(T.i64, global_buffer_size_scale) - k_scale_n_nrec = arith.index_cast(T.i64, global_buffer_size_scale) - v_nrec = arith.index_cast(T.i64, global_buffer_size) - v_scale_nrec = arith.index_cast(T.i64, global_buffer_size_scale) - do_nrec = arith.index_cast(T.i64, global_buffer_size) - do_scale_head_nrec = arith.index_cast(T.i64, global_buffer_size_scale) - do_scale_m_nrec = arith.index_cast(T.i64, global_buffer_size_scale) - dq_nrec = arith.index_cast(T.i64, global_buffer_size * 4) - dk_nrec = arith.index_cast(T.i64, global_buffer_size * 2) - dv_nrec = arith.index_cast(T.i64, global_buffer_size * 2) - m_nrec = arith.index_cast(T.i64, fx.Index(batch * seqlen * 4)) - D_nrec = arith.index_cast(T.i64, fx.Index(batch * seqlen * 4)) + global_buffer_size_qo = fx.Index(batch * num_heads_q * seqlen * head_dim) + global_buffer_size_kv = fx.Index(batch * num_heads_kv * seqlen * head_dim) + global_buffer_size_qo_scale = fx.Index(batch * num_heads_q * seqlen * head_dim_mx) + global_buffer_size_kv_scale = fx.Index(batch * num_heads_kv * seqlen * head_dim_mx) + q_nrec = arith.index_cast(T.i64, global_buffer_size_qo) + q_scale_head_nrec = arith.index_cast(T.i64, global_buffer_size_qo_scale) + q_scale_m_nrec = arith.index_cast(T.i64, global_buffer_size_qo_scale) + k_nrec = arith.index_cast(T.i64, global_buffer_size_kv) + k_scale_head_nrec = arith.index_cast(T.i64, global_buffer_size_kv_scale) + k_scale_n_nrec = arith.index_cast(T.i64, global_buffer_size_kv_scale) + v_nrec = arith.index_cast(T.i64, global_buffer_size_kv) + v_scale_nrec = arith.index_cast(T.i64, global_buffer_size_kv_scale) + do_nrec = arith.index_cast(T.i64, global_buffer_size_qo) + do_scale_head_nrec = arith.index_cast(T.i64, global_buffer_size_qo_scale) + do_scale_m_nrec = arith.index_cast(T.i64, global_buffer_size_qo_scale) + dq_nrec = arith.index_cast(T.i64, global_buffer_size_qo * 4) + dk_nrec = arith.index_cast(T.i64, global_buffer_size_kv * 4) + dv_nrec = arith.index_cast(T.i64, global_buffer_size_kv * 4) + m_nrec = arith.index_cast(T.i64, fx.Index(batch * num_heads_q * seqlen * 4)) + D_nrec = arith.index_cast(T.i64, fx.Index(batch * num_heads_q * seqlen * 4)) q_quant_head_rsrc = buffer_ops.create_buffer_resource(arg_q_quant_head, max_size=False, num_records_bytes=q_nrec) q_scale_head_rsrc = buffer_ops.create_buffer_resource(arg_q_scale_head, max_size=False, num_records_bytes=q_scale_head_nrec) @@ -373,7 +390,7 @@ def kernel_attn_bwd( num_records_bytes=m_nrec) D_rsrc = buffer_ops.create_buffer_resource(arg_D, max_size=False, num_records_bytes=D_nrec) - global_offset_n = bx * tile_n + global_offset_n = by * tile_n global_offset_n_mx = global_offset_n // 32 # ---- Wave / lane decomposition ---- @@ -493,12 +510,14 @@ def lds_scale_load(row, col, lds_stride, lds_buffer): layout_kv_tile_div4 = fx.make_layout((tile_n, tile_head_dwords), (tile_head_dwords, 1)) c4 = fx.Index(4) tx_i32_base = tx * c4 - batch_offset_tensor = by * seqlen * head_dim_div4 - batch_offset_scale = by * seqlen * head_dim_mx - batch_offset_mD = by * seqlen - batch_offset_dq = by * seqlen * head_dim - batch_offset_dkdv = by * seqlen * head_dim + offset_qo_nheads = batch_id * fx.Index(stride_qo_batch) + head_q * fx.Index(stride_qkvo_nheads) + offset_qo_nheads_div4 = offset_qo_nheads // 4 + offset_kv_nheads = batch_id * fx.Index(stride_kv_batch) + head_kv * fx.Index(stride_qkvo_nheads) + offset_kv_nheads_div4 = offset_kv_nheads // 4 + offset_qo_scale_nheads = batch_id * fx.Index(stride_qo_scale_batch) + head_q * fx.Index(stride_qkvo_scale_nheads) + offset_kv_scale_nheads = batch_id * fx.Index(stride_kv_scale_batch) + head_kv * fx.Index(stride_qkvo_scale_nheads) + offset_MD_nheads = batch_id * fx.Index(stride_MD_batch) + head_q * fx.Index(stride_MD_nheads) def load_q_quant_head_16(idx_elem): return buffer_copy_gmem16_dwordx4( @@ -580,7 +599,7 @@ def prefetch_q_quant_head_tile(offset_m): for i in range_constexpr(num_qo_loads): row_q_local, col_q_local_i32 = qo_tile_chunk_coord_i32(i) row_q_global = offset_m + row_q_local - idx_elem = batch_offset_tensor + row_q_global * head_dim_div4 + col_q_local_i32 + idx_elem = offset_qo_nheads_div4 + row_q_global * head_dim_div4 + col_q_local_i32 q_16B = load_q_quant_head_16(idx_elem) parts.append(vector.bitcast(T.i32x4, q_16B)) return parts @@ -590,7 +609,7 @@ def prefetch_q_quant_m_tile(offset_m): for i in range_constexpr(num_qo_loads): row_q_local, col_q_local_i32 = qo_tile_chunk_coord_i32(i) row_q_global = offset_m + row_q_local - idx_elem = batch_offset_tensor + row_q_global * head_dim_div4 + col_q_local_i32 + idx_elem = offset_qo_nheads_div4 + row_q_global * head_dim_div4 + col_q_local_i32 q_16B = load_q_quant_m_16(idx_elem) parts.append(vector.bitcast(T.i32x4, q_16B)) return parts @@ -600,7 +619,7 @@ def prefetch_k_quant_head_tile(): for i in range_constexpr(num_kv_loads): row_k_local, col_k_local_i32 = kv_tile_chunk_coord_i32(i) row_k_global = global_offset_n + row_k_local - idx_elem = batch_offset_tensor + row_k_global * head_dim_div4 + col_k_local_i32 + idx_elem = offset_kv_nheads_div4 + row_k_global * head_dim_div4 + col_k_local_i32 k_16B = load_k_quant_head_16(idx_elem) parts.append(vector.bitcast(T.i32x4, k_16B)) return parts @@ -610,7 +629,7 @@ def prefetch_k_quant_n_tile(): for i in range_constexpr(num_kv_loads): row_k_local, col_k_local_i32 = kv_tile_chunk_coord_i32(i) row_k_global = global_offset_n + row_k_local - idx_elem = batch_offset_tensor + row_k_global * head_dim_div4 + col_k_local_i32 + idx_elem = offset_kv_nheads_div4 + row_k_global * head_dim_div4 + col_k_local_i32 k_16B = load_k_quant_n_16(idx_elem) parts.append(vector.bitcast(T.i32x4, k_16B)) return parts @@ -620,7 +639,7 @@ def prefetch_v_tile(): for i in range_constexpr(num_kv_loads): row_v_local, col_v_local_i32 = kv_tile_chunk_coord_i32(i) row_v_global = global_offset_n + row_v_local - idx_elem = batch_offset_tensor + row_v_global * head_dim_div4 + col_v_local_i32 + idx_elem = offset_kv_nheads_div4 + row_v_global * head_dim_div4 + col_v_local_i32 v_16B = load_v_16(idx_elem) parts.append(vector.bitcast(T.i32x4, v_16B)) return parts @@ -630,7 +649,7 @@ def prefetch_do_quant_head_tile(offset_m): for i in range_constexpr(num_qo_loads): row_do_local, col_do_local_i32 = qo_tile_chunk_coord_i32(i) row_do_global = offset_m + row_do_local - idx_elem = batch_offset_tensor + row_do_global * head_dim_div4 + col_do_local_i32 + idx_elem = offset_qo_nheads_div4 + row_do_global * head_dim_div4 + col_do_local_i32 do_16B = load_do_quant_head_16(idx_elem) parts.append(vector.bitcast(T.i32x4, do_16B)) return parts @@ -640,7 +659,7 @@ def prefetch_do_quant_m_tile(offset_m): for i in range_constexpr(num_qo_loads): row_do_local, col_do_local_i32 = qo_tile_chunk_coord_i32(i) row_do_global = offset_m + row_do_local - idx_elem = batch_offset_tensor + row_do_global * head_dim_div4 + col_do_local_i32 + idx_elem = offset_qo_nheads_div4 + row_do_global * head_dim_div4 + col_do_local_i32 do_16B = load_do_quant_m_16(idx_elem) parts.append(vector.bitcast(T.i32x4, do_16B)) return parts @@ -649,13 +668,13 @@ def prefetch_q_scale_head_tile(offset_m): vec_width = bytes_per_thread_qo_scale if const_expr(vec_width == 1): if const_expr(bytes_per_tile_qo_scale < total_threads): - idx_elem = batch_offset_scale + offset_m * head_dim_mx + tx % bytes_per_tile_qo_scale + idx_elem = offset_qo_scale_nheads + offset_m * head_dim_mx + tx % bytes_per_tile_qo_scale else: - idx_elem = batch_offset_scale + offset_m * head_dim_mx + tx + idx_elem = offset_qo_scale_nheads + offset_m * head_dim_mx + tx vec = buffer_ops.buffer_load(q_scale_head_rsrc, idx_elem, vec_width=1, dtype=T.i8) vec = vector.from_elements(T.vec(1, T.i8), [vec]) else: # vec_width=2 - idx_elem = (batch_offset_scale + offset_m * head_dim_mx + tx * vec_width) // 2 + idx_elem = (offset_qo_scale_nheads + offset_m * head_dim_mx + tx * vec_width) // 2 vec = buffer_ops.buffer_load(q_scale_head_rsrc, idx_elem, vec_width=1, dtype=T.i16) vec = vector.from_elements(T.vec(1, T.i16), [vec]) vec = vector.bitcast(T.i8x2, vec) @@ -665,13 +684,13 @@ def prefetch_q_scale_m_tile(offset_m): vec_width = bytes_per_thread_qo_scale if const_expr(vec_width == 1): if const_expr(bytes_per_tile_qo_scale < total_threads): - idx_elem = batch_offset_scale + offset_m * head_dim + tx % bytes_per_tile_qo_scale + idx_elem = offset_qo_scale_nheads + offset_m * head_dim + tx % bytes_per_tile_qo_scale else: - idx_elem = batch_offset_scale + offset_m * head_dim + tx + idx_elem = offset_qo_scale_nheads + offset_m * head_dim + tx vec = buffer_ops.buffer_load(q_scale_m_rsrc, idx_elem, vec_width=1, dtype=T.i8) vec = vector.from_elements(T.vec(1, T.i8), [vec]) else: # vec_width=2 - idx_elem = (batch_offset_scale + offset_m * head_dim + tx * vec_width) // 2 + idx_elem = (offset_qo_scale_nheads + offset_m * head_dim + tx * vec_width) // 2 vec = buffer_ops.buffer_load(q_scale_m_rsrc, idx_elem, vec_width=1, dtype=T.i16) vec = vector.from_elements(T.vec(1, T.i16), [vec]) vec = vector.bitcast(T.i8x2, vec) @@ -681,13 +700,13 @@ def prefetch_k_scale_head_tile(): vec_width = bytes_per_thread_kv_scale if const_expr(vec_width == 1): if const_expr(bytes_per_tile_kv_scale < total_threads): - idx_elem = batch_offset_scale + global_offset_n * head_dim_mx + tx % bytes_per_tile_kv_scale + idx_elem = offset_kv_scale_nheads + global_offset_n * head_dim_mx + tx % bytes_per_tile_kv_scale else: - idx_elem = batch_offset_scale + global_offset_n * head_dim_mx + tx + idx_elem = offset_kv_scale_nheads + global_offset_n * head_dim_mx + tx vec = buffer_ops.buffer_load(k_scale_head_rsrc, idx_elem, vec_width=1, dtype=T.i8) vec = vector.from_elements(T.vec(1, T.i8), [vec]) else: # vec_width=2 - idx_elem = (batch_offset_scale + global_offset_n * head_dim_mx + tx * vec_width) // 2 + idx_elem = (offset_kv_scale_nheads + global_offset_n * head_dim_mx + tx * vec_width) // 2 vec = buffer_ops.buffer_load(k_scale_head_rsrc, idx_elem, vec_width=1, dtype=T.i16) vec = vector.from_elements(T.vec(1, T.i16), [vec]) vec = vector.bitcast(T.i8x2, vec) @@ -697,13 +716,13 @@ def prefetch_k_scale_n_tile(): vec_width = bytes_per_thread_kv_scale if const_expr(vec_width == 1): if const_expr(bytes_per_tile_kv_scale < total_threads): - idx_elem = batch_offset_scale + global_offset_n_mx * head_dim + tx % bytes_per_tile_kv_scale + idx_elem = offset_kv_scale_nheads + global_offset_n_mx * head_dim + tx % bytes_per_tile_kv_scale else: - idx_elem = batch_offset_scale + global_offset_n_mx * head_dim + tx + idx_elem = offset_kv_scale_nheads + global_offset_n_mx * head_dim + tx vec = buffer_ops.buffer_load(k_scale_n_rsrc, idx_elem, vec_width=1, dtype=T.i8) vec = vector.from_elements(T.vec(1, T.i8), [vec]) else: # vec_width=2 - idx_elem = (batch_offset_scale + global_offset_n_mx * head_dim + tx * vec_width) // 2 + idx_elem = (offset_kv_scale_nheads + global_offset_n_mx * head_dim + tx * vec_width) // 2 vec = buffer_ops.buffer_load(k_scale_n_rsrc, idx_elem, vec_width=1, dtype=T.i16) vec = vector.from_elements(T.vec(1, T.i16), [vec]) vec = vector.bitcast(T.i8x2, vec) @@ -713,13 +732,13 @@ def prefetch_v_scale_tile(): vec_width = bytes_per_thread_kv_scale if const_expr(vec_width == 1): if const_expr(bytes_per_tile_kv_scale < total_threads): - idx_elem = batch_offset_scale + global_offset_n * head_dim_mx + tx % bytes_per_tile_kv_scale + idx_elem = offset_kv_scale_nheads + global_offset_n * head_dim_mx + tx % bytes_per_tile_kv_scale else: - idx_elem = batch_offset_scale + global_offset_n * head_dim_mx + tx + idx_elem = offset_kv_scale_nheads + global_offset_n * head_dim_mx + tx vec = buffer_ops.buffer_load(v_scale_rsrc, idx_elem, vec_width=1, dtype=T.i8) vec = vector.from_elements(T.vec(1, T.i8), [vec]) else: # vec_width=2 - idx_elem = (batch_offset_scale + global_offset_n * head_dim_mx + tx * vec_width) // 2 + idx_elem = (offset_kv_scale_nheads + global_offset_n * head_dim_mx + tx * vec_width) // 2 vec = buffer_ops.buffer_load(v_scale_rsrc, idx_elem, vec_width=1, dtype=T.i16) vec = vector.from_elements(T.vec(1, T.i16), [vec]) vec = vector.bitcast(T.i8x2, vec) @@ -729,13 +748,13 @@ def prefetch_do_scale_head_tile(offset_m): vec_width = bytes_per_thread_qo_scale if const_expr(vec_width == 1): if const_expr(bytes_per_tile_qo_scale < total_threads): - idx_elem = batch_offset_scale + offset_m * head_dim_mx + tx % bytes_per_tile_qo_scale + idx_elem = offset_qo_scale_nheads + offset_m * head_dim_mx + tx % bytes_per_tile_qo_scale else: - idx_elem = batch_offset_scale + offset_m * head_dim_mx + tx + idx_elem = offset_qo_scale_nheads + offset_m * head_dim_mx + tx vec = buffer_ops.buffer_load(do_scale_head_rsrc, idx_elem, vec_width=1, dtype=T.i8) vec = vector.from_elements(T.vec(1, T.i8), [vec]) else: # vec_width=2 - idx_elem = (batch_offset_scale + offset_m * head_dim_mx + tx * vec_width) // 2 + idx_elem = (offset_qo_scale_nheads + offset_m * head_dim_mx + tx * vec_width) // 2 vec = buffer_ops.buffer_load(do_scale_head_rsrc, idx_elem, vec_width=1, dtype=T.i16) vec = vector.from_elements(T.vec(1, T.i16), [vec]) vec = vector.bitcast(T.i8x2, vec) @@ -745,13 +764,13 @@ def prefetch_do_scale_m_tile(offset_m): vec_width = bytes_per_thread_qo_scale if const_expr(vec_width == 1): if const_expr(bytes_per_tile_qo_scale < total_threads): - idx_elem = batch_offset_scale + offset_m * head_dim + tx % bytes_per_tile_qo_scale + idx_elem = offset_qo_scale_nheads + offset_m * head_dim + tx % bytes_per_tile_qo_scale else: - idx_elem = batch_offset_scale + offset_m * head_dim + tx + idx_elem = offset_qo_scale_nheads + offset_m * head_dim + tx vec = buffer_ops.buffer_load(do_scale_m_rsrc, idx_elem, vec_width=1, dtype=T.i8) vec = vector.from_elements(T.vec(1, T.i8), [vec]) else: # vec_width=2 - idx_elem = (batch_offset_scale + offset_m * head_dim + tx * vec_width) // 2 + idx_elem = (offset_qo_scale_nheads + offset_m * head_dim + tx * vec_width) // 2 vec = buffer_ops.buffer_load(do_scale_m_rsrc, idx_elem, vec_width=1, dtype=T.i16) vec = vector.from_elements(T.vec(1, T.i16), [vec]) vec = vector.bitcast(T.i8x2, vec) @@ -888,7 +907,7 @@ def softmax(accs_in, offset_m): accs_out = [acc_init] * ps_n_accs for mi in range_constexpr(ps_m_num_subtiles): - global_m_norm_idx = batch_offset_mD + offset_m + ps_m_wave_id * ps_m_per_wave + mi * 16 + lane_div_16 * 4 #+ ii + global_m_norm_idx = offset_MD_nheads + offset_m + ps_m_wave_id * ps_m_per_wave + mi * 16 + lane_div_16 * 4 #+ ii m_norm_vector = buffer_ops.buffer_load(m_rsrc, global_m_norm_idx, vec_width=4) for ni in range_constexpr(ps_n_num_subtiles): @@ -1024,7 +1043,7 @@ def compute_ds(dp_accs, p_accs, offset_m): for mi in range_constexpr(ps_m_num_subtiles): - global_D_idx = batch_offset_mD + offset_m + ps_m_wave_id * ps_m_per_wave + mi * 16 + lane_div_16 * 4 + global_D_idx = offset_MD_nheads + offset_m + ps_m_wave_id * ps_m_per_wave + mi * 16 + lane_div_16 * 4 D_vector = buffer_ops.buffer_load(D_rsrc, global_D_idx, vec_width=4) for ni in range_constexpr(ps_n_num_subtiles): @@ -1330,7 +1349,7 @@ def store_dq_atomic(final_accs, offset_m): for ii in range_constexpr(4): global_row = offset_m + dq_m_wave_id * dq_m_per_wave + mi * 16 + lane_div_16 * 4 + ii global_col = dq_head_wave_id * dq_head_per_wave + hi * 16 + lane_mod_16 - global_idx = batch_offset_dq + global_row * head_dim + global_col + global_idx = offset_qo_nheads + global_row * head_dim + global_col global_idx_bytes = global_idx * 4 acc_idx = mi * dq_num_subtiles_head + hi @@ -1341,7 +1360,7 @@ def store_dq_atomic(final_accs, offset_m): #buffer_ops.buffer_store(val_f32, dq_rsrc, global_idx) - def store_dk_bf16(final_accs): + def store_dk_atomic(final_accs): for ni in range_constexpr(dk_num_subtiles_n): for hi in range_constexpr(dk_num_subtiles_head): acc_idx = ni * dk_num_subtiles_head + hi @@ -1350,17 +1369,20 @@ def store_dk_bf16(final_accs): global_row = global_offset_n + dk_n_wave_id * dk_n_per_wave + ni * 16 + lane_div_16 * 4 + ii global_col = dk_head_wave_id * dk_head_per_wave + hi * 16 + lane_mod_16 - global_idx = batch_offset_dkdv + global_row * head_dim + global_col + global_idx = offset_kv_nheads + global_row * head_dim + global_col acc_idx = ni * dk_num_subtiles_head + hi acc = final_accs[acc_idx] val_f32 = vector.extract(acc, static_position=[ii], dynamic_position=[]) val_f32 = val_f32 * c_sm_scale - val_bf16 = arith.trunc_f(T.bf16, val_f32) - buffer_ops.buffer_store(val_bf16, dk_rsrc, global_idx) + if const_expr(gqa_size == 1): + buffer_ops.buffer_store(val_f32, dk_rsrc, global_idx) + else: + global_idx_bytes = global_idx * 4 + rocdl.raw_ptr_buffer_atomic_fadd(val_f32, dk_rsrc, fx.Int32(global_idx_bytes), fx.Int32(0), fx.Int32(0)) - def store_dv_bf16(final_accs): + def store_dv_atomic(final_accs): for ni in range_constexpr(dv_n_num_subtiles): for hi in range_constexpr(dv_head_num_subtiles): acc_idx = ni * dv_head_num_subtiles + hi @@ -1369,13 +1391,16 @@ def store_dv_bf16(final_accs): global_row = global_offset_n + dv_n_wave_id * dv_n_per_wave + ni * 16 + lane_div_16 * 4 + ii global_col = dv_head_wave_id * dv_head_per_wave + hi * 16 + lane_mod_16 - global_idx = batch_offset_dkdv + global_row * head_dim + global_col + global_idx = offset_kv_nheads + global_row * head_dim + global_col acc_idx = ni * dv_head_num_subtiles + hi acc = final_accs[acc_idx] val_f32 = vector.extract(acc, static_position=[ii], dynamic_position=[]) - val_bf16 = arith.trunc_f(T.bf16, val_f32) - buffer_ops.buffer_store(val_bf16, dv_rsrc, global_idx) + if const_expr(gqa_size == 1): + buffer_ops.buffer_store(val_f32, dv_rsrc, global_idx) + else: + global_idx_bytes = global_idx * 4 + rocdl.raw_ptr_buffer_atomic_fadd(val_f32, dv_rsrc, fx.Int32(global_idx_bytes), fx.Int32(0), fx.Int32(0)) # ── Scheduling hints ────────────────────────────────────────────── @@ -1551,8 +1576,8 @@ def pingpong(offset_m, inner_state): dq = compute_dq(lds_ds_shuffle, lds_ds_scale_shuffle, lds_k_quant_n, lds_k_scale_n) store_dq_atomic(dq, curr_m) - store_dk_bf16(dk) - store_dv_bf16(dv) + store_dk_atomic(dk) + store_dv_atomic(dv) # ── Host launcher ────────────────────────────────────────────────────── @@ -1580,6 +1605,14 @@ def launch_attn_bwd( arg_m: fx.Tensor, arg_D: fx.Tensor, batch: fx.Int32, + stride_qo_batch: fx.Int32, + stride_qo_scale_batch: fx.Int32, + stride_kv_batch: fx.Int32, + stride_kv_scale_batch: fx.Int32, + stride_MD_batch: fx.Int32, + stride_qkvo_nheads: fx.Int32, + stride_qkvo_scale_nheads: fx.Int32, + stride_MD_nheads: fx.Int32, stream: fx.Stream, ): _ = _cache_tag @@ -1614,10 +1647,20 @@ def launch_attn_bwd( allocator_ds_shuffle.finalize() allocator_ds_scale_shuffle.finalize() - gx = seqlen // tile_n - gy = batch - - launcher = kernel_attn_bwd(arg_dq, arg_dk, arg_dv, arg_q_quant_head, arg_q_scale_head, arg_q_quant_m, arg_q_scale_m, arg_k_quant_head, arg_k_scale_head, arg_k_quant_n, arg_k_scale_n, arg_v, arg_v_scale, arg_do_quant_head, arg_do_scale_head, arg_do_quant_m, arg_do_scale_m, arg_m, arg_D, batch) + gx = num_heads_q + gy = seqlen // tile_n + gz = batch + + launcher = kernel_attn_bwd(arg_dq, arg_dk, arg_dv, arg_q_quant_head, arg_q_scale_head, arg_q_quant_m, arg_q_scale_m, arg_k_quant_head, arg_k_scale_head, arg_k_quant_n, arg_k_scale_n, arg_v, arg_v_scale, arg_do_quant_head, arg_do_scale_head, arg_do_quant_m, arg_do_scale_m, arg_m, arg_D, + batch, + stride_qo_batch, + stride_qo_scale_batch, + stride_kv_batch, + stride_kv_scale_batch, + stride_MD_batch, + stride_qkvo_nheads, + stride_qkvo_scale_nheads, + stride_MD_nheads) if waves_per_eu is not None: _wpe = int(waves_per_eu) if _wpe >= 1: @@ -1625,7 +1668,7 @@ def launch_attn_bwd( if hasattr(op, 'attributes') and op.OPERATION_NAME == "gpu.func": op.attributes["rocdl.waves_per_eu"] = ir.IntegerAttr.get(T.i32, _wpe) launcher.launch( - grid=(gx, gy, 1), + grid=(gx, gy, gz), block=(256, 1, 1), stream=stream, ) diff --git a/op_tests/flydsl_tests/test_attn_bwd_mxfp8_gfx950.py b/op_tests/flydsl_tests/test_attn_bwd_mxfp8_gfx950.py index 8eff2b452a..7397b8de91 100644 --- a/op_tests/flydsl_tests/test_attn_bwd_mxfp8_gfx950.py +++ b/op_tests/flydsl_tests/test_attn_bwd_mxfp8_gfx950.py @@ -42,39 +42,54 @@ def mx_quant(x, dim=-1): return x_fp32.contiguous(), x_fp8.contiguous(), x_scale.contiguous() -def run_torch(q_fp32_head, q_fp32_m, k_fp32_head, k_fp32_n, v, do_fp32_head, do_fp32_m, m, D, sm_scale, causal, dtype=torch.float32): - seqlen = q_fp32_head.shape[1] +def run_torch(q_fp32_head, q_fp32_m, k_fp32_head, k_fp32_n, v, do_fp32_head, do_fp32_m, m, D, sm_scale, causal, gqa_size): + batch = q_fp32_head.shape[0] + num_heads_q = q_fp32_head.shape[1] + num_heads_kv = num_heads_q // gqa_size + seqlen = q_fp32_head.shape[2] + head_dim = q_fp32_head.shape[3] device = q_fp32_head.device v_f32 = v.to(torch.float32) qk = torch.matmul(q_fp32_head, k_fp32_head.transpose(-2, -1)) * sm_scale - p = torch.exp(qk - m[:, :, None]) + p = torch.exp(qk - m[:, :, :, None]) if causal: - mask = torch.tril(torch.ones((seqlen, seqlen), device=device)) #.T - p[:, mask == 0] = 0.0 + mask = torch.tril(torch.ones((seqlen, seqlen), device=device)) + p[:, :, mask == 0] = 0.0 ppT, _, _ = mx_quant(p, -2) ppT = ppT.transpose(-2, -1) dv = torch.matmul(ppT, do_fp32_m) dp = torch.matmul(do_fp32_head, v_f32.transpose(-2, -1)) - ds = p * (dp - D[:, :, None]) + ds = p * (dp - D[:, :, :, None]) dsT, _, _ = mx_quant(ds, -1) dsT = dsT.transpose(-2, -1) ds, _, _ = mx_quant(ds, -2) dk = torch.matmul(dsT, q_fp32_m) * sm_scale dq = torch.matmul(ds, k_fp32_n) * sm_scale + dk = dk.view(batch, num_heads_kv, gqa_size, seqlen, head_dim).sum(dim=2) + dv = dv.view(batch, num_heads_kv, gqa_size, seqlen, head_dim).sum(dim=2) + return dq, dk, dv -@pytest.mark.parametrize("batch", [2, 8, 45, 256]) +@pytest.mark.parametrize("batch", [1, 4]) +@pytest.mark.parametrize( + "num_heads_q, num_heads_kv", + [ + (48, 48), + (64, 8), + (80, 20) + ], +) @pytest.mark.parametrize("seqlen", [128, 1024, 1152, 4096]) @pytest.mark.parametrize("head_dim", [64, 128]) @pytest.mark.parametrize("tile_m", [64, 128]) @pytest.mark.parametrize("tile_n", [64, 128]) @pytest.mark.parametrize("causal", [False, True]) def test_attn_bwd_flyc( - batch, seqlen, head_dim, - tile_m, tile_n, + batch, num_heads_q, num_heads_kv, seqlen, head_dim, + tile_m, tile_n, causal, waves_per_eu: int = 0, ): @@ -87,6 +102,7 @@ def test_attn_bwd_flyc( sm_scale = 0.5 _wpe = int(waves_per_eu) launch_fn = compile_attn_bwd_mxfp8_gfx950( + num_heads_q=num_heads_q, num_heads_kv=num_heads_kv, seqlen=seqlen, head_dim=head_dim, tile_m=tile_m, tile_n=tile_n, tile_head=tile_head, sm_scale=sm_scale, @@ -95,34 +111,40 @@ def test_attn_bwd_flyc( ) device = torch.device("cuda") - q_fp32 = torch.randn(batch, seqlen, head_dim, device=device, dtype=torch.float32) * 0.5 - k_fp32 = torch.randn(batch, seqlen, head_dim, device=device, dtype=torch.float32) * 0.5 - v_fp32 = torch.randn(batch, seqlen, head_dim, device=device, dtype=torch.float32) * 0.5 - o_fp32 = torch.randn(batch, seqlen, head_dim, device=device, dtype=torch.float32) * 0.5 - do_fp32 = torch.randn(batch, seqlen, head_dim, device=device, dtype=torch.float32) * 0.5 + gqa_size = num_heads_q // num_heads_kv + q_fp32 = torch.randn(batch, num_heads_q, seqlen, head_dim, device=device, dtype=torch.float32) * 0.5 + k_fp32 = torch.randn(batch, num_heads_kv, seqlen, head_dim, device=device, dtype=torch.float32) * 0.5 + v_fp32 = torch.randn(batch, num_heads_kv, seqlen, head_dim, device=device, dtype=torch.float32) * 0.5 + do_fp32 = torch.randn(batch, num_heads_q, seqlen, head_dim, device=device, dtype=torch.float32) * 0.5 + + q_fp32_head, q_quant_head, q_scale_head = mx_quant(q_fp32, -1) + q_fp32_m, q_quant_m, q_scale_m = mx_quant(q_fp32, -2) + k_fp32_head, k_quant_head, k_scale_head = mx_quant(k_fp32, -1) + k_fp32_n, k_quant_n, k_scale_n = mx_quant(k_fp32, -2) + v_fp32, v_quant, v_scale = mx_quant(v_fp32) + do_fp32_head, do_quant_head, do_scale_head = mx_quant(do_fp32, -1) + do_fp32_m, do_quant_m, do_scale_m = mx_quant(do_fp32, -2) + + k_fp32 = k_fp32.repeat_interleave(gqa_size, dim=1) + k_fp32_head = k_fp32_head.repeat_interleave(gqa_size, dim=1) + k_fp32_n = k_fp32_n.repeat_interleave(gqa_size, dim=1) + v_fp32 = v_fp32.repeat_interleave(gqa_size, dim=1) qk = q_fp32 @ k_fp32.transpose(-2, -1) qk = qk * sm_scale m = qk.max(dim=-1)[0] - p = (qk - m[:, :, None]).exp() + p = (qk - m[:, :, :, None]).exp() l = p.sum(dim=-1) - p = p / l[:, :, None] + p = p / l[:, :, :, None] o_fp32 = torch.matmul(p, v_fp32) m = m + torch.log(l) D = (o_fp32 * do_fp32).sum(dim=-1) - q_fp32_head, q_quant_head, q_scale_head = mx_quant(q_fp32, -1) - q_fp32_m, q_quant_m, q_scale_m = mx_quant(q_fp32, -2) - k_fp32_head, k_quant_head, k_scale_head = mx_quant(k_fp32, -1) - k_fp32_n, k_quant_n, k_scale_n = mx_quant(k_fp32, -2) - v_fp32, v_quant, v_scale = mx_quant(v_fp32) - do_fp32_head, do_quant_head, do_scale_head = mx_quant(do_fp32, -1) - do_fp32_m, do_quant_m, do_scale_m = mx_quant(do_fp32, -2) - - dq_ref, dk_ref, dv_ref = run_torch(q_fp32_head, q_fp32_m, k_fp32_head, k_fp32_n, v_fp32, do_fp32_head, do_fp32_m, m, D, sm_scale, causal, dtype=torch.float32) - dq_fly = torch.zeros((batch, seqlen, head_dim), dtype=torch.float32, device=device) - dk_fly = torch.zeros((batch, seqlen, head_dim), dtype=torch.bfloat16, device=device) - dv_fly = torch.zeros((batch, seqlen, head_dim), dtype=torch.bfloat16, device=device) + dq_ref, dk_ref, dv_ref = run_torch(q_fp32_head, q_fp32_m, k_fp32_head, k_fp32_n, v_fp32, do_fp32_head, do_fp32_m, m, D, sm_scale, causal, gqa_size) + + dq_fly = torch.zeros((batch, num_heads_q, seqlen, head_dim), dtype=torch.float32, device=device) + dk_fly = torch.zeros((batch, num_heads_kv, seqlen, head_dim), dtype=torch.float32, device=device) + dv_fly = torch.zeros((batch, num_heads_kv, seqlen, head_dim), dtype=torch.float32, device=device) def launch_kernel(dq, dk, dv, q_quant_head, q_scale_head, q_quant_m, q_scale_m, k_quant_head, k_scale_head, k_quant_n, k_scale_n, v, v_scale, do_quant_head, do_scale_head, do_quant_m, do_scale_m, m, D, batch): launch_fn( @@ -146,6 +168,14 @@ def launch_kernel(dq, dk, dv, q_quant_head, q_scale_head, q_quant_m, q_scale_m, m.contiguous().view(-1), D.contiguous().view(-1), batch, + q_quant_head.stride(0), + q_scale_head.stride(0), + k_quant_head.stride(0), + k_scale_head.stride(0), + m.stride(0), + q_quant_head.stride(1), + q_scale_head.stride(1), + m.stride(1), torch.cuda.current_stream(), ) diff --git a/op_tests/op_benchmarks/flydsl/bench_attn_bwd_mxfp8_gfx950.py b/op_tests/op_benchmarks/flydsl/bench_attn_bwd_mxfp8_gfx950.py index b3d7932eba..3f1ce84779 100644 --- a/op_tests/op_benchmarks/flydsl/bench_attn_bwd_mxfp8_gfx950.py +++ b/op_tests/op_benchmarks/flydsl/bench_attn_bwd_mxfp8_gfx950.py @@ -18,7 +18,7 @@ DEFAULT_BENCH_WARMUP = 3 def bench_attn_bwd_flyc( - batch, seqlen, head_dim, + batch, num_heads_q, num_heads_kv, seqlen, head_dim, tile_m, tile_n, causal, test_graph, @@ -38,7 +38,7 @@ def bench_attn_bwd_flyc( sm_scale = 0.5 _wpe = int(waves_per_eu) if waves_per_eu else 0 launch_fn = compile_attn_bwd_mxfp8_gfx950( - seqlen=seqlen, head_dim=head_dim, + num_heads_q=num_heads_q, num_heads_kv=num_heads_kv, seqlen=seqlen, head_dim=head_dim, tile_m=tile_m, tile_n=tile_n, tile_head=tile_head, sm_scale=sm_scale, causal=causal, @@ -47,11 +47,12 @@ def bench_attn_bwd_flyc( print(f"✓ Kernel prepared") device = torch.device("cuda") - q_fp32 = torch.randn(batch, seqlen, head_dim, device=device, dtype=torch.float32) * 0.5 - k_fp32 = torch.randn(batch, seqlen, head_dim, device=device, dtype=torch.float32) * 0.5 - v_fp32 = torch.randn(batch, seqlen, head_dim, device=device, dtype=torch.float32) * 0.5 - o_fp32 = torch.randn(batch, seqlen, head_dim, device=device, dtype=torch.float32) * 0.5 - do_fp32 = torch.randn(batch, seqlen, head_dim, device=device, dtype=torch.float32) * 0.5 + gqa_size = num_heads_q // num_heads_kv + q_fp32 = torch.randn(batch, num_heads_q, seqlen, head_dim, device=device, dtype=torch.float32) * 0.5 + k_fp32 = torch.randn(batch, num_heads_kv, seqlen, head_dim, device=device, dtype=torch.float32) * 0.5 + v_fp32 = torch.randn(batch, num_heads_kv, seqlen, head_dim, device=device, dtype=torch.float32) * 0.5 + o_fp32 = torch.randn(batch, num_heads_q, seqlen, head_dim, device=device, dtype=torch.float32) * 0.5 + do_fp32 = torch.randn(batch, num_heads_q, seqlen, head_dim, device=device, dtype=torch.float32) * 0.5 q_fp32_head, q_quant_head, q_scale_head = mx_quant(q_fp32, -1) q_fp32_m, q_quant_m, q_scale_m = mx_quant(q_fp32, -2) @@ -61,19 +62,24 @@ def bench_attn_bwd_flyc( do_fp32_head, do_quant_head, do_scale_head = mx_quant(do_fp32, -1) do_fp32_m, do_quant_m, do_scale_m = mx_quant(do_fp32, -2) - qk = q_fp32 @ k_fp32.transpose(-2, -1) + k_fp32 = k_fp32.repeat_interleave(gqa_size, dim=1) + k_fp32_head = k_fp32_head.repeat_interleave(gqa_size, dim=1) + k_fp32_n = k_fp32_n.repeat_interleave(gqa_size, dim=1) + v_fp32 = v_fp32.repeat_interleave(gqa_size, dim=1) + + qk = torch.matmul(q_fp32, k_fp32.transpose(-2, -1)) qk = qk * sm_scale m = qk.max(dim=-1)[0] - p = (qk - m[:, :, None]).exp() + p = (qk - m[:, :, :, None]).exp() l = p.sum(dim=-1) m = m + torch.log(l) D = (o_fp32 * do_fp32).sum(dim=-1) if check_correctness: - dq_ref, dk_ref, dv_ref = run_torch(q_fp32_head, q_fp32_m, k_fp32_head, k_fp32_n, v_fp32, do_fp32_head, do_fp32_m, m, D, sm_scale, causal) - dq_fly = torch.zeros((batch, seqlen, head_dim), dtype=torch.float32, device=device) - dk_fly = torch.zeros((batch, seqlen, head_dim), dtype=torch.bfloat16, device=device) - dv_fly = torch.zeros((batch, seqlen, head_dim), dtype=torch.bfloat16, device=device) + dq_ref, dk_ref, dv_ref = run_torch(q_fp32_head, q_fp32_m, k_fp32_head, k_fp32_n, v_fp32, do_fp32_head, do_fp32_m, m, D, sm_scale, causal, gqa_size) + dq_fly = torch.zeros((batch, num_heads_q, seqlen, head_dim), dtype=torch.float32, device=device) + dk_fly = torch.zeros((batch, num_heads_kv, seqlen, head_dim), dtype=torch.float32, device=device) + dv_fly = torch.zeros((batch, num_heads_kv, seqlen, head_dim), dtype=torch.float32, device=device) def launch_kernel(dq, dk, dv, q_quant_head, q_scale_head, q_quant_m, q_scale_m, k_quant_head, k_scale_head, k_quant_n, k_scale_n, v, v_scale, do_quant_head, do_scale_head, do_quant_m, do_scale_m, m, D, batch): launch_fn( @@ -97,6 +103,14 @@ def launch_kernel(dq, dk, dv, q_quant_head, q_scale_head, q_quant_m, q_scale_m, m.contiguous().view(-1), D.contiguous().view(-1), batch, + q_quant_head.stride(0), + q_scale_head.stride(0), + k_quant_head.stride(0), + k_scale_head.stride(0), + m.stride(0), + q_quant_head.stride(1), + q_scale_head.stride(1), + m.stride(1), torch.cuda.current_stream(), ) @@ -131,6 +145,8 @@ def launch_kernel(dq, dk, dv, q_quant_head, q_scale_head, q_quant_m, q_scale_m, torch.cuda.synchronize() dq_fly.zero_() + dk_fly.zero_() + dv_fly.zero_() launch_kernel( dq_fly, dk_fly, @@ -163,8 +179,8 @@ def launch_kernel(dq, dk, dv, q_quant_head, q_scale_head, q_quant_m, q_scale_m, assert check_result(dk_fly_fp32, dk_ref, rtol=0.01, atol=0.01) assert check_result(dv_fly_fp32, dv_ref, rtol=0.01, atol=0.01) - bytes_moved = (7 + 4 + 2 * 2) * seqlen * head_dim + 2 * 4 * seqlen - flops = batch * (5 * 2 * seqlen * seqlen * head_dim + 5 * seqlen * seqlen + 2 * 3 * seqlen * seqlen) + bytes_moved = (4 + 4) * batch * num_heads_q * seqlen * head_dim + (3 + 2 * 4) * batch * num_heads_kv * seqlen * head_dim + 2 * 4 * batch * num_heads_q * seqlen + flops = batch * num_heads_q * (5 * 2 * seqlen * seqlen * head_dim + 5 * seqlen * seqlen + 2 * 3 * seqlen * seqlen) if causal: flops /= 2 tflops = flops / (us / 1e6) / 1e12 @@ -177,6 +193,8 @@ def launch_kernel(dq, dk, dv, q_quant_head, q_scale_head, q_quant_m, q_scale_m, parser = argparse.ArgumentParser(description="Preshuffle GEMM benchmark") parser.add_argument("--batch", type=int, default=1) + parser.add_argument("--num_heads_q", type=int, default=128) + parser.add_argument("--num_heads_kv", type=int, default=128) parser.add_argument("--seqlen", type=int, default=1024) parser.add_argument("--head", type=int, default=128) parser.add_argument("--tile_m", type=int, default=128) @@ -191,7 +209,7 @@ def launch_kernel(dq, dk, dv, q_quant_head, q_scale_head, q_quant_m, q_scale_m, torch.set_default_device("cuda") bench_attn_bwd_flyc( - batch=args.batch, seqlen=args.seqlen, head_dim=args.head, + batch=args.batch, num_heads_q=args.num_heads_q, num_heads_kv=args.num_heads_kv, seqlen=args.seqlen, head_dim=args.head, tile_m=args.tile_m, tile_n=args.tile_n, causal=args.causal, test_graph=bool(args.test_graph), From b88b1944b22d105da56e53ed0d82bd94c0285e27 Mon Sep 17 00:00:00 2001 From: Lukasz Burzawa Date: Thu, 14 May 2026 00:27:13 +0000 Subject: [PATCH 5/8] support uneven sequences --- .../flydsl/kernels/attn_bwd_mxfp8_gfx950.py | 179 ++++++++---------- .../test_attn_bwd_mxfp8_gfx950.py | 2 +- 2 files changed, 85 insertions(+), 96 deletions(-) diff --git a/aiter/ops/flydsl/kernels/attn_bwd_mxfp8_gfx950.py b/aiter/ops/flydsl/kernels/attn_bwd_mxfp8_gfx950.py index 47fd212e7e..a370a89090 100644 --- a/aiter/ops/flydsl/kernels/attn_bwd_mxfp8_gfx950.py +++ b/aiter/ops/flydsl/kernels/attn_bwd_mxfp8_gfx950.py @@ -78,6 +78,7 @@ def compile_attn_bwd_mxfp8_gfx950( tile_m_mx = tile_m // 32 tile_n_mx = tile_n // 32 gqa_size = num_heads_q // num_heads_kv + seqlen_rounded = ((seqlen + tile_m - 1) // tile_m) * tile_m gpu_arch = get_hip_arch() @@ -214,7 +215,7 @@ def kernel_attn_bwd( arg_do_scale_head: fx.Tensor, arg_do_quant_m: fx.Tensor, arg_do_scale_m: fx.Tensor, - arg_m: fx.Tensor, + arg_M: fx.Tensor, arg_D: fx.Tensor, batch: fx.Int32, stride_qo_batch: fx.Int32, @@ -343,52 +344,48 @@ def kernel_attn_bwd( base_ptr_ds_scale_shuffle, lds_ds_scale_shuffle_offset, T.i8, shape=(tile_m * tile_n_mx,) ).get() + offset_qo_nheads = batch_id * fx.Index(stride_qo_batch) + head_q * fx.Index(stride_qkvo_nheads) + offset_dq_nheads = offset_qo_nheads * 4 + offset_kv_nheads = batch_id * fx.Index(stride_kv_batch) + head_kv * fx.Index(stride_qkvo_nheads) + offset_dkdv_nheads = offset_kv_nheads * 4 + offset_qo_scale_nheads = batch_id * fx.Index(stride_qo_scale_batch) + head_q * fx.Index(stride_qkvo_scale_nheads) + offset_kv_scale_nheads = batch_id * fx.Index(stride_kv_scale_batch) + head_kv * fx.Index(stride_qkvo_scale_nheads) + offset_MD_nheads = (batch_id * fx.Index(stride_MD_batch) + head_q * fx.Index(stride_MD_nheads)) * 4 # ---- Buffer resources (runtime byte sizes for OOB protection) ---- head_dim_mx = head_dim // 32 - global_buffer_size_qo = fx.Index(batch * num_heads_q * seqlen * head_dim) - global_buffer_size_kv = fx.Index(batch * num_heads_kv * seqlen * head_dim) - global_buffer_size_qo_scale = fx.Index(batch * num_heads_q * seqlen * head_dim_mx) - global_buffer_size_kv_scale = fx.Index(batch * num_heads_kv * seqlen * head_dim_mx) - q_nrec = arith.index_cast(T.i64, global_buffer_size_qo) - q_scale_head_nrec = arith.index_cast(T.i64, global_buffer_size_qo_scale) - q_scale_m_nrec = arith.index_cast(T.i64, global_buffer_size_qo_scale) - k_nrec = arith.index_cast(T.i64, global_buffer_size_kv) - k_scale_head_nrec = arith.index_cast(T.i64, global_buffer_size_kv_scale) - k_scale_n_nrec = arith.index_cast(T.i64, global_buffer_size_kv_scale) - v_nrec = arith.index_cast(T.i64, global_buffer_size_kv) - v_scale_nrec = arith.index_cast(T.i64, global_buffer_size_kv_scale) - do_nrec = arith.index_cast(T.i64, global_buffer_size_qo) - do_scale_head_nrec = arith.index_cast(T.i64, global_buffer_size_qo_scale) - do_scale_m_nrec = arith.index_cast(T.i64, global_buffer_size_qo_scale) - dq_nrec = arith.index_cast(T.i64, global_buffer_size_qo * 4) - dk_nrec = arith.index_cast(T.i64, global_buffer_size_kv * 4) - dv_nrec = arith.index_cast(T.i64, global_buffer_size_kv * 4) - m_nrec = arith.index_cast(T.i64, fx.Index(batch * num_heads_q * seqlen * 4)) - D_nrec = arith.index_cast(T.i64, fx.Index(batch * num_heads_q * seqlen * 4)) - - q_quant_head_rsrc = buffer_ops.create_buffer_resource(arg_q_quant_head, max_size=False, num_records_bytes=q_nrec) - q_scale_head_rsrc = buffer_ops.create_buffer_resource(arg_q_scale_head, max_size=False, num_records_bytes=q_scale_head_nrec) - q_quant_m_rsrc = buffer_ops.create_buffer_resource(arg_q_quant_m, max_size=False, num_records_bytes=q_nrec) - q_scale_m_rsrc = buffer_ops.create_buffer_resource(arg_q_scale_m, max_size=False, num_records_bytes=q_scale_m_nrec) - k_quant_head_rsrc = buffer_ops.create_buffer_resource(arg_k_quant_head, max_size=False, num_records_bytes=k_nrec) - k_scale_head_rsrc = buffer_ops.create_buffer_resource(arg_k_scale_head, max_size=False, num_records_bytes=k_scale_head_nrec) - k_quant_n_rsrc = buffer_ops.create_buffer_resource(arg_k_quant_n, max_size=False, num_records_bytes=k_nrec) - k_scale_n_rsrc = buffer_ops.create_buffer_resource(arg_k_scale_n, max_size=False, num_records_bytes=k_scale_n_nrec) - v_rsrc = buffer_ops.create_buffer_resource(arg_v, max_size=False, num_records_bytes=v_nrec) - v_scale_rsrc = buffer_ops.create_buffer_resource(arg_v_scale, max_size=False, num_records_bytes=v_scale_nrec) - do_quant_head_rsrc = buffer_ops.create_buffer_resource(arg_do_quant_head, max_size=False, num_records_bytes=do_nrec) - do_scale_head_rsrc = buffer_ops.create_buffer_resource(arg_do_scale_head, max_size=False, num_records_bytes=do_scale_head_nrec) - do_quant_m_rsrc = buffer_ops.create_buffer_resource(arg_do_quant_m, max_size=False, num_records_bytes=do_nrec) - do_scale_m_rsrc = buffer_ops.create_buffer_resource(arg_do_scale_m, max_size=False, num_records_bytes=do_scale_m_nrec) - dq_rsrc = buffer_ops.create_buffer_resource(arg_dq, max_size=False, num_records_bytes=dq_nrec) - dk_rsrc = buffer_ops.create_buffer_resource(arg_dk, max_size=False, - num_records_bytes=dk_nrec) - dv_rsrc = buffer_ops.create_buffer_resource(arg_dv, max_size=False, - num_records_bytes=dv_nrec) - m_rsrc = buffer_ops.create_buffer_resource(arg_m, max_size=False, - num_records_bytes=m_nrec) - D_rsrc = buffer_ops.create_buffer_resource(arg_D, max_size=False, num_records_bytes=D_nrec) + global_buffer_size_tensor = fx.Index(seqlen * head_dim) + global_buffer_size_scale = fx.Index(seqlen * head_dim_mx) + q_nrec = arith.index_cast(T.i64, global_buffer_size_tensor) + q_scale_nrec = arith.index_cast(T.i64, global_buffer_size_scale) + k_nrec = arith.index_cast(T.i64, global_buffer_size_tensor) + k_scale_nrec = arith.index_cast(T.i64, global_buffer_size_scale) + v_nrec = arith.index_cast(T.i64, global_buffer_size_tensor) + v_scale_nrec = arith.index_cast(T.i64, global_buffer_size_scale) + do_nrec = arith.index_cast(T.i64, global_buffer_size_tensor) + do_scale_nrec = arith.index_cast(T.i64, global_buffer_size_scale) + output_nrec = arith.index_cast(T.i64, global_buffer_size_tensor * 4) + MD_nrec = arith.index_cast(T.i64, fx.Index(seqlen * 4)) + + q_quant_head_rsrc = buffer_ops.create_buffer_resource(arg_q_quant_head, max_size=False, num_records_bytes=q_nrec, base_byte_offset=offset_qo_nheads) + q_scale_head_rsrc = buffer_ops.create_buffer_resource(arg_q_scale_head, max_size=False, num_records_bytes=q_scale_nrec, base_byte_offset=offset_qo_scale_nheads) + q_quant_m_rsrc = buffer_ops.create_buffer_resource(arg_q_quant_m, max_size=False, num_records_bytes=q_nrec, base_byte_offset=offset_qo_nheads) + q_scale_m_rsrc = buffer_ops.create_buffer_resource(arg_q_scale_m, max_size=False, num_records_bytes=q_scale_nrec, base_byte_offset=offset_qo_scale_nheads) + k_quant_head_rsrc = buffer_ops.create_buffer_resource(arg_k_quant_head, max_size=False, num_records_bytes=k_nrec, base_byte_offset=offset_kv_nheads) + k_scale_head_rsrc = buffer_ops.create_buffer_resource(arg_k_scale_head, max_size=False, num_records_bytes=k_scale_nrec, base_byte_offset=offset_kv_scale_nheads) + k_quant_n_rsrc = buffer_ops.create_buffer_resource(arg_k_quant_n, max_size=False, num_records_bytes=k_nrec, base_byte_offset=offset_kv_nheads) + k_scale_n_rsrc = buffer_ops.create_buffer_resource(arg_k_scale_n, max_size=False, num_records_bytes=k_scale_nrec, base_byte_offset=offset_kv_scale_nheads) + v_rsrc = buffer_ops.create_buffer_resource(arg_v, max_size=False, num_records_bytes=v_nrec, base_byte_offset=offset_kv_nheads) + v_scale_rsrc = buffer_ops.create_buffer_resource(arg_v_scale, max_size=False, num_records_bytes=v_scale_nrec, base_byte_offset=offset_kv_scale_nheads) + do_quant_head_rsrc = buffer_ops.create_buffer_resource(arg_do_quant_head, max_size=False, num_records_bytes=do_nrec, base_byte_offset=offset_qo_nheads) + do_scale_head_rsrc = buffer_ops.create_buffer_resource(arg_do_scale_head, max_size=False, num_records_bytes=do_scale_nrec, base_byte_offset=offset_qo_scale_nheads) + do_quant_m_rsrc = buffer_ops.create_buffer_resource(arg_do_quant_m, max_size=False, num_records_bytes=do_nrec, base_byte_offset=offset_qo_nheads) + do_scale_m_rsrc = buffer_ops.create_buffer_resource(arg_do_scale_m, max_size=False, num_records_bytes=do_scale_nrec, base_byte_offset=offset_qo_scale_nheads) + dq_rsrc = buffer_ops.create_buffer_resource(arg_dq, max_size=False, num_records_bytes=output_nrec, base_byte_offset=offset_dq_nheads) + dk_rsrc = buffer_ops.create_buffer_resource(arg_dk, max_size=False, num_records_bytes=output_nrec, base_byte_offset=offset_dkdv_nheads) + dv_rsrc = buffer_ops.create_buffer_resource(arg_dv, max_size=False, num_records_bytes=output_nrec, base_byte_offset=offset_dkdv_nheads) + M_rsrc = buffer_ops.create_buffer_resource(arg_M, max_size=False, num_records_bytes=MD_nrec, base_byte_offset=offset_MD_nheads) + D_rsrc = buffer_ops.create_buffer_resource(arg_D, max_size=False, num_records_bytes=MD_nrec, base_byte_offset=offset_MD_nheads) global_offset_n = by * tile_n global_offset_n_mx = global_offset_n // 32 @@ -511,14 +508,6 @@ def lds_scale_load(row, col, lds_stride, lds_buffer): c4 = fx.Index(4) tx_i32_base = tx * c4 - offset_qo_nheads = batch_id * fx.Index(stride_qo_batch) + head_q * fx.Index(stride_qkvo_nheads) - offset_qo_nheads_div4 = offset_qo_nheads // 4 - offset_kv_nheads = batch_id * fx.Index(stride_kv_batch) + head_kv * fx.Index(stride_qkvo_nheads) - offset_kv_nheads_div4 = offset_kv_nheads // 4 - offset_qo_scale_nheads = batch_id * fx.Index(stride_qo_scale_batch) + head_q * fx.Index(stride_qkvo_scale_nheads) - offset_kv_scale_nheads = batch_id * fx.Index(stride_kv_scale_batch) + head_kv * fx.Index(stride_qkvo_scale_nheads) - offset_MD_nheads = batch_id * fx.Index(stride_MD_batch) + head_q * fx.Index(stride_MD_nheads) - def load_q_quant_head_16(idx_elem): return buffer_copy_gmem16_dwordx4( buffer_ops, vector, @@ -599,7 +588,7 @@ def prefetch_q_quant_head_tile(offset_m): for i in range_constexpr(num_qo_loads): row_q_local, col_q_local_i32 = qo_tile_chunk_coord_i32(i) row_q_global = offset_m + row_q_local - idx_elem = offset_qo_nheads_div4 + row_q_global * head_dim_div4 + col_q_local_i32 + idx_elem = row_q_global * head_dim_div4 + col_q_local_i32 q_16B = load_q_quant_head_16(idx_elem) parts.append(vector.bitcast(T.i32x4, q_16B)) return parts @@ -609,7 +598,7 @@ def prefetch_q_quant_m_tile(offset_m): for i in range_constexpr(num_qo_loads): row_q_local, col_q_local_i32 = qo_tile_chunk_coord_i32(i) row_q_global = offset_m + row_q_local - idx_elem = offset_qo_nheads_div4 + row_q_global * head_dim_div4 + col_q_local_i32 + idx_elem = row_q_global * head_dim_div4 + col_q_local_i32 q_16B = load_q_quant_m_16(idx_elem) parts.append(vector.bitcast(T.i32x4, q_16B)) return parts @@ -619,7 +608,7 @@ def prefetch_k_quant_head_tile(): for i in range_constexpr(num_kv_loads): row_k_local, col_k_local_i32 = kv_tile_chunk_coord_i32(i) row_k_global = global_offset_n + row_k_local - idx_elem = offset_kv_nheads_div4 + row_k_global * head_dim_div4 + col_k_local_i32 + idx_elem = row_k_global * head_dim_div4 + col_k_local_i32 k_16B = load_k_quant_head_16(idx_elem) parts.append(vector.bitcast(T.i32x4, k_16B)) return parts @@ -629,7 +618,7 @@ def prefetch_k_quant_n_tile(): for i in range_constexpr(num_kv_loads): row_k_local, col_k_local_i32 = kv_tile_chunk_coord_i32(i) row_k_global = global_offset_n + row_k_local - idx_elem = offset_kv_nheads_div4 + row_k_global * head_dim_div4 + col_k_local_i32 + idx_elem = row_k_global * head_dim_div4 + col_k_local_i32 k_16B = load_k_quant_n_16(idx_elem) parts.append(vector.bitcast(T.i32x4, k_16B)) return parts @@ -639,7 +628,7 @@ def prefetch_v_tile(): for i in range_constexpr(num_kv_loads): row_v_local, col_v_local_i32 = kv_tile_chunk_coord_i32(i) row_v_global = global_offset_n + row_v_local - idx_elem = offset_kv_nheads_div4 + row_v_global * head_dim_div4 + col_v_local_i32 + idx_elem = row_v_global * head_dim_div4 + col_v_local_i32 v_16B = load_v_16(idx_elem) parts.append(vector.bitcast(T.i32x4, v_16B)) return parts @@ -649,7 +638,7 @@ def prefetch_do_quant_head_tile(offset_m): for i in range_constexpr(num_qo_loads): row_do_local, col_do_local_i32 = qo_tile_chunk_coord_i32(i) row_do_global = offset_m + row_do_local - idx_elem = offset_qo_nheads_div4 + row_do_global * head_dim_div4 + col_do_local_i32 + idx_elem = row_do_global * head_dim_div4 + col_do_local_i32 do_16B = load_do_quant_head_16(idx_elem) parts.append(vector.bitcast(T.i32x4, do_16B)) return parts @@ -659,7 +648,7 @@ def prefetch_do_quant_m_tile(offset_m): for i in range_constexpr(num_qo_loads): row_do_local, col_do_local_i32 = qo_tile_chunk_coord_i32(i) row_do_global = offset_m + row_do_local - idx_elem = offset_qo_nheads_div4 + row_do_global * head_dim_div4 + col_do_local_i32 + idx_elem = row_do_global * head_dim_div4 + col_do_local_i32 do_16B = load_do_quant_m_16(idx_elem) parts.append(vector.bitcast(T.i32x4, do_16B)) return parts @@ -668,13 +657,13 @@ def prefetch_q_scale_head_tile(offset_m): vec_width = bytes_per_thread_qo_scale if const_expr(vec_width == 1): if const_expr(bytes_per_tile_qo_scale < total_threads): - idx_elem = offset_qo_scale_nheads + offset_m * head_dim_mx + tx % bytes_per_tile_qo_scale + idx_elem = offset_m * head_dim_mx + tx % bytes_per_tile_qo_scale else: - idx_elem = offset_qo_scale_nheads + offset_m * head_dim_mx + tx + idx_elem = offset_m * head_dim_mx + tx vec = buffer_ops.buffer_load(q_scale_head_rsrc, idx_elem, vec_width=1, dtype=T.i8) vec = vector.from_elements(T.vec(1, T.i8), [vec]) else: # vec_width=2 - idx_elem = (offset_qo_scale_nheads + offset_m * head_dim_mx + tx * vec_width) // 2 + idx_elem = (offset_m * head_dim_mx + tx * vec_width) // 2 vec = buffer_ops.buffer_load(q_scale_head_rsrc, idx_elem, vec_width=1, dtype=T.i16) vec = vector.from_elements(T.vec(1, T.i16), [vec]) vec = vector.bitcast(T.i8x2, vec) @@ -684,13 +673,13 @@ def prefetch_q_scale_m_tile(offset_m): vec_width = bytes_per_thread_qo_scale if const_expr(vec_width == 1): if const_expr(bytes_per_tile_qo_scale < total_threads): - idx_elem = offset_qo_scale_nheads + offset_m * head_dim + tx % bytes_per_tile_qo_scale + idx_elem = offset_m * head_dim + tx % bytes_per_tile_qo_scale else: - idx_elem = offset_qo_scale_nheads + offset_m * head_dim + tx + idx_elem = offset_m * head_dim + tx vec = buffer_ops.buffer_load(q_scale_m_rsrc, idx_elem, vec_width=1, dtype=T.i8) vec = vector.from_elements(T.vec(1, T.i8), [vec]) else: # vec_width=2 - idx_elem = (offset_qo_scale_nheads + offset_m * head_dim + tx * vec_width) // 2 + idx_elem = (offset_m * head_dim + tx * vec_width) // 2 vec = buffer_ops.buffer_load(q_scale_m_rsrc, idx_elem, vec_width=1, dtype=T.i16) vec = vector.from_elements(T.vec(1, T.i16), [vec]) vec = vector.bitcast(T.i8x2, vec) @@ -700,13 +689,13 @@ def prefetch_k_scale_head_tile(): vec_width = bytes_per_thread_kv_scale if const_expr(vec_width == 1): if const_expr(bytes_per_tile_kv_scale < total_threads): - idx_elem = offset_kv_scale_nheads + global_offset_n * head_dim_mx + tx % bytes_per_tile_kv_scale + idx_elem = global_offset_n * head_dim_mx + tx % bytes_per_tile_kv_scale else: - idx_elem = offset_kv_scale_nheads + global_offset_n * head_dim_mx + tx + idx_elem = global_offset_n * head_dim_mx + tx vec = buffer_ops.buffer_load(k_scale_head_rsrc, idx_elem, vec_width=1, dtype=T.i8) vec = vector.from_elements(T.vec(1, T.i8), [vec]) else: # vec_width=2 - idx_elem = (offset_kv_scale_nheads + global_offset_n * head_dim_mx + tx * vec_width) // 2 + idx_elem = (global_offset_n * head_dim_mx + tx * vec_width) // 2 vec = buffer_ops.buffer_load(k_scale_head_rsrc, idx_elem, vec_width=1, dtype=T.i16) vec = vector.from_elements(T.vec(1, T.i16), [vec]) vec = vector.bitcast(T.i8x2, vec) @@ -716,13 +705,13 @@ def prefetch_k_scale_n_tile(): vec_width = bytes_per_thread_kv_scale if const_expr(vec_width == 1): if const_expr(bytes_per_tile_kv_scale < total_threads): - idx_elem = offset_kv_scale_nheads + global_offset_n_mx * head_dim + tx % bytes_per_tile_kv_scale + idx_elem = global_offset_n_mx * head_dim + tx % bytes_per_tile_kv_scale else: - idx_elem = offset_kv_scale_nheads + global_offset_n_mx * head_dim + tx + idx_elem = global_offset_n_mx * head_dim + tx vec = buffer_ops.buffer_load(k_scale_n_rsrc, idx_elem, vec_width=1, dtype=T.i8) vec = vector.from_elements(T.vec(1, T.i8), [vec]) else: # vec_width=2 - idx_elem = (offset_kv_scale_nheads + global_offset_n_mx * head_dim + tx * vec_width) // 2 + idx_elem = (global_offset_n_mx * head_dim + tx * vec_width) // 2 vec = buffer_ops.buffer_load(k_scale_n_rsrc, idx_elem, vec_width=1, dtype=T.i16) vec = vector.from_elements(T.vec(1, T.i16), [vec]) vec = vector.bitcast(T.i8x2, vec) @@ -732,13 +721,13 @@ def prefetch_v_scale_tile(): vec_width = bytes_per_thread_kv_scale if const_expr(vec_width == 1): if const_expr(bytes_per_tile_kv_scale < total_threads): - idx_elem = offset_kv_scale_nheads + global_offset_n * head_dim_mx + tx % bytes_per_tile_kv_scale + idx_elem = global_offset_n * head_dim_mx + tx % bytes_per_tile_kv_scale else: - idx_elem = offset_kv_scale_nheads + global_offset_n * head_dim_mx + tx + idx_elem = global_offset_n * head_dim_mx + tx vec = buffer_ops.buffer_load(v_scale_rsrc, idx_elem, vec_width=1, dtype=T.i8) vec = vector.from_elements(T.vec(1, T.i8), [vec]) else: # vec_width=2 - idx_elem = (offset_kv_scale_nheads + global_offset_n * head_dim_mx + tx * vec_width) // 2 + idx_elem = (global_offset_n * head_dim_mx + tx * vec_width) // 2 vec = buffer_ops.buffer_load(v_scale_rsrc, idx_elem, vec_width=1, dtype=T.i16) vec = vector.from_elements(T.vec(1, T.i16), [vec]) vec = vector.bitcast(T.i8x2, vec) @@ -748,13 +737,13 @@ def prefetch_do_scale_head_tile(offset_m): vec_width = bytes_per_thread_qo_scale if const_expr(vec_width == 1): if const_expr(bytes_per_tile_qo_scale < total_threads): - idx_elem = offset_qo_scale_nheads + offset_m * head_dim_mx + tx % bytes_per_tile_qo_scale + idx_elem = offset_m * head_dim_mx + tx % bytes_per_tile_qo_scale else: - idx_elem = offset_qo_scale_nheads + offset_m * head_dim_mx + tx + idx_elem = offset_m * head_dim_mx + tx vec = buffer_ops.buffer_load(do_scale_head_rsrc, idx_elem, vec_width=1, dtype=T.i8) vec = vector.from_elements(T.vec(1, T.i8), [vec]) else: # vec_width=2 - idx_elem = (offset_qo_scale_nheads + offset_m * head_dim_mx + tx * vec_width) // 2 + idx_elem = (offset_m * head_dim_mx + tx * vec_width) // 2 vec = buffer_ops.buffer_load(do_scale_head_rsrc, idx_elem, vec_width=1, dtype=T.i16) vec = vector.from_elements(T.vec(1, T.i16), [vec]) vec = vector.bitcast(T.i8x2, vec) @@ -764,13 +753,13 @@ def prefetch_do_scale_m_tile(offset_m): vec_width = bytes_per_thread_qo_scale if const_expr(vec_width == 1): if const_expr(bytes_per_tile_qo_scale < total_threads): - idx_elem = offset_qo_scale_nheads + offset_m * head_dim + tx % bytes_per_tile_qo_scale + idx_elem = offset_m * head_dim + tx % bytes_per_tile_qo_scale else: - idx_elem = offset_qo_scale_nheads + offset_m * head_dim + tx + idx_elem = offset_m * head_dim + tx vec = buffer_ops.buffer_load(do_scale_m_rsrc, idx_elem, vec_width=1, dtype=T.i8) vec = vector.from_elements(T.vec(1, T.i8), [vec]) else: # vec_width=2 - idx_elem = (offset_qo_scale_nheads + offset_m * head_dim + tx * vec_width) // 2 + idx_elem = (offset_m * head_dim + tx * vec_width) // 2 vec = buffer_ops.buffer_load(do_scale_m_rsrc, idx_elem, vec_width=1, dtype=T.i16) vec = vector.from_elements(T.vec(1, T.i16), [vec]) vec = vector.bitcast(T.i8x2, vec) @@ -907,8 +896,8 @@ def softmax(accs_in, offset_m): accs_out = [acc_init] * ps_n_accs for mi in range_constexpr(ps_m_num_subtiles): - global_m_norm_idx = offset_MD_nheads + offset_m + ps_m_wave_id * ps_m_per_wave + mi * 16 + lane_div_16 * 4 #+ ii - m_norm_vector = buffer_ops.buffer_load(m_rsrc, global_m_norm_idx, vec_width=4) + global_m_norm_idx = offset_m + ps_m_wave_id * ps_m_per_wave + mi * 16 + lane_div_16 * 4 + m_norm_vector = buffer_ops.buffer_load(M_rsrc, global_m_norm_idx, vec_width=4) for ni in range_constexpr(ps_n_num_subtiles): @@ -1043,7 +1032,7 @@ def compute_ds(dp_accs, p_accs, offset_m): for mi in range_constexpr(ps_m_num_subtiles): - global_D_idx = offset_MD_nheads + offset_m + ps_m_wave_id * ps_m_per_wave + mi * 16 + lane_div_16 * 4 + global_D_idx = offset_m + ps_m_wave_id * ps_m_per_wave + mi * 16 + lane_div_16 * 4 D_vector = buffer_ops.buffer_load(D_rsrc, global_D_idx, vec_width=4) for ni in range_constexpr(ps_n_num_subtiles): @@ -1349,7 +1338,7 @@ def store_dq_atomic(final_accs, offset_m): for ii in range_constexpr(4): global_row = offset_m + dq_m_wave_id * dq_m_per_wave + mi * 16 + lane_div_16 * 4 + ii global_col = dq_head_wave_id * dq_head_per_wave + hi * 16 + lane_mod_16 - global_idx = offset_qo_nheads + global_row * head_dim + global_col + global_idx = global_row * head_dim + global_col global_idx_bytes = global_idx * 4 acc_idx = mi * dq_num_subtiles_head + hi @@ -1369,7 +1358,7 @@ def store_dk_atomic(final_accs): global_row = global_offset_n + dk_n_wave_id * dk_n_per_wave + ni * 16 + lane_div_16 * 4 + ii global_col = dk_head_wave_id * dk_head_per_wave + hi * 16 + lane_mod_16 - global_idx = offset_kv_nheads + global_row * head_dim + global_col + global_idx = global_row * head_dim + global_col acc_idx = ni * dk_num_subtiles_head + hi acc = final_accs[acc_idx] @@ -1391,7 +1380,7 @@ def store_dv_atomic(final_accs): global_row = global_offset_n + dv_n_wave_id * dv_n_per_wave + ni * 16 + lane_div_16 * 4 + ii global_col = dv_head_wave_id * dv_head_per_wave + hi * 16 + lane_mod_16 - global_idx = offset_kv_nheads + global_row * head_dim + global_col + global_idx = global_row * head_dim + global_col acc_idx = ni * dv_head_num_subtiles + hi acc = final_accs[acc_idx] @@ -1503,15 +1492,15 @@ def pingpong(offset_m, inner_state): dk = [acc_init] * dk_n_accs dv = [acc_init] * dv_n_accs - num_tiles_loop = seqlen // tile_m + num_tiles_loop = seqlen_rounded // tile_m if const_expr((num_tiles_loop % 2) == 1): - upper_bound = seqlen - tile_m + upper_bound = seqlen_rounded - tile_m init_state = _pack_state(dk, dv) for iv, inner in range(start_m, upper_bound, tile_m * 2, init=init_state): results = yield pingpong(iv, inner) dk, dv = _unpack_state(results) - curr_m = arith.index(seqlen - tile_m) + curr_m = arith.index(seqlen_rounded - tile_m) qk = compute_qk(lds_q_quant_head_pong, lds_q_scale_head_pong, lds_k_quant_head, lds_k_scale_head) p = softmax(qk, curr_m) mxquant_m_and_store_to_lds(p, lds_ppt_shuffle, lds_ppt_scale_shuffle) @@ -1526,14 +1515,14 @@ def pingpong(offset_m, inner_state): dq = compute_dq(lds_ds_shuffle, lds_ds_scale_shuffle, lds_k_quant_n, lds_k_scale_n) store_dq_atomic(dq, curr_m) else: - upper_bound = seqlen - (tile_m * 2) + upper_bound = seqlen_rounded - (tile_m * 2) init_state = _pack_state(dk, dv) for iv, inner in range(start_m, upper_bound, tile_m * 2, init=init_state): results = yield pingpong(iv, inner) dk, dv = _unpack_state(results) - curr_m = arith.index(seqlen - tile_m * 2) - last_m = arith.index(seqlen - tile_m) + curr_m = arith.index(seqlen_rounded - tile_m * 2) + last_m = arith.index(seqlen_rounded - tile_m) last_m_mx = last_m // 32 store_q_tile_to_lds(prefetch_q_quant_head_tile(last_m), lds_q_quant_head_ping) store_q_scale_tile_to_lds(prefetch_q_scale_head_tile(last_m), lds_q_scale_head_ping) @@ -1602,7 +1591,7 @@ def launch_attn_bwd( arg_do_scale_head: fx.Tensor, arg_do_quant_m: fx.Tensor, arg_do_scale_m: fx.Tensor, - arg_m: fx.Tensor, + arg_M: fx.Tensor, arg_D: fx.Tensor, batch: fx.Int32, stride_qo_batch: fx.Int32, @@ -1648,10 +1637,10 @@ def launch_attn_bwd( allocator_ds_scale_shuffle.finalize() gx = num_heads_q - gy = seqlen // tile_n + gy = (seqlen + tile_n - 1) // tile_n gz = batch - launcher = kernel_attn_bwd(arg_dq, arg_dk, arg_dv, arg_q_quant_head, arg_q_scale_head, arg_q_quant_m, arg_q_scale_m, arg_k_quant_head, arg_k_scale_head, arg_k_quant_n, arg_k_scale_n, arg_v, arg_v_scale, arg_do_quant_head, arg_do_scale_head, arg_do_quant_m, arg_do_scale_m, arg_m, arg_D, + launcher = kernel_attn_bwd(arg_dq, arg_dk, arg_dv, arg_q_quant_head, arg_q_scale_head, arg_q_quant_m, arg_q_scale_m, arg_k_quant_head, arg_k_scale_head, arg_k_quant_n, arg_k_scale_n, arg_v, arg_v_scale, arg_do_quant_head, arg_do_scale_head, arg_do_quant_m, arg_do_scale_m, arg_M, arg_D, batch, stride_qo_batch, stride_qo_scale_batch, diff --git a/op_tests/flydsl_tests/test_attn_bwd_mxfp8_gfx950.py b/op_tests/flydsl_tests/test_attn_bwd_mxfp8_gfx950.py index 7397b8de91..6fd580e491 100644 --- a/op_tests/flydsl_tests/test_attn_bwd_mxfp8_gfx950.py +++ b/op_tests/flydsl_tests/test_attn_bwd_mxfp8_gfx950.py @@ -82,7 +82,7 @@ def run_torch(q_fp32_head, q_fp32_m, k_fp32_head, k_fp32_n, v, do_fp32_head, do_ (80, 20) ], ) -@pytest.mark.parametrize("seqlen", [128, 1024, 1152, 4096]) +@pytest.mark.parametrize("seqlen", [128, 1024, 1056, 1152, 4096]) @pytest.mark.parametrize("head_dim", [64, 128]) @pytest.mark.parametrize("tile_m", [64, 128]) @pytest.mark.parametrize("tile_n", [64, 128]) From 125345eb88a56685f2fd587681905369ff5ab3b4 Mon Sep 17 00:00:00 2001 From: Lukasz Burzawa Date: Thu, 14 May 2026 00:32:40 +0000 Subject: [PATCH 6/8] reformat --- .../flydsl/kernels/attn_bwd_mxfp8_gfx950.py | 1510 +++++++++++++---- .../test_attn_bwd_mxfp8_gfx950.py | 127 +- .../flydsl/bench_attn_bwd_mxfp8_gfx950.py | 147 +- op_tests/op_benchmarks/flydsl/utils.py | 2 +- 4 files changed, 1359 insertions(+), 427 deletions(-) diff --git a/aiter/ops/flydsl/kernels/attn_bwd_mxfp8_gfx950.py b/aiter/ops/flydsl/kernels/attn_bwd_mxfp8_gfx950.py index a370a89090..96263b8c3a 100644 --- a/aiter/ops/flydsl/kernels/attn_bwd_mxfp8_gfx950.py +++ b/aiter/ops/flydsl/kernels/attn_bwd_mxfp8_gfx950.py @@ -22,6 +22,7 @@ swizzle_xor16, ) + def lds_transpose_load(lds_memref, elem_offset): """Transpose-load from LDS memref via ds_read_tr8_b64 (gfx950). @@ -47,8 +48,10 @@ def lds_transpose_load(lds_memref, elem_offset): addr_i32 = _to_raw(arith.index_cast(T.i32, total_byte_idx)) ptr_val = llvm.inttoptr(lds_ptr_ty, addr_i32) - result_type=T.i32x2 - result = llvm.call_intrinsic(result_type, "llvm.amdgcn.ds.read.tr8.b64", [ptr_val], [], []) + result_type = T.i32x2 + result = llvm.call_intrinsic( + result_type, "llvm.amdgcn.ds.read.tr8.b64", [ptr_val], [], [] + ) return result @@ -84,27 +87,49 @@ def compile_attn_bwd_mxfp8_gfx950( allocator_pong = SmemAllocator(None, arch=gpu_arch, global_sym_name="smem0") allocator_ping = SmemAllocator(None, arch=gpu_arch, global_sym_name="smem1") - allocator_k_quant_head = SmemAllocator(None, arch=gpu_arch, global_sym_name="smem_k_quant_head") - allocator_k_scale_head = SmemAllocator(None, arch=gpu_arch, global_sym_name="smem_k_scale_head") - allocator_k_quant_n = SmemAllocator(None, arch=gpu_arch, global_sym_name="smem_k_quant_n") - allocator_k_scale_n = SmemAllocator(None, arch=gpu_arch, global_sym_name="smem_k_scale_n") + allocator_k_quant_head = SmemAllocator( + None, arch=gpu_arch, global_sym_name="smem_k_quant_head" + ) + allocator_k_scale_head = SmemAllocator( + None, arch=gpu_arch, global_sym_name="smem_k_scale_head" + ) + allocator_k_quant_n = SmemAllocator( + None, arch=gpu_arch, global_sym_name="smem_k_quant_n" + ) + allocator_k_scale_n = SmemAllocator( + None, arch=gpu_arch, global_sym_name="smem_k_scale_n" + ) allocator_v = SmemAllocator(None, arch=gpu_arch, global_sym_name="smem_v") - allocator_v_scale = SmemAllocator(None, arch=gpu_arch, global_sym_name="smem_v_scale") - allocator_ppt_shuffle = SmemAllocator(None, arch=gpu_arch, global_sym_name="smem_ppt_shuffle") - allocator_ppt_scale_shuffle = SmemAllocator(None, arch=gpu_arch, global_sym_name="smem_ppt_scale_shuffle") - allocator_dst_shuffle = SmemAllocator(None, arch=gpu_arch, global_sym_name="smem_dst_shuffle") - allocator_dst_scale_shuffle = SmemAllocator(None, arch=gpu_arch, global_sym_name="smem_dst_scale_shuffle") - allocator_ds_shuffle = SmemAllocator(None, arch=gpu_arch, global_sym_name="smem_ds_shuffle") - allocator_ds_scale_shuffle = SmemAllocator(None, arch=gpu_arch, global_sym_name="smem_ds_scale_shuffle") + allocator_v_scale = SmemAllocator( + None, arch=gpu_arch, global_sym_name="smem_v_scale" + ) + allocator_ppt_shuffle = SmemAllocator( + None, arch=gpu_arch, global_sym_name="smem_ppt_shuffle" + ) + allocator_ppt_scale_shuffle = SmemAllocator( + None, arch=gpu_arch, global_sym_name="smem_ppt_scale_shuffle" + ) + allocator_dst_shuffle = SmemAllocator( + None, arch=gpu_arch, global_sym_name="smem_dst_shuffle" + ) + allocator_dst_scale_shuffle = SmemAllocator( + None, arch=gpu_arch, global_sym_name="smem_dst_scale_shuffle" + ) + allocator_ds_shuffle = SmemAllocator( + None, arch=gpu_arch, global_sym_name="smem_ds_shuffle" + ) + allocator_ds_scale_shuffle = SmemAllocator( + None, arch=gpu_arch, global_sym_name="smem_ds_scale_shuffle" + ) wave_size = 64 total_threads = 256 - bytes_per_tile_qo = int(tile_m) * int(tile_head) + bytes_per_tile_qo = int(tile_m) * int(tile_head) bytes_per_thread_qo = bytes_per_tile_qo // total_threads qo_load_bytes = 16 - bytes_per_tile_kv = int(tile_n) * int(tile_head) + bytes_per_tile_kv = int(tile_n) * int(tile_head) bytes_per_thread_kv = bytes_per_tile_kv // total_threads kv_load_bytes = 16 @@ -113,7 +138,7 @@ def compile_attn_bwd_mxfp8_gfx950( bytes_per_tile_kv_scale = (int(tile_n) * int(tile_head)) // 32 bytes_per_thread_kv_scale = max(1, bytes_per_tile_kv_scale // total_threads) - + def _elem_type(): return T.f8 @@ -159,10 +184,14 @@ def _vec16_type(): lds_do_quant_m_ping_offset = lds_do_scale_head_ping_offset + lds_qo_scale_tile_bytes lds_do_scale_m_ping_offset = lds_do_quant_m_ping_offset + lds_qo_tile_bytes - lds_k_quant_head_offset = allocator_k_quant_head._align(allocator_k_quant_head.ptr, 16) + lds_k_quant_head_offset = allocator_k_quant_head._align( + allocator_k_quant_head.ptr, 16 + ) allocator_k_quant_head.ptr = lds_k_quant_head_offset + lds_k_tile_bytes - lds_k_scale_head_offset = allocator_k_scale_head._align(allocator_k_scale_head.ptr, 16) + lds_k_scale_head_offset = allocator_k_scale_head._align( + allocator_k_scale_head.ptr, 16 + ) allocator_k_scale_head.ptr = lds_k_scale_head_offset + lds_k_scale_head_tile_bytes lds_k_quant_n_offset = allocator_k_quant_n._align(allocator_k_quant_n.ptr, 16) @@ -180,20 +209,32 @@ def _vec16_type(): lds_ppt_shuffle_offset = allocator_ppt_shuffle._align(allocator_ppt_shuffle.ptr, 16) allocator_ppt_shuffle.ptr = lds_ppt_shuffle_offset + lds_ppt_tile_bytes - lds_ppt_scale_shuffle_offset = allocator_ppt_scale_shuffle._align(allocator_ppt_scale_shuffle.ptr, 16) - allocator_ppt_scale_shuffle.ptr = lds_ppt_scale_shuffle_offset + lds_ppt_scale_tile_bytes + lds_ppt_scale_shuffle_offset = allocator_ppt_scale_shuffle._align( + allocator_ppt_scale_shuffle.ptr, 16 + ) + allocator_ppt_scale_shuffle.ptr = ( + lds_ppt_scale_shuffle_offset + lds_ppt_scale_tile_bytes + ) lds_dst_shuffle_offset = allocator_dst_shuffle._align(allocator_dst_shuffle.ptr, 16) allocator_dst_shuffle.ptr = lds_dst_shuffle_offset + lds_dst_tile_bytes - lds_dst_scale_shuffle_offset = allocator_dst_scale_shuffle._align(allocator_dst_scale_shuffle.ptr, 16) - allocator_dst_scale_shuffle.ptr = lds_dst_scale_shuffle_offset + lds_dst_scale_tile_bytes + lds_dst_scale_shuffle_offset = allocator_dst_scale_shuffle._align( + allocator_dst_scale_shuffle.ptr, 16 + ) + allocator_dst_scale_shuffle.ptr = ( + lds_dst_scale_shuffle_offset + lds_dst_scale_tile_bytes + ) lds_ds_shuffle_offset = allocator_ds_shuffle._align(allocator_ds_shuffle.ptr, 16) allocator_ds_shuffle.ptr = lds_ds_shuffle_offset + lds_ds_tile_bytes - lds_ds_scale_shuffle_offset = allocator_ds_scale_shuffle._align(allocator_ds_scale_shuffle.ptr, 16) - allocator_ds_scale_shuffle.ptr = lds_ds_scale_shuffle_offset + lds_ds_scale_tile_bytes + lds_ds_scale_shuffle_offset = allocator_ds_scale_shuffle._align( + allocator_ds_scale_shuffle.ptr, 16 + ) + allocator_ds_scale_shuffle.ptr = ( + lds_ds_scale_shuffle_offset + lds_ds_scale_tile_bytes + ) # ── Kernel function ──────────────────────────────────────────────────── @flyc.kernel @@ -223,9 +264,9 @@ def kernel_attn_bwd( stride_kv_batch: fx.Int32, stride_kv_scale_batch: fx.Int32, stride_MD_batch: fx.Int32, - stride_qkvo_nheads: fx.Int32, + stride_qkvo_nheads: fx.Int32, stride_qkvo_scale_nheads: fx.Int32, - stride_MD_nheads: fx.Int32 + stride_MD_nheads: fx.Int32, ): # ---- Types ---- @@ -239,7 +280,7 @@ def kernel_attn_bwd( bx = gpu.block_id("x") by = gpu.block_id("y") bz = gpu.block_id("z") - batch_id = bz + batch_id = bz head_q = bx head_kv = head_q // gqa_size @@ -260,16 +301,28 @@ def kernel_attn_bwd( base_ptr_ds_scale_shuffle = allocator_ds_scale_shuffle.get_base() lds_q_quant_head_pong = SmemPtr( - base_ptr_pong, lds_q_quant_head_pong_offset, T.f8, shape=(tile_m * tile_head,) + base_ptr_pong, + lds_q_quant_head_pong_offset, + T.f8, + shape=(tile_m * tile_head,), ).get() lds_q_quant_head_ping = SmemPtr( - base_ptr_ping, lds_q_quant_head_ping_offset, T.f8, shape=(tile_m * tile_head,) + base_ptr_ping, + lds_q_quant_head_ping_offset, + T.f8, + shape=(tile_m * tile_head,), ).get() lds_q_scale_head_pong = SmemPtr( - base_ptr_pong, lds_q_scale_head_pong_offset, T.i8, shape=(tile_m * tile_head_mx,) + base_ptr_pong, + lds_q_scale_head_pong_offset, + T.i8, + shape=(tile_m * tile_head_mx,), ).get() lds_q_scale_head_ping = SmemPtr( - base_ptr_ping, lds_q_scale_head_ping_offset, T.i8, shape=(tile_m * tile_head_mx,) + base_ptr_ping, + lds_q_scale_head_ping_offset, + T.i8, + shape=(tile_m * tile_head_mx,), ).get() lds_q_quant_m_pong = SmemPtr( base_ptr_pong, lds_q_quant_m_pong_offset, T.f8, shape=(tile_m * tile_head,) @@ -278,22 +331,40 @@ def kernel_attn_bwd( base_ptr_ping, lds_q_quant_m_ping_offset, T.f8, shape=(tile_m * tile_head,) ).get() lds_q_scale_m_pong = SmemPtr( - base_ptr_pong, lds_q_scale_m_pong_offset, T.i8, shape=(tile_m_mx * tile_head,) + base_ptr_pong, + lds_q_scale_m_pong_offset, + T.i8, + shape=(tile_m_mx * tile_head,), ).get() lds_q_scale_m_ping = SmemPtr( - base_ptr_ping, lds_q_scale_m_ping_offset, T.i8, shape=(tile_m_mx * tile_head,) + base_ptr_ping, + lds_q_scale_m_ping_offset, + T.i8, + shape=(tile_m_mx * tile_head,), ).get() lds_do_quant_head_pong = SmemPtr( - base_ptr_pong, lds_do_quant_head_pong_offset, T.f8, shape=(tile_m * tile_head,) + base_ptr_pong, + lds_do_quant_head_pong_offset, + T.f8, + shape=(tile_m * tile_head,), ).get() lds_do_quant_head_ping = SmemPtr( - base_ptr_ping, lds_do_quant_head_ping_offset, T.f8, shape=(tile_m * tile_head,) + base_ptr_ping, + lds_do_quant_head_ping_offset, + T.f8, + shape=(tile_m * tile_head,), ).get() lds_do_scale_head_pong = SmemPtr( - base_ptr_pong, lds_do_scale_head_pong_offset, T.i8, shape=(tile_m * tile_head_mx,) + base_ptr_pong, + lds_do_scale_head_pong_offset, + T.i8, + shape=(tile_m * tile_head_mx,), ).get() lds_do_scale_head_ping = SmemPtr( - base_ptr_ping, lds_do_scale_head_ping_offset, T.i8, shape=(tile_m * tile_head_mx,) + base_ptr_ping, + lds_do_scale_head_ping_offset, + T.i8, + shape=(tile_m * tile_head_mx,), ).get() lds_do_quant_m_pong = SmemPtr( base_ptr_pong, lds_do_quant_m_pong_offset, T.f8, shape=(tile_m * tile_head,) @@ -302,22 +373,37 @@ def kernel_attn_bwd( base_ptr_ping, lds_do_quant_m_ping_offset, T.f8, shape=(tile_m * tile_head,) ).get() lds_do_scale_m_pong = SmemPtr( - base_ptr_pong, lds_do_scale_m_pong_offset, T.i8, shape=(tile_head * tile_m_mx,) + base_ptr_pong, + lds_do_scale_m_pong_offset, + T.i8, + shape=(tile_head * tile_m_mx,), ).get() lds_do_scale_m_ping = SmemPtr( - base_ptr_ping, lds_do_scale_m_ping_offset, T.i8, shape=(tile_head * tile_m_mx,) + base_ptr_ping, + lds_do_scale_m_ping_offset, + T.i8, + shape=(tile_head * tile_m_mx,), ).get() lds_k_quant_head = SmemPtr( - base_ptr_k_quant_head, lds_k_quant_head_offset, T.f8, shape=(tile_n * tile_head,) + base_ptr_k_quant_head, + lds_k_quant_head_offset, + T.f8, + shape=(tile_n * tile_head,), ).get() lds_k_scale_head = SmemPtr( - base_ptr_k_scale_head, lds_k_scale_head_offset, T.i8, shape=(tile_n * tile_head_mx,) + base_ptr_k_scale_head, + lds_k_scale_head_offset, + T.i8, + shape=(tile_n * tile_head_mx,), ).get() lds_k_quant_n = SmemPtr( base_ptr_k_quant_n, lds_k_quant_n_offset, T.f8, shape=(tile_n * tile_head,) ).get() lds_k_scale_n = SmemPtr( - base_ptr_k_scale_n, lds_k_scale_n_offset, T.i8, shape=(tile_n_mx * tile_head,) + base_ptr_k_scale_n, + lds_k_scale_n_offset, + T.i8, + shape=(tile_n_mx * tile_head,), ).get() lds_v = SmemPtr( base_ptr_v, lds_v_offset, T.f8, shape=(tile_n * tile_head,) @@ -329,28 +415,47 @@ def kernel_attn_bwd( base_ptr_ppt_shuffle, lds_ppt_shuffle_offset, T.f8, shape=(tile_n * tile_m,) ).get() lds_ppt_scale_shuffle = SmemPtr( - base_ptr_ppt_scale_shuffle, lds_ppt_scale_shuffle_offset, T.i8, shape=(tile_n * tile_m_mx,) + base_ptr_ppt_scale_shuffle, + lds_ppt_scale_shuffle_offset, + T.i8, + shape=(tile_n * tile_m_mx,), ).get() lds_dst_shuffle = SmemPtr( base_ptr_dst_shuffle, lds_dst_shuffle_offset, T.f8, shape=(tile_n * tile_m,) ).get() lds_dst_scale_shuffle = SmemPtr( - base_ptr_dst_scale_shuffle, lds_dst_scale_shuffle_offset, T.i8, shape=(tile_n * tile_m_mx,) + base_ptr_dst_scale_shuffle, + lds_dst_scale_shuffle_offset, + T.i8, + shape=(tile_n * tile_m_mx,), ).get() lds_ds_shuffle = SmemPtr( base_ptr_ds_shuffle, lds_ds_shuffle_offset, T.f8, shape=(tile_m * tile_n,) ).get() lds_ds_scale_shuffle = SmemPtr( - base_ptr_ds_scale_shuffle, lds_ds_scale_shuffle_offset, T.i8, shape=(tile_m * tile_n_mx,) + base_ptr_ds_scale_shuffle, + lds_ds_scale_shuffle_offset, + T.i8, + shape=(tile_m * tile_n_mx,), ).get() - offset_qo_nheads = batch_id * fx.Index(stride_qo_batch) + head_q * fx.Index(stride_qkvo_nheads) + offset_qo_nheads = batch_id * fx.Index(stride_qo_batch) + head_q * fx.Index( + stride_qkvo_nheads + ) offset_dq_nheads = offset_qo_nheads * 4 - offset_kv_nheads = batch_id * fx.Index(stride_kv_batch) + head_kv * fx.Index(stride_qkvo_nheads) + offset_kv_nheads = batch_id * fx.Index(stride_kv_batch) + head_kv * fx.Index( + stride_qkvo_nheads + ) offset_dkdv_nheads = offset_kv_nheads * 4 - offset_qo_scale_nheads = batch_id * fx.Index(stride_qo_scale_batch) + head_q * fx.Index(stride_qkvo_scale_nheads) - offset_kv_scale_nheads = batch_id * fx.Index(stride_kv_scale_batch) + head_kv * fx.Index(stride_qkvo_scale_nheads) - offset_MD_nheads = (batch_id * fx.Index(stride_MD_batch) + head_q * fx.Index(stride_MD_nheads)) * 4 + offset_qo_scale_nheads = batch_id * fx.Index( + stride_qo_scale_batch + ) + head_q * fx.Index(stride_qkvo_scale_nheads) + offset_kv_scale_nheads = batch_id * fx.Index( + stride_kv_scale_batch + ) + head_kv * fx.Index(stride_qkvo_scale_nheads) + offset_MD_nheads = ( + batch_id * fx.Index(stride_MD_batch) + head_q * fx.Index(stride_MD_nheads) + ) * 4 # ---- Buffer resources (runtime byte sizes for OOB protection) ---- head_dim_mx = head_dim // 32 @@ -367,25 +472,120 @@ def kernel_attn_bwd( output_nrec = arith.index_cast(T.i64, global_buffer_size_tensor * 4) MD_nrec = arith.index_cast(T.i64, fx.Index(seqlen * 4)) - q_quant_head_rsrc = buffer_ops.create_buffer_resource(arg_q_quant_head, max_size=False, num_records_bytes=q_nrec, base_byte_offset=offset_qo_nheads) - q_scale_head_rsrc = buffer_ops.create_buffer_resource(arg_q_scale_head, max_size=False, num_records_bytes=q_scale_nrec, base_byte_offset=offset_qo_scale_nheads) - q_quant_m_rsrc = buffer_ops.create_buffer_resource(arg_q_quant_m, max_size=False, num_records_bytes=q_nrec, base_byte_offset=offset_qo_nheads) - q_scale_m_rsrc = buffer_ops.create_buffer_resource(arg_q_scale_m, max_size=False, num_records_bytes=q_scale_nrec, base_byte_offset=offset_qo_scale_nheads) - k_quant_head_rsrc = buffer_ops.create_buffer_resource(arg_k_quant_head, max_size=False, num_records_bytes=k_nrec, base_byte_offset=offset_kv_nheads) - k_scale_head_rsrc = buffer_ops.create_buffer_resource(arg_k_scale_head, max_size=False, num_records_bytes=k_scale_nrec, base_byte_offset=offset_kv_scale_nheads) - k_quant_n_rsrc = buffer_ops.create_buffer_resource(arg_k_quant_n, max_size=False, num_records_bytes=k_nrec, base_byte_offset=offset_kv_nheads) - k_scale_n_rsrc = buffer_ops.create_buffer_resource(arg_k_scale_n, max_size=False, num_records_bytes=k_scale_nrec, base_byte_offset=offset_kv_scale_nheads) - v_rsrc = buffer_ops.create_buffer_resource(arg_v, max_size=False, num_records_bytes=v_nrec, base_byte_offset=offset_kv_nheads) - v_scale_rsrc = buffer_ops.create_buffer_resource(arg_v_scale, max_size=False, num_records_bytes=v_scale_nrec, base_byte_offset=offset_kv_scale_nheads) - do_quant_head_rsrc = buffer_ops.create_buffer_resource(arg_do_quant_head, max_size=False, num_records_bytes=do_nrec, base_byte_offset=offset_qo_nheads) - do_scale_head_rsrc = buffer_ops.create_buffer_resource(arg_do_scale_head, max_size=False, num_records_bytes=do_scale_nrec, base_byte_offset=offset_qo_scale_nheads) - do_quant_m_rsrc = buffer_ops.create_buffer_resource(arg_do_quant_m, max_size=False, num_records_bytes=do_nrec, base_byte_offset=offset_qo_nheads) - do_scale_m_rsrc = buffer_ops.create_buffer_resource(arg_do_scale_m, max_size=False, num_records_bytes=do_scale_nrec, base_byte_offset=offset_qo_scale_nheads) - dq_rsrc = buffer_ops.create_buffer_resource(arg_dq, max_size=False, num_records_bytes=output_nrec, base_byte_offset=offset_dq_nheads) - dk_rsrc = buffer_ops.create_buffer_resource(arg_dk, max_size=False, num_records_bytes=output_nrec, base_byte_offset=offset_dkdv_nheads) - dv_rsrc = buffer_ops.create_buffer_resource(arg_dv, max_size=False, num_records_bytes=output_nrec, base_byte_offset=offset_dkdv_nheads) - M_rsrc = buffer_ops.create_buffer_resource(arg_M, max_size=False, num_records_bytes=MD_nrec, base_byte_offset=offset_MD_nheads) - D_rsrc = buffer_ops.create_buffer_resource(arg_D, max_size=False, num_records_bytes=MD_nrec, base_byte_offset=offset_MD_nheads) + q_quant_head_rsrc = buffer_ops.create_buffer_resource( + arg_q_quant_head, + max_size=False, + num_records_bytes=q_nrec, + base_byte_offset=offset_qo_nheads, + ) + q_scale_head_rsrc = buffer_ops.create_buffer_resource( + arg_q_scale_head, + max_size=False, + num_records_bytes=q_scale_nrec, + base_byte_offset=offset_qo_scale_nheads, + ) + q_quant_m_rsrc = buffer_ops.create_buffer_resource( + arg_q_quant_m, + max_size=False, + num_records_bytes=q_nrec, + base_byte_offset=offset_qo_nheads, + ) + q_scale_m_rsrc = buffer_ops.create_buffer_resource( + arg_q_scale_m, + max_size=False, + num_records_bytes=q_scale_nrec, + base_byte_offset=offset_qo_scale_nheads, + ) + k_quant_head_rsrc = buffer_ops.create_buffer_resource( + arg_k_quant_head, + max_size=False, + num_records_bytes=k_nrec, + base_byte_offset=offset_kv_nheads, + ) + k_scale_head_rsrc = buffer_ops.create_buffer_resource( + arg_k_scale_head, + max_size=False, + num_records_bytes=k_scale_nrec, + base_byte_offset=offset_kv_scale_nheads, + ) + k_quant_n_rsrc = buffer_ops.create_buffer_resource( + arg_k_quant_n, + max_size=False, + num_records_bytes=k_nrec, + base_byte_offset=offset_kv_nheads, + ) + k_scale_n_rsrc = buffer_ops.create_buffer_resource( + arg_k_scale_n, + max_size=False, + num_records_bytes=k_scale_nrec, + base_byte_offset=offset_kv_scale_nheads, + ) + v_rsrc = buffer_ops.create_buffer_resource( + arg_v, + max_size=False, + num_records_bytes=v_nrec, + base_byte_offset=offset_kv_nheads, + ) + v_scale_rsrc = buffer_ops.create_buffer_resource( + arg_v_scale, + max_size=False, + num_records_bytes=v_scale_nrec, + base_byte_offset=offset_kv_scale_nheads, + ) + do_quant_head_rsrc = buffer_ops.create_buffer_resource( + arg_do_quant_head, + max_size=False, + num_records_bytes=do_nrec, + base_byte_offset=offset_qo_nheads, + ) + do_scale_head_rsrc = buffer_ops.create_buffer_resource( + arg_do_scale_head, + max_size=False, + num_records_bytes=do_scale_nrec, + base_byte_offset=offset_qo_scale_nheads, + ) + do_quant_m_rsrc = buffer_ops.create_buffer_resource( + arg_do_quant_m, + max_size=False, + num_records_bytes=do_nrec, + base_byte_offset=offset_qo_nheads, + ) + do_scale_m_rsrc = buffer_ops.create_buffer_resource( + arg_do_scale_m, + max_size=False, + num_records_bytes=do_scale_nrec, + base_byte_offset=offset_qo_scale_nheads, + ) + dq_rsrc = buffer_ops.create_buffer_resource( + arg_dq, + max_size=False, + num_records_bytes=output_nrec, + base_byte_offset=offset_dq_nheads, + ) + dk_rsrc = buffer_ops.create_buffer_resource( + arg_dk, + max_size=False, + num_records_bytes=output_nrec, + base_byte_offset=offset_dkdv_nheads, + ) + dv_rsrc = buffer_ops.create_buffer_resource( + arg_dv, + max_size=False, + num_records_bytes=output_nrec, + base_byte_offset=offset_dkdv_nheads, + ) + M_rsrc = buffer_ops.create_buffer_resource( + arg_M, + max_size=False, + num_records_bytes=MD_nrec, + base_byte_offset=offset_MD_nheads, + ) + D_rsrc = buffer_ops.create_buffer_resource( + arg_D, + max_size=False, + num_records_bytes=MD_nrec, + base_byte_offset=offset_MD_nheads, + ) global_offset_n = by * tile_n global_offset_n_mx = global_offset_n // 32 @@ -409,7 +609,9 @@ def kernel_attn_bwd( # wave partitioning for qk, p, dp, ds ps_m_num_waves = 2 ps_n_num_waves = 2 - ps_wave_layout = fx.make_layout((ps_m_num_waves, ps_n_num_waves), (ps_n_num_waves, 1)) + ps_wave_layout = fx.make_layout( + (ps_m_num_waves, ps_n_num_waves), (ps_n_num_waves, 1) + ) ps_coord = fx.idx2crd(wave_id, ps_wave_layout) ps_m_wave_id = fx.get(ps_coord, 0) ps_n_wave_id = fx.get(ps_coord, 1) @@ -419,12 +621,14 @@ def kernel_attn_bwd( ps_n_per_wave = tile_n // ps_n_num_waves ps_n_mx_per_wave = tile_n_mx // ps_n_num_waves ps_n_num_subtiles = ps_n_per_wave // 16 - ps_n_accs = ps_n_num_subtiles * ps_m_num_subtiles + ps_n_accs = ps_n_num_subtiles * ps_m_num_subtiles # wave partitioning for dv gemm dv_n_num_waves = 2 dv_head_num_waves = 2 - dv_wave_layout = fx.make_layout((dv_n_num_waves, dv_head_num_waves), (dv_head_num_waves, 1)) + dv_wave_layout = fx.make_layout( + (dv_n_num_waves, dv_head_num_waves), (dv_head_num_waves, 1) + ) dv_coord = fx.idx2crd(wave_id, dv_wave_layout) dv_n_wave_id = fx.get(dv_coord, 0) dv_head_wave_id = fx.get(dv_coord, 1) @@ -437,7 +641,9 @@ def kernel_attn_bwd( # wave partitioning for dk gemm dk_n_num_waves = 2 dk_head_num_waves = 2 - dk_wave_layout = fx.make_layout((dk_n_num_waves, dk_head_num_waves), (dk_head_num_waves, 1)) + dk_wave_layout = fx.make_layout( + (dk_n_num_waves, dk_head_num_waves), (dk_head_num_waves, 1) + ) dk_coord = fx.idx2crd(wave_id, dk_wave_layout) dk_n_wave_id = fx.get(dk_coord, 0) dk_head_wave_id = fx.get(dk_coord, 1) @@ -450,7 +656,9 @@ def kernel_attn_bwd( # wave partitioning for dq gemm dq_m_num_waves = 2 dq_head_num_waves = 2 - dq_wave_layout = fx.make_layout((dq_m_num_waves, dq_head_num_waves), (dq_head_num_waves, 1)) + dq_wave_layout = fx.make_layout( + (dq_m_num_waves, dq_head_num_waves), (dq_head_num_waves, 1) + ) dq_coord = fx.idx2crd(wave_id, dq_wave_layout) dq_m_wave_id = fx.get(dq_coord, 0) dq_head_wave_id = fx.get(dq_coord, 1) @@ -467,27 +675,35 @@ def lds_load_16b(curr_row_lds, col_base, lds_stride, lds_buffer, swizzle=16): col_base = swizzle_xor16(curr_row_lds, col_base, lds_stride // swizzle) idx = curr_row_lds * lds_stride + col_base return vector.load_op(_vec16_type(), lds_buffer, [idx]) - - def lds_load_8b_transposed(curr_row_lds, col_base, lds_stride, lds_buffer, swizzle=16): + + def lds_load_8b_transposed( + curr_row_lds, col_base, lds_stride, lds_buffer, swizzle=16 + ): if swizzle == 16: col_base = swizzle_xor16(curr_row_lds, col_base, lds_stride // swizzle) col_base = col_base + lane_mod_2 * 8 idx = curr_row_lds * lds_stride + col_base return lds_transpose_load(lds_buffer, idx) - - def lds_load_packs_k64(curr_row_lds, col_base, lds_stride, lds_buffer, swizzle=16): + + def lds_load_packs_k64( + curr_row_lds, col_base, lds_stride, lds_buffer, swizzle=16 + ): vec = lds_load_16b(curr_row_lds, col_base, lds_stride, lds_buffer, swizzle) vec = vector.bitcast(T.i64x2, vec) val0_i64 = vector.extract(vec, static_position=[0], dynamic_position=[]) val1_i64 = vector.extract(vec, static_position=[1], dynamic_position=[]) return val0_i64, val1_i64 - - def lds_load_packs_k32_transposed(curr_row_lds, col_base, lds_stride, lds_buffer, swizzle=16): - vec = lds_load_8b_transposed(curr_row_lds, col_base, lds_stride, lds_buffer, swizzle) + + def lds_load_packs_k32_transposed( + curr_row_lds, col_base, lds_stride, lds_buffer, swizzle=16 + ): + vec = lds_load_8b_transposed( + curr_row_lds, col_base, lds_stride, lds_buffer, swizzle + ) vec = vector.bitcast(T.vec(1, T.i64), vec) val_i64 = vector.extract(vec, static_position=[0], dynamic_position=[]) return val_i64 - + def lds_scale_load(row, col, lds_stride, lds_buffer): idx = row * lds_stride + col vec = vector.load_op(T.vec(1, T.i8), lds_buffer, [idx]) @@ -495,7 +711,6 @@ def lds_scale_load(row, col, lds_stride, lds_buffer): val = val.extui(T.i32) return val - # ── A global→reg load ───────────────────────────────────────────── head_dim_div4 = head_dim // 4 tile_m_div16 = tile_m // 16 @@ -503,84 +718,108 @@ def lds_scale_load(row, col, lds_stride, lds_buffer): num_qo_loads = bytes_per_thread_qo // qo_load_bytes num_kv_loads = bytes_per_thread_kv // kv_load_bytes tile_head_dwords = tile_head // 4 - layout_qo_tile_div4 = fx.make_layout((tile_m, tile_head_dwords), (tile_head_dwords, 1)) - layout_kv_tile_div4 = fx.make_layout((tile_n, tile_head_dwords), (tile_head_dwords, 1)) + layout_qo_tile_div4 = fx.make_layout( + (tile_m, tile_head_dwords), (tile_head_dwords, 1) + ) + layout_kv_tile_div4 = fx.make_layout( + (tile_n, tile_head_dwords), (tile_head_dwords, 1) + ) c4 = fx.Index(4) tx_i32_base = tx * c4 def load_q_quant_head_16(idx_elem): return buffer_copy_gmem16_dwordx4( - buffer_ops, vector, + buffer_ops, + vector, elem_type=_elem_type(), idx_i32=idx_elem, - rsrc=q_quant_head_rsrc, vec_elems=16, + rsrc=q_quant_head_rsrc, + vec_elems=16, elem_bytes=elem_bytes, ) - + def load_q_quant_m_16(idx_elem): return buffer_copy_gmem16_dwordx4( - buffer_ops, vector, + buffer_ops, + vector, elem_type=_elem_type(), idx_i32=idx_elem, - rsrc=q_quant_m_rsrc, vec_elems=16, + rsrc=q_quant_m_rsrc, + vec_elems=16, elem_bytes=elem_bytes, ) - + def load_k_quant_head_16(idx_elem): return buffer_copy_gmem16_dwordx4( - buffer_ops, vector, + buffer_ops, + vector, elem_type=_elem_type(), idx_i32=idx_elem, - rsrc=k_quant_head_rsrc, vec_elems=16, + rsrc=k_quant_head_rsrc, + vec_elems=16, elem_bytes=elem_bytes, ) - + def load_k_quant_n_16(idx_elem): return buffer_copy_gmem16_dwordx4( - buffer_ops, vector, + buffer_ops, + vector, elem_type=_elem_type(), idx_i32=idx_elem, - rsrc=k_quant_n_rsrc, vec_elems=16, + rsrc=k_quant_n_rsrc, + vec_elems=16, elem_bytes=elem_bytes, ) - + def load_v_16(idx_elem): return buffer_copy_gmem16_dwordx4( - buffer_ops, vector, + buffer_ops, + vector, elem_type=_elem_type(), idx_i32=idx_elem, - rsrc=v_rsrc, vec_elems=16, + rsrc=v_rsrc, + vec_elems=16, elem_bytes=elem_bytes, ) - + def load_do_quant_head_16(idx_elem): return buffer_copy_gmem16_dwordx4( - buffer_ops, vector, + buffer_ops, + vector, elem_type=_elem_type(), idx_i32=idx_elem, - rsrc=do_quant_head_rsrc, vec_elems=16, + rsrc=do_quant_head_rsrc, + vec_elems=16, elem_bytes=elem_bytes, ) def load_do_quant_m_16(idx_elem): return buffer_copy_gmem16_dwordx4( - buffer_ops, vector, + buffer_ops, + vector, elem_type=_elem_type(), idx_i32=idx_elem, - rsrc=do_quant_m_rsrc, vec_elems=16, + rsrc=do_quant_m_rsrc, + vec_elems=16, elem_bytes=elem_bytes, ) def qo_tile_chunk_coord_i32(i: int): return tile_chunk_coord_i32( - arith, tx_i32_base=tx_i32_base, i=i, - total_threads=total_threads, layout_tile_div4=layout_qo_tile_div4, + arith, + tx_i32_base=tx_i32_base, + i=i, + total_threads=total_threads, + layout_tile_div4=layout_qo_tile_div4, ) - + def kv_tile_chunk_coord_i32(i: int): return tile_chunk_coord_i32( - arith, tx_i32_base=tx_i32_base, i=i, - total_threads=total_threads, layout_tile_div4=layout_kv_tile_div4, + arith, + tx_i32_base=tx_i32_base, + i=i, + total_threads=total_threads, + layout_tile_div4=layout_kv_tile_div4, ) def prefetch_q_quant_head_tile(offset_m): @@ -592,7 +831,7 @@ def prefetch_q_quant_head_tile(offset_m): q_16B = load_q_quant_head_16(idx_elem) parts.append(vector.bitcast(T.i32x4, q_16B)) return parts - + def prefetch_q_quant_m_tile(offset_m): parts = [] for i in range_constexpr(num_qo_loads): @@ -602,7 +841,7 @@ def prefetch_q_quant_m_tile(offset_m): q_16B = load_q_quant_m_16(idx_elem) parts.append(vector.bitcast(T.i32x4, q_16B)) return parts - + def prefetch_k_quant_head_tile(): parts = [] for i in range_constexpr(num_kv_loads): @@ -612,7 +851,7 @@ def prefetch_k_quant_head_tile(): k_16B = load_k_quant_head_16(idx_elem) parts.append(vector.bitcast(T.i32x4, k_16B)) return parts - + def prefetch_k_quant_n_tile(): parts = [] for i in range_constexpr(num_kv_loads): @@ -622,7 +861,7 @@ def prefetch_k_quant_n_tile(): k_16B = load_k_quant_n_16(idx_elem) parts.append(vector.bitcast(T.i32x4, k_16B)) return parts - + def prefetch_v_tile(): parts = [] for i in range_constexpr(num_kv_loads): @@ -642,7 +881,7 @@ def prefetch_do_quant_head_tile(offset_m): do_16B = load_do_quant_head_16(idx_elem) parts.append(vector.bitcast(T.i32x4, do_16B)) return parts - + def prefetch_do_quant_m_tile(offset_m): parts = [] for i in range_constexpr(num_qo_loads): @@ -652,19 +891,23 @@ def prefetch_do_quant_m_tile(offset_m): do_16B = load_do_quant_m_16(idx_elem) parts.append(vector.bitcast(T.i32x4, do_16B)) return parts - + def prefetch_q_scale_head_tile(offset_m): vec_width = bytes_per_thread_qo_scale if const_expr(vec_width == 1): if const_expr(bytes_per_tile_qo_scale < total_threads): idx_elem = offset_m * head_dim_mx + tx % bytes_per_tile_qo_scale else: - idx_elem = offset_m * head_dim_mx + tx - vec = buffer_ops.buffer_load(q_scale_head_rsrc, idx_elem, vec_width=1, dtype=T.i8) + idx_elem = offset_m * head_dim_mx + tx + vec = buffer_ops.buffer_load( + q_scale_head_rsrc, idx_elem, vec_width=1, dtype=T.i8 + ) vec = vector.from_elements(T.vec(1, T.i8), [vec]) else: # vec_width=2 idx_elem = (offset_m * head_dim_mx + tx * vec_width) // 2 - vec = buffer_ops.buffer_load(q_scale_head_rsrc, idx_elem, vec_width=1, dtype=T.i16) + vec = buffer_ops.buffer_load( + q_scale_head_rsrc, idx_elem, vec_width=1, dtype=T.i16 + ) vec = vector.from_elements(T.vec(1, T.i16), [vec]) vec = vector.bitcast(T.i8x2, vec) return vec @@ -676,27 +919,37 @@ def prefetch_q_scale_m_tile(offset_m): idx_elem = offset_m * head_dim + tx % bytes_per_tile_qo_scale else: idx_elem = offset_m * head_dim + tx - vec = buffer_ops.buffer_load(q_scale_m_rsrc, idx_elem, vec_width=1, dtype=T.i8) + vec = buffer_ops.buffer_load( + q_scale_m_rsrc, idx_elem, vec_width=1, dtype=T.i8 + ) vec = vector.from_elements(T.vec(1, T.i8), [vec]) else: # vec_width=2 idx_elem = (offset_m * head_dim + tx * vec_width) // 2 - vec = buffer_ops.buffer_load(q_scale_m_rsrc, idx_elem, vec_width=1, dtype=T.i16) + vec = buffer_ops.buffer_load( + q_scale_m_rsrc, idx_elem, vec_width=1, dtype=T.i16 + ) vec = vector.from_elements(T.vec(1, T.i16), [vec]) vec = vector.bitcast(T.i8x2, vec) return vec - + def prefetch_k_scale_head_tile(): vec_width = bytes_per_thread_kv_scale if const_expr(vec_width == 1): if const_expr(bytes_per_tile_kv_scale < total_threads): - idx_elem = global_offset_n * head_dim_mx + tx % bytes_per_tile_kv_scale + idx_elem = ( + global_offset_n * head_dim_mx + tx % bytes_per_tile_kv_scale + ) else: idx_elem = global_offset_n * head_dim_mx + tx - vec = buffer_ops.buffer_load(k_scale_head_rsrc, idx_elem, vec_width=1, dtype=T.i8) + vec = buffer_ops.buffer_load( + k_scale_head_rsrc, idx_elem, vec_width=1, dtype=T.i8 + ) vec = vector.from_elements(T.vec(1, T.i8), [vec]) else: # vec_width=2 idx_elem = (global_offset_n * head_dim_mx + tx * vec_width) // 2 - vec = buffer_ops.buffer_load(k_scale_head_rsrc, idx_elem, vec_width=1, dtype=T.i16) + vec = buffer_ops.buffer_load( + k_scale_head_rsrc, idx_elem, vec_width=1, dtype=T.i16 + ) vec = vector.from_elements(T.vec(1, T.i16), [vec]) vec = vector.bitcast(T.i8x2, vec) return vec @@ -705,34 +958,46 @@ def prefetch_k_scale_n_tile(): vec_width = bytes_per_thread_kv_scale if const_expr(vec_width == 1): if const_expr(bytes_per_tile_kv_scale < total_threads): - idx_elem = global_offset_n_mx * head_dim + tx % bytes_per_tile_kv_scale + idx_elem = ( + global_offset_n_mx * head_dim + tx % bytes_per_tile_kv_scale + ) else: - idx_elem = global_offset_n_mx * head_dim + tx - vec = buffer_ops.buffer_load(k_scale_n_rsrc, idx_elem, vec_width=1, dtype=T.i8) + idx_elem = global_offset_n_mx * head_dim + tx + vec = buffer_ops.buffer_load( + k_scale_n_rsrc, idx_elem, vec_width=1, dtype=T.i8 + ) vec = vector.from_elements(T.vec(1, T.i8), [vec]) else: # vec_width=2 idx_elem = (global_offset_n_mx * head_dim + tx * vec_width) // 2 - vec = buffer_ops.buffer_load(k_scale_n_rsrc, idx_elem, vec_width=1, dtype=T.i16) + vec = buffer_ops.buffer_load( + k_scale_n_rsrc, idx_elem, vec_width=1, dtype=T.i16 + ) vec = vector.from_elements(T.vec(1, T.i16), [vec]) vec = vector.bitcast(T.i8x2, vec) return vec - + def prefetch_v_scale_tile(): vec_width = bytes_per_thread_kv_scale if const_expr(vec_width == 1): if const_expr(bytes_per_tile_kv_scale < total_threads): - idx_elem = global_offset_n * head_dim_mx + tx % bytes_per_tile_kv_scale + idx_elem = ( + global_offset_n * head_dim_mx + tx % bytes_per_tile_kv_scale + ) else: idx_elem = global_offset_n * head_dim_mx + tx - vec = buffer_ops.buffer_load(v_scale_rsrc, idx_elem, vec_width=1, dtype=T.i8) + vec = buffer_ops.buffer_load( + v_scale_rsrc, idx_elem, vec_width=1, dtype=T.i8 + ) vec = vector.from_elements(T.vec(1, T.i8), [vec]) else: # vec_width=2 idx_elem = (global_offset_n * head_dim_mx + tx * vec_width) // 2 - vec = buffer_ops.buffer_load(v_scale_rsrc, idx_elem, vec_width=1, dtype=T.i16) + vec = buffer_ops.buffer_load( + v_scale_rsrc, idx_elem, vec_width=1, dtype=T.i16 + ) vec = vector.from_elements(T.vec(1, T.i16), [vec]) vec = vector.bitcast(T.i8x2, vec) return vec - + def prefetch_do_scale_head_tile(offset_m): vec_width = bytes_per_thread_qo_scale if const_expr(vec_width == 1): @@ -740,11 +1005,15 @@ def prefetch_do_scale_head_tile(offset_m): idx_elem = offset_m * head_dim_mx + tx % bytes_per_tile_qo_scale else: idx_elem = offset_m * head_dim_mx + tx - vec = buffer_ops.buffer_load(do_scale_head_rsrc, idx_elem, vec_width=1, dtype=T.i8) + vec = buffer_ops.buffer_load( + do_scale_head_rsrc, idx_elem, vec_width=1, dtype=T.i8 + ) vec = vector.from_elements(T.vec(1, T.i8), [vec]) else: # vec_width=2 idx_elem = (offset_m * head_dim_mx + tx * vec_width) // 2 - vec = buffer_ops.buffer_load(do_scale_head_rsrc, idx_elem, vec_width=1, dtype=T.i16) + vec = buffer_ops.buffer_load( + do_scale_head_rsrc, idx_elem, vec_width=1, dtype=T.i16 + ) vec = vector.from_elements(T.vec(1, T.i16), [vec]) vec = vector.bitcast(T.i8x2, vec) return vec @@ -756,11 +1025,15 @@ def prefetch_do_scale_m_tile(offset_m): idx_elem = offset_m * head_dim + tx % bytes_per_tile_qo_scale else: idx_elem = offset_m * head_dim + tx - vec = buffer_ops.buffer_load(do_scale_m_rsrc, idx_elem, vec_width=1, dtype=T.i8) + vec = buffer_ops.buffer_load( + do_scale_m_rsrc, idx_elem, vec_width=1, dtype=T.i8 + ) vec = vector.from_elements(T.vec(1, T.i8), [vec]) else: # vec_width=2 idx_elem = (offset_m * head_dim + tx * vec_width) // 2 - vec = buffer_ops.buffer_load(do_scale_m_rsrc, idx_elem, vec_width=1, dtype=T.i16) + vec = buffer_ops.buffer_load( + do_scale_m_rsrc, idx_elem, vec_width=1, dtype=T.i16 + ) vec = vector.from_elements(T.vec(1, T.i16), [vec]) vec = vector.bitcast(T.i8x2, vec) return vec @@ -769,7 +1042,9 @@ def store_q_tile_to_lds(vec_q_parts, lds_buffer): for i in range_constexpr(num_qo_loads): row_q_local, col_q_local_i32 = qo_tile_chunk_coord_i32(i) col_local_bytes = col_q_local_i32 * c4 - col_swz_bytes = swizzle_xor16(row_q_local, col_local_bytes, tile_head_div16) + col_swz_bytes = swizzle_xor16( + row_q_local, col_local_bytes, tile_head_div16 + ) col_swz = col_swz_bytes idx0 = row_q_local * tile_head + col_swz v16 = vector.bitcast(_vec16_type(), vec_q_parts[i]) @@ -779,7 +1054,9 @@ def store_k_tile_to_lds(vec_k_parts, lds_buffer): for i in range_constexpr(num_kv_loads): row_k_local, col_k_local_i32 = kv_tile_chunk_coord_i32(i) col_local_bytes = col_k_local_i32 * c4 - col_swz_bytes = swizzle_xor16(row_k_local, col_local_bytes, tile_head_div16) + col_swz_bytes = swizzle_xor16( + row_k_local, col_local_bytes, tile_head_div16 + ) col_swz = col_swz_bytes idx0 = row_k_local * tile_head + col_swz v16 = vector.bitcast(_vec16_type(), vec_k_parts[i]) @@ -789,7 +1066,9 @@ def store_v_tile_to_lds(vec_v_parts, lds_buffer): for i in range_constexpr(num_kv_loads): row_v_local, col_v_local_i32 = kv_tile_chunk_coord_i32(i) col_local_bytes = col_v_local_i32 * c4 - col_swz_bytes = swizzle_xor16(row_v_local, col_local_bytes, tile_head_div16) + col_swz_bytes = swizzle_xor16( + row_v_local, col_local_bytes, tile_head_div16 + ) col_swz = col_swz_bytes idx0 = row_v_local * tile_head + col_swz v16 = vector.bitcast(_vec16_type(), vec_v_parts[i]) @@ -799,13 +1078,14 @@ def store_do_tile_to_lds(vec_do_parts, lds_buffer): for i in range_constexpr(num_qo_loads): row_do_local, col_do_local_i32 = qo_tile_chunk_coord_i32(i) col_local_bytes = col_do_local_i32 * c4 - col_swz_bytes = swizzle_xor16(row_do_local, col_local_bytes, tile_head_div16) + col_swz_bytes = swizzle_xor16( + row_do_local, col_local_bytes, tile_head_div16 + ) col_swz = col_swz_bytes idx0 = row_do_local * tile_head + col_swz v16 = vector.bitcast(_vec16_type(), vec_do_parts[i]) vector.store(v16, lds_buffer, [idx0]) - def store_q_scale_tile_to_lds(vec_scale, lds_buffer): vec_width = bytes_per_thread_qo_scale idx = tx * vec_width @@ -833,8 +1113,7 @@ def store_do_scale_tile_to_lds(vec_scale, lds_buffer): if total_threads > bytes_per_tile_qo_scale: idx = idx % bytes_per_tile_qo_scale vector.store(vec_scale, lds_buffer, [idx]) - - + # ── Compute tile (MFMA) ─────────────────────────────────────────── def pack_i64x4_to_i32x8(x0, x1, x2, x3): @@ -843,8 +1122,9 @@ def pack_i64x4_to_i32x8(x0, x1, x2, x3): v4 = vector.from_elements(vec4_i64, [x0, x1, x2, x3]) return vector.bitcast(vec8_i32, v4) - - def compute_qk(lds_a_buffer, lds_a_scale_buffer, lds_b_buffer, lds_b_scale_buffer): + def compute_qk( + lds_a_buffer, lds_a_scale_buffer, lds_b_buffer, lds_b_scale_buffer + ): # (m, head) @ (head, n) = (m, n) current_accs_list = [acc_init] * ps_n_accs @@ -852,7 +1132,9 @@ def compute_qk(lds_a_buffer, lds_a_scale_buffer, lds_b_buffer, lds_b_scale_buffe ku0 = 0 ku1 = 1 - lds_col0 = ku0 * 64 + lane_div_16 * 16 # 16 elements packed per lane, 64 per wave + lds_col0 = ( + ku0 * 64 + lane_div_16 * 16 + ) # 16 elements packed per lane, 64 per wave lds_col1 = ku1 * 64 + lane_div_16 * 16 lds_scale_col = lane_div_16 if const_expr(tile_head == 64): @@ -860,35 +1142,55 @@ def compute_qk(lds_a_buffer, lds_a_scale_buffer, lds_b_buffer, lds_b_scale_buffe for mi in range_constexpr(ps_m_num_subtiles): lds_a_row = ps_m_wave_id * ps_m_per_wave + mi * 16 + lane_mod_16 - a0, a1 = lds_load_packs_k64(lds_a_row, lds_col0, tile_head, lds_a_buffer) + a0, a1 = lds_load_packs_k64( + lds_a_row, lds_col0, tile_head, lds_a_buffer + ) if const_expr(tile_head == 128): - a2, a3 = lds_load_packs_k64(lds_a_row, lds_col1, tile_head, lds_a_buffer) + a2, a3 = lds_load_packs_k64( + lds_a_row, lds_col1, tile_head, lds_a_buffer + ) else: a2 = a3 = fx.Int64(0) a128 = pack_i64x4_to_i32x8(a0, a1, a2, a3) - lds_a_scale_row = lds_a_row - a_scale = lds_scale_load(lds_a_scale_row, lds_scale_col, tile_head_mx, lds_a_scale_buffer) + lds_a_scale_row = lds_a_row + a_scale = lds_scale_load( + lds_a_scale_row, lds_scale_col, tile_head_mx, lds_a_scale_buffer + ) for ni in range_constexpr(ps_n_num_subtiles): - lds_b_row = ps_n_wave_id * ps_n_per_wave + ni * 16 + lane_mod_16 - b0, b1 = lds_load_packs_k64(lds_b_row, lds_col0, tile_head, lds_b_buffer) + lds_b_row = ps_n_wave_id * ps_n_per_wave + ni * 16 + lane_mod_16 + b0, b1 = lds_load_packs_k64( + lds_b_row, lds_col0, tile_head, lds_b_buffer + ) if const_expr(tile_head == 128): - b2, b3 = lds_load_packs_k64(lds_b_row, lds_col1, tile_head, lds_b_buffer) + b2, b3 = lds_load_packs_k64( + lds_b_row, lds_col1, tile_head, lds_b_buffer + ) else: b2 = b3 = fx.Int64(0) b128 = pack_i64x4_to_i32x8(b0, b1, b2, b3) - b_scale = lds_scale_load(lds_b_row, lds_scale_col, tile_head_mx, lds_b_scale_buffer) + b_scale = lds_scale_load( + lds_b_row, lds_scale_col, tile_head_mx, lds_b_scale_buffer + ) - #fx.printf("ni={}, mi={}", ni, mi) + # fx.printf("ni={}, mi={}", ni, mi) acc_idx = mi * ps_n_num_subtiles + ni current_accs_list[acc_idx] = rocdl.mfma_scale_f32_16x16x128_f8f6f4( mfma_res_ty, - [a128, b128, current_accs_list[acc_idx], - 0, 0, 0, a_scale, 0, b_scale], + [ + a128, + b128, + current_accs_list[acc_idx], + 0, + 0, + 0, + a_scale, + 0, + b_scale, + ], ) return current_accs_list - def softmax(accs_in, offset_m): # inputs are tile_m x tile_n shape @@ -896,26 +1198,47 @@ def softmax(accs_in, offset_m): accs_out = [acc_init] * ps_n_accs for mi in range_constexpr(ps_m_num_subtiles): - global_m_norm_idx = offset_m + ps_m_wave_id * ps_m_per_wave + mi * 16 + lane_div_16 * 4 - m_norm_vector = buffer_ops.buffer_load(M_rsrc, global_m_norm_idx, vec_width=4) + global_m_norm_idx = ( + offset_m + ps_m_wave_id * ps_m_per_wave + mi * 16 + lane_div_16 * 4 + ) + m_norm_vector = buffer_ops.buffer_load( + M_rsrc, global_m_norm_idx, vec_width=4 + ) for ni in range_constexpr(ps_n_num_subtiles): - + acc_idx = mi * ps_n_num_subtiles + ni acc = accs_in[acc_idx] vals_f32 = [] for ii in range_constexpr(4): - val_f32 = vector.extract(acc, static_position=[ii], dynamic_position=[]) - m_norm = vector.extract(m_norm_vector, static_position=[ii], dynamic_position=[]) - val_f32 = val_f32 * c_sm_scale + val_f32 = vector.extract( + acc, static_position=[ii], dynamic_position=[] + ) + m_norm = vector.extract( + m_norm_vector, static_position=[ii], dynamic_position=[] + ) + val_f32 = val_f32 * c_sm_scale val_f32 = val_f32 - m_norm - val_f32 = val_f32 * log2e + val_f32 = val_f32 * log2e val_f32 = rocdl.exp2(T.f32, val_f32) if causal: - global_m = offset_m + ps_m_wave_id * ps_m_per_wave + mi * 16 + lane_div_16 * 4 + ii - global_n = global_offset_n + ps_n_wave_id * ps_n_per_wave + ni * 16 + lane_mod_16 - needs_mask = arith.cmpi(arith.CmpIPredicate.ugt, global_n, global_m) + global_m = ( + offset_m + + ps_m_wave_id * ps_m_per_wave + + mi * 16 + + lane_div_16 * 4 + + ii + ) + global_n = ( + global_offset_n + + ps_n_wave_id * ps_n_per_wave + + ni * 16 + + lane_mod_16 + ) + needs_mask = arith.cmpi( + arith.CmpIPredicate.ugt, global_n, global_m + ) mask_if = scf.IfOp(needs_mask, [T.f32], has_else=True) with ir.InsertionPoint(mask_if.then_block): scf.YieldOp([arith.constant(0.0, type=T.f32)]) @@ -928,16 +1251,19 @@ def softmax(accs_in, offset_m): return accs_out - - def compute_dv(accs_in, lds_a_buffer, lds_a_scale_buffer, lds_b_buffer, lds_b_scale_buffer): + def compute_dv( + accs_in, lds_a_buffer, lds_a_scale_buffer, lds_b_buffer, lds_b_scale_buffer + ): current_accs_list = list(accs_in) - mfma_res_ty = T.f32x4 + mfma_res_ty = T.f32x4 num_subtiles_reduction = max(1, tile_m // 128) for ku128 in range_constexpr(num_subtiles_reduction): ku0 = ku128 * 2 ku1 = ku0 + 1 - lds_a_col0 = ku0 * 64 + lane_div_16 * 16 # 16 elements packed per lane, 64 per wave + lds_a_col0 = ( + ku0 * 64 + lane_div_16 * 16 + ) # 16 elements packed per lane, 64 per wave lds_a_col1 = ku1 * 64 + lane_div_16 * 16 lds_b_row0 = ku0 * 64 + lane_div_16 * 16 + lane_div_2 @@ -953,77 +1279,132 @@ def compute_dv(accs_in, lds_a_buffer, lds_a_scale_buffer, lds_b_buffer, lds_b_sc for ni in range_constexpr(dv_n_num_subtiles): lds_a_row = dv_n_wave_id * dv_n_per_wave + ni * 16 + lane_mod_16 - a0, a1 = lds_load_packs_k64(lds_a_row, lds_a_col0, tile_m, lds_a_buffer) + a0, a1 = lds_load_packs_k64( + lds_a_row, lds_a_col0, tile_m, lds_a_buffer + ) if const_expr(tile_m == 128): - a2, a3 = lds_load_packs_k64(lds_a_row, lds_a_col1, tile_m, lds_a_buffer) + a2, a3 = lds_load_packs_k64( + lds_a_row, lds_a_col1, tile_m, lds_a_buffer + ) else: a2 = a3 = fx.Int64(0) a128 = pack_i64x4_to_i32x8(a0, a1, a2, a3) - a_scale = lds_scale_load(lds_a_row, lds_a_scale_col, tile_m_mx, lds_a_scale_buffer) + a_scale = lds_scale_load( + lds_a_row, lds_a_scale_col, tile_m_mx, lds_a_scale_buffer + ) for hi in range_constexpr(dv_head_num_subtiles): - lds_b_col = dv_head_wave_id * dv_head_per_wave + hi * 16 - b0 = lds_load_packs_k32_transposed(lds_b_row0, lds_b_col, tile_head, lds_b_buffer) - b1 = lds_load_packs_k32_transposed(lds_b_row1, lds_b_col, tile_head, lds_b_buffer) + lds_b_col = dv_head_wave_id * dv_head_per_wave + hi * 16 + b0 = lds_load_packs_k32_transposed( + lds_b_row0, lds_b_col, tile_head, lds_b_buffer + ) + b1 = lds_load_packs_k32_transposed( + lds_b_row1, lds_b_col, tile_head, lds_b_buffer + ) if const_expr(tile_m == 128): - b2 = lds_load_packs_k32_transposed(lds_b_row2, lds_b_col, tile_head, lds_b_buffer) - b3 = lds_load_packs_k32_transposed(lds_b_row3, lds_b_col, tile_head, lds_b_buffer) + b2 = lds_load_packs_k32_transposed( + lds_b_row2, lds_b_col, tile_head, lds_b_buffer + ) + b3 = lds_load_packs_k32_transposed( + lds_b_row3, lds_b_col, tile_head, lds_b_buffer + ) else: b2 = fx.Int64(0) b3 = fx.Int64(0) b128 = pack_i64x4_to_i32x8(b0, b1, b2, b3) - lds_b_scale_col = dv_head_wave_id * dv_head_per_wave + hi * 16 + lane_mod_16 - b_scale = lds_scale_load(lds_b_scale_row, lds_b_scale_col, tile_head, lds_b_scale_buffer) + lds_b_scale_col = ( + dv_head_wave_id * dv_head_per_wave + hi * 16 + lane_mod_16 + ) + b_scale = lds_scale_load( + lds_b_scale_row, + lds_b_scale_col, + tile_head, + lds_b_scale_buffer, + ) acc_idx = ni * dv_head_num_subtiles + hi - current_accs_list[acc_idx] = rocdl.mfma_scale_f32_16x16x128_f8f6f4( - mfma_res_ty, - [a128, b128, current_accs_list[acc_idx], - 0, 0, 0, a_scale, 0, b_scale], + current_accs_list[acc_idx] = ( + rocdl.mfma_scale_f32_16x16x128_f8f6f4( + mfma_res_ty, + [ + a128, + b128, + current_accs_list[acc_idx], + 0, + 0, + 0, + a_scale, + 0, + b_scale, + ], + ) ) return current_accs_list - - def compute_dp(lds_a_buffer, lds_a_scale_buffer, lds_b_buffer, lds_b_scale_buffer): + def compute_dp( + lds_a_buffer, lds_a_scale_buffer, lds_b_buffer, lds_b_scale_buffer + ): current_accs_list = [acc_init] * ps_n_accs mfma_res_ty = T.f32x4 ku0 = 0 ku1 = 1 - lds_col0 = ku0 * 64 + lane_div_16 * 16 # 16 elements packed per lane, 64 per wave + lds_col0 = ( + ku0 * 64 + lane_div_16 * 16 + ) # 16 elements packed per lane, 64 per wave lds_col1 = ku1 * 64 + lane_div_16 * 16 lds_scale_col = lane_div_16 if const_expr(tile_head == 64): lds_scale_col = lds_scale_col % 2 for mi in range_constexpr(ps_m_num_subtiles): - lds_a_row = ps_m_wave_id * ps_m_per_wave + mi * 16 + lane_mod_16 - a0, a1 = lds_load_packs_k64(lds_a_row, lds_col0, tile_head, lds_a_buffer) + lds_a_row = ps_m_wave_id * ps_m_per_wave + mi * 16 + lane_mod_16 + a0, a1 = lds_load_packs_k64( + lds_a_row, lds_col0, tile_head, lds_a_buffer + ) if const_expr(tile_head == 128): - a2, a3 = lds_load_packs_k64(lds_a_row, lds_col1, tile_head, lds_a_buffer) + a2, a3 = lds_load_packs_k64( + lds_a_row, lds_col1, tile_head, lds_a_buffer + ) else: a2 = a3 = fx.Int64(0) a128 = pack_i64x4_to_i32x8(a0, a1, a2, a3) - a_scale = lds_scale_load(lds_a_row, lds_scale_col, tile_head_mx, lds_a_scale_buffer) + a_scale = lds_scale_load( + lds_a_row, lds_scale_col, tile_head_mx, lds_a_scale_buffer + ) for ni in range_constexpr(ps_n_num_subtiles): lds_b_row = ps_n_wave_id * ps_n_per_wave + ni * 16 + lane_mod_16 - b0, b1 = lds_load_packs_k64(lds_b_row, lds_col0, tile_head, lds_b_buffer) + b0, b1 = lds_load_packs_k64( + lds_b_row, lds_col0, tile_head, lds_b_buffer + ) if const_expr(tile_head == 128): - b2, b3 = lds_load_packs_k64(lds_b_row, lds_col1, tile_head, lds_b_buffer) + b2, b3 = lds_load_packs_k64( + lds_b_row, lds_col1, tile_head, lds_b_buffer + ) else: b2 = b3 = fx.Int64(0) b128 = pack_i64x4_to_i32x8(b0, b1, b2, b3) - b_scale = lds_scale_load(lds_b_row, lds_scale_col, tile_head_mx, lds_b_scale_buffer) + b_scale = lds_scale_load( + lds_b_row, lds_scale_col, tile_head_mx, lds_b_scale_buffer + ) acc_idx = mi * ps_n_num_subtiles + ni current_accs_list[acc_idx] = rocdl.mfma_scale_f32_16x16x128_f8f6f4( mfma_res_ty, - [a128, b128, current_accs_list[acc_idx], - 0, 0, 0, a_scale, 0, b_scale], + [ + a128, + b128, + current_accs_list[acc_idx], + 0, + 0, + 0, + a_scale, + 0, + b_scale, + ], ) return current_accs_list - def compute_ds(dp_accs, p_accs, offset_m): # inputs are tile_m x tile_n shape @@ -1032,28 +1413,35 @@ def compute_ds(dp_accs, p_accs, offset_m): for mi in range_constexpr(ps_m_num_subtiles): - global_D_idx = offset_m + ps_m_wave_id * ps_m_per_wave + mi * 16 + lane_div_16 * 4 + global_D_idx = ( + offset_m + ps_m_wave_id * ps_m_per_wave + mi * 16 + lane_div_16 * 4 + ) D_vector = buffer_ops.buffer_load(D_rsrc, global_D_idx, vec_width=4) for ni in range_constexpr(ps_n_num_subtiles): - + acc_idx = mi * ps_n_num_subtiles + ni dp_f32x4 = dp_accs[acc_idx] p_f32x4 = p_accs[acc_idx] vals_f32 = [] for ii in range_constexpr(4): - dp_f32 = vector.extract(dp_f32x4, static_position=[ii], dynamic_position=[]) - p_f32 = vector.extract(p_f32x4, static_position=[ii], dynamic_position=[]) - D = vector.extract(D_vector, static_position=[ii], dynamic_position=[]) - ds_f32 = p_f32 * (dp_f32 - D) + dp_f32 = vector.extract( + dp_f32x4, static_position=[ii], dynamic_position=[] + ) + p_f32 = vector.extract( + p_f32x4, static_position=[ii], dynamic_position=[] + ) + D = vector.extract( + D_vector, static_position=[ii], dynamic_position=[] + ) + ds_f32 = p_f32 * (dp_f32 - D) vals_f32.append(ds_f32) vals_f32_vector = vector.from_elements(T.f32x4, vals_f32) accs_out[acc_idx] = vals_f32_vector return accs_out - def wave_reduce_max_4threads(x): width_i32 = arith.constant(64, type=T.i32) @@ -1063,14 +1451,13 @@ def wave_reduce_max_4threads(x): peer = w.shuffle_xor(off, width_i32) w = w.maximumf(peer) return w - def mxquant_m_and_store_to_lds(accs_in, lds_buffer, lds_buffer_scale): # inputs are tile_m x tile_n shape for mi in range_constexpr(ps_m_num_subtiles // 2): for ni in range_constexpr(ps_n_num_subtiles): - + acc_idx0 = (mi * 2) * ps_n_num_subtiles + ni acc_idx1 = (mi * 2 + 1) * ps_n_num_subtiles + ni acc0 = accs_in[acc_idx0] @@ -1080,9 +1467,13 @@ def mxquant_m_and_store_to_lds(accs_in, lds_buffer, lds_buffer_scale): vals_subtile1 = [] vals_abs = [] for ii in range_constexpr(4): - val0 = vector.extract(acc0, static_position=[ii], dynamic_position=[]) + val0 = vector.extract( + acc0, static_position=[ii], dynamic_position=[] + ) vals_subtile0.append(val0) - val1 = vector.extract(acc1, static_position=[ii], dynamic_position=[]) + val1 = vector.extract( + acc1, static_position=[ii], dynamic_position=[] + ) vals_subtile1.append(val1) val0_abs = fx_math.absf(val0) val1_abs = fx_math.absf(val1) @@ -1097,42 +1488,67 @@ def mxquant_m_and_store_to_lds(accs_in, lds_buffer, lds_buffer_scale): val_max = val_max + arith.constant(0x007FFFFF, type=T.i32) val_max = val_max & arith.constant(0x7F800000, type=T.i32) val_max_f32 = arith.bitcast(T.f32, val_max) - val_max_rcp = arith.select(val_max_f32 == zero_f, zero_f, rocdl.rcp(T.f32, val_max_f32)) + val_max_rcp = arith.select( + val_max_f32 == zero_f, zero_f, rocdl.rcp(T.f32, val_max_f32) + ) scale = val_max >> 23 scale = arith.trunci(T.i8, scale) scale_vector = vector.from_elements(T.vec(1, T.i8), [scale]) for ii in range_constexpr(4): - vals_subtile0[ii] = vals_subtile0[ii] * val_max_rcp + vals_subtile0[ii] = vals_subtile0[ii] * val_max_rcp vals_subtile1[ii] = vals_subtile1[ii] * val_max_rcp - val_f8_packed_i32 = rocdl.cvt_pk_fp8_f32(T.i32, vals_subtile0[0], vals_subtile0[1], fx.Int32(0), False) - val_f8_packed_i32 = rocdl.cvt_pk_fp8_f32(T.i32, vals_subtile0[2], vals_subtile0[3], val_f8_packed_i32, True) - val_f8_packed_i32_vector = vector.from_elements(T.vec(1, T.i32), [val_f8_packed_i32]) + val_f8_packed_i32 = rocdl.cvt_pk_fp8_f32( + T.i32, vals_subtile0[0], vals_subtile0[1], fx.Int32(0), False + ) + val_f8_packed_i32 = rocdl.cvt_pk_fp8_f32( + T.i32, + vals_subtile0[2], + vals_subtile0[3], + val_f8_packed_i32, + True, + ) + val_f8_packed_i32_vector = vector.from_elements( + T.vec(1, T.i32), [val_f8_packed_i32] + ) val_f8x4_subtile0 = vector.bitcast(T.f8x4, val_f8_packed_i32_vector) - val_f8_packed_i32 = rocdl.cvt_pk_fp8_f32(T.i32, vals_subtile1[0], vals_subtile1[1], fx.Int32(0), False) - val_f8_packed_i32 = rocdl.cvt_pk_fp8_f32(T.i32, vals_subtile1[2], vals_subtile1[3], val_f8_packed_i32, True) - val_f8_packed_i32_vector = vector.from_elements(T.vec(1, T.i32), [val_f8_packed_i32]) + val_f8_packed_i32 = rocdl.cvt_pk_fp8_f32( + T.i32, vals_subtile1[0], vals_subtile1[1], fx.Int32(0), False + ) + val_f8_packed_i32 = rocdl.cvt_pk_fp8_f32( + T.i32, + vals_subtile1[2], + vals_subtile1[3], + val_f8_packed_i32, + True, + ) + val_f8_packed_i32_vector = vector.from_elements( + T.vec(1, T.i32), [val_f8_packed_i32] + ) val_f8x4_subtile1 = vector.bitcast(T.f8x4, val_f8_packed_i32_vector) lds_row = ps_n_wave_id * ps_n_per_wave + ni * 16 + lane_mod_16 - lds_col_base0 = ps_m_wave_id * ps_m_per_wave + (mi * 2) * 16 #+ lane_div_16 * 4 - lds_col_base1 = ps_m_wave_id * ps_m_per_wave + (mi * 2 + 1) * 16 #+ lane_div_16 * 4 + lds_col_base0 = ( + ps_m_wave_id * ps_m_per_wave + (mi * 2) * 16 + ) # + lane_div_16 * 4 + lds_col_base1 = ( + ps_m_wave_id * ps_m_per_wave + (mi * 2 + 1) * 16 + ) # + lane_div_16 * 4 lds_col0 = swizzle_xor16(lds_row, lds_col_base0, tile_m_div16) lds_col1 = swizzle_xor16(lds_row, lds_col_base1, tile_m_div16) lds_col0 = lds_col0 + lane_div_16 * 4 lds_col1 = lds_col1 + lane_div_16 * 4 - lds_scale_col = ps_m_wave_id * ps_m_mx_per_wave + mi + lds_scale_col = ps_m_wave_id * ps_m_mx_per_wave + mi lds_idx0 = lds_row * tile_m + lds_col0 lds_idx1 = lds_row * tile_m + lds_col1 lds_scale_idx = lds_row * tile_m_mx + lds_scale_col - + vector.store(val_f8x4_subtile0, lds_buffer, [lds_idx0]) vector.store(val_f8x4_subtile1, lds_buffer, [lds_idx1]) vector.store(scale_vector, lds_buffer_scale, [lds_scale_idx]) - def wave_reduce_max_16threads(x): width_i32 = arith.constant(64, type=T.i32) w = x @@ -1141,14 +1557,13 @@ def wave_reduce_max_16threads(x): peer = w.shuffle_xor(off, width_i32) w = w.maximumf(peer) return w - def mxquant_n_and_store_to_lds(accs_in, lds_buffer, lds_buffer_scale): # inputs are tile_m x tile_n shape for mi in range_constexpr(ps_m_num_subtiles): for ni in range_constexpr(ps_n_num_subtiles // 2): - + acc_idx0 = mi * ps_n_num_subtiles + ni * 2 acc_idx1 = mi * ps_n_num_subtiles + ni * 2 + 1 acc0 = accs_in[acc_idx0] @@ -1158,8 +1573,12 @@ def mxquant_n_and_store_to_lds(accs_in, lds_buffer, lds_buffer_scale): vals_subtile1 = [] scales = [] for ii in range_constexpr(4): - val0 = vector.extract(acc0, static_position=[ii], dynamic_position=[]) - val1 = vector.extract(acc1, static_position=[ii], dynamic_position=[]) + val0 = vector.extract( + acc0, static_position=[ii], dynamic_position=[] + ) + val1 = vector.extract( + acc1, static_position=[ii], dynamic_position=[] + ) val0_abs = fx_math.absf(val0) val1_abs = fx_math.absf(val1) val_max = arith.maximumf(val0_abs, val1_abs) @@ -1169,7 +1588,9 @@ def mxquant_n_and_store_to_lds(accs_in, lds_buffer, lds_buffer_scale): val_max = val_max + arith.constant(0x007FFFFF, type=T.i32) val_max = val_max & arith.constant(0x7F800000, type=T.i32) val_max_f32 = arith.bitcast(T.f32, val_max) - val_max_rcp = arith.select(val_max_f32 == zero_f, zero_f, rocdl.rcp(T.f32, val_max_f32)) + val_max_rcp = arith.select( + val_max_f32 == zero_f, zero_f, rocdl.rcp(T.f32, val_max_f32) + ) val0_quant = val0 * val_max_rcp vals_subtile0.append(val0_quant) val1_quant = val1 * val_max_rcp @@ -1178,19 +1599,45 @@ def mxquant_n_and_store_to_lds(accs_in, lds_buffer, lds_buffer_scale): scale = arith.trunci(T.i8, scale) scales.append(scale) - val_f8_packed_i32 = rocdl.cvt_pk_fp8_f32(T.i32, vals_subtile0[0], vals_subtile0[1], fx.Int32(0), False) - val_f8_packed_i32 = rocdl.cvt_pk_fp8_f32(T.i32, vals_subtile0[2], vals_subtile0[3], val_f8_packed_i32, True) - val_f8_packed_i32_vector = vector.from_elements(T.vec(1, T.i32), [val_f8_packed_i32]) + val_f8_packed_i32 = rocdl.cvt_pk_fp8_f32( + T.i32, vals_subtile0[0], vals_subtile0[1], fx.Int32(0), False + ) + val_f8_packed_i32 = rocdl.cvt_pk_fp8_f32( + T.i32, + vals_subtile0[2], + vals_subtile0[3], + val_f8_packed_i32, + True, + ) + val_f8_packed_i32_vector = vector.from_elements( + T.vec(1, T.i32), [val_f8_packed_i32] + ) val_f8x4_subtile0 = vector.bitcast(T.f8x4, val_f8_packed_i32_vector) - val_f8_packed_i32 = rocdl.cvt_pk_fp8_f32(T.i32, vals_subtile1[0], vals_subtile1[1], fx.Int32(0), False) - val_f8_packed_i32 = rocdl.cvt_pk_fp8_f32(T.i32, vals_subtile1[2], vals_subtile1[3], val_f8_packed_i32, True) - val_f8_packed_i32_vector = vector.from_elements(T.vec(1, T.i32), [val_f8_packed_i32]) + val_f8_packed_i32 = rocdl.cvt_pk_fp8_f32( + T.i32, vals_subtile1[0], vals_subtile1[1], fx.Int32(0), False + ) + val_f8_packed_i32 = rocdl.cvt_pk_fp8_f32( + T.i32, + vals_subtile1[2], + vals_subtile1[3], + val_f8_packed_i32, + True, + ) + val_f8_packed_i32_vector = vector.from_elements( + T.vec(1, T.i32), [val_f8_packed_i32] + ) val_f8x4_subtile1 = vector.bitcast(T.f8x4, val_f8_packed_i32_vector) - lds_row0 = ps_n_wave_id * ps_n_per_wave + (ni * 2) * 16 + lane_mod_16 - lds_row1 = ps_n_wave_id * ps_n_per_wave + (ni * 2 + 1) * 16 + lane_mod_16 - lds_col_base = ps_m_wave_id * ps_m_per_wave + mi * 16 #+ lane_div_16 * 4 + lds_row0 = ( + ps_n_wave_id * ps_n_per_wave + (ni * 2) * 16 + lane_mod_16 + ) + lds_row1 = ( + ps_n_wave_id * ps_n_per_wave + (ni * 2 + 1) * 16 + lane_mod_16 + ) + lds_col_base = ( + ps_m_wave_id * ps_m_per_wave + mi * 16 + ) # + lane_div_16 * 4 lds_col0 = swizzle_xor16(lds_row0, lds_col_base, tile_m_div16) lds_col1 = swizzle_xor16(lds_row1, lds_col_base, tile_m_div16) lds_col0 = lds_col0 + lane_div_16 * 4 @@ -1201,23 +1648,33 @@ def mxquant_n_and_store_to_lds(accs_in, lds_buffer, lds_buffer_scale): vector.store(val_f8x4_subtile1, lds_buffer, [lds_idx1]) for ii in range_constexpr(4): - lds_scale_row = ps_m_wave_id * ps_m_per_wave + mi * 16 + lane_div_16 * 4 + ii - lds_scale_col = ps_n_wave_id * ps_n_mx_per_wave + ni + lds_scale_row = ( + ps_m_wave_id * ps_m_per_wave + + mi * 16 + + lane_div_16 * 4 + + ii + ) + lds_scale_col = ps_n_wave_id * ps_n_mx_per_wave + ni lds_scale_idx = lds_scale_row * tile_n_mx + lds_scale_col - - scale_vector = vector.from_elements(T.vec(1, T.i8), [scales[ii]]) - vector.store(scale_vector, lds_buffer_scale, [lds_scale_idx]) + scale_vector = vector.from_elements( + T.vec(1, T.i8), [scales[ii]] + ) + vector.store(scale_vector, lds_buffer_scale, [lds_scale_idx]) - def compute_dk(accs_in, lds_a_buffer, lds_a_scale_buffer, lds_b_buffer, lds_b_scale_buffer): + def compute_dk( + accs_in, lds_a_buffer, lds_a_scale_buffer, lds_b_buffer, lds_b_scale_buffer + ): current_accs_list = list(accs_in) - mfma_res_ty = T.f32x4 - num_subtiles_reduction = max(1, tile_m // 128) + mfma_res_ty = T.f32x4 + num_subtiles_reduction = max(1, tile_m // 128) for ku128 in range_constexpr(num_subtiles_reduction): ku0 = ku128 * 2 ku1 = ku0 + 1 - lds_a_col0 = ku0 * 64 + lane_div_16 * 16 # 16 elements packed per lane, 64 per wave + lds_a_col0 = ( + ku0 * 64 + lane_div_16 * 16 + ) # 16 elements packed per lane, 64 per wave lds_a_col1 = ku1 * 64 + lane_div_16 * 16 lds_b_row0 = ku0 * 64 + lane_div_16 * 16 + lane_div_2 @@ -1233,40 +1690,75 @@ def compute_dk(accs_in, lds_a_buffer, lds_a_scale_buffer, lds_b_buffer, lds_b_sc for ni in range_constexpr(dk_num_subtiles_n): lds_a_row = dk_n_wave_id * dk_n_per_wave + ni * 16 + lane_mod_16 - a0, a1 = lds_load_packs_k64(lds_a_row, lds_a_col0, tile_m, lds_a_buffer) + a0, a1 = lds_load_packs_k64( + lds_a_row, lds_a_col0, tile_m, lds_a_buffer + ) if const_expr(tile_m == 128): - a2, a3 = lds_load_packs_k64(lds_a_row, lds_a_col1, tile_m, lds_a_buffer) + a2, a3 = lds_load_packs_k64( + lds_a_row, lds_a_col1, tile_m, lds_a_buffer + ) else: a2 = fx.Int64(0) a3 = fx.Int64(0) a128 = pack_i64x4_to_i32x8(a0, a1, a2, a3) - a_scale = lds_scale_load(lds_a_row, lds_a_scale_col, tile_m_mx, lds_a_scale_buffer) + a_scale = lds_scale_load( + lds_a_row, lds_a_scale_col, tile_m_mx, lds_a_scale_buffer + ) for hi in range_constexpr(dk_num_subtiles_head): - lds_b_col = dk_head_wave_id * dk_head_per_wave + hi * 16 #+ lane_mod_2 * 8 - b0 = lds_load_packs_k32_transposed(lds_b_row0, lds_b_col, tile_head, lds_b_buffer) - b1 = lds_load_packs_k32_transposed(lds_b_row1, lds_b_col, tile_head, lds_b_buffer) + lds_b_col = ( + dk_head_wave_id * dk_head_per_wave + hi * 16 + ) # + lane_mod_2 * 8 + b0 = lds_load_packs_k32_transposed( + lds_b_row0, lds_b_col, tile_head, lds_b_buffer + ) + b1 = lds_load_packs_k32_transposed( + lds_b_row1, lds_b_col, tile_head, lds_b_buffer + ) if const_expr(tile_m == 128): - b2 = lds_load_packs_k32_transposed(lds_b_row2, lds_b_col, tile_head, lds_b_buffer) - b3 = lds_load_packs_k32_transposed(lds_b_row3, lds_b_col, tile_head, lds_b_buffer) + b2 = lds_load_packs_k32_transposed( + lds_b_row2, lds_b_col, tile_head, lds_b_buffer + ) + b3 = lds_load_packs_k32_transposed( + lds_b_row3, lds_b_col, tile_head, lds_b_buffer + ) else: b2 = fx.Int64(0) b3 = fx.Int64(0) b128 = pack_i64x4_to_i32x8(b0, b1, b2, b3) - lds_b_scale_col = dk_head_wave_id * dk_head_per_wave + hi * 16 + lane_mod_16 - b_scale = lds_scale_load(lds_b_scale_row, lds_b_scale_col, tile_head, lds_b_scale_buffer) + lds_b_scale_col = ( + dk_head_wave_id * dk_head_per_wave + hi * 16 + lane_mod_16 + ) + b_scale = lds_scale_load( + lds_b_scale_row, + lds_b_scale_col, + tile_head, + lds_b_scale_buffer, + ) acc_idx = ni * dk_num_subtiles_head + hi - current_accs_list[acc_idx] = rocdl.mfma_scale_f32_16x16x128_f8f6f4( - mfma_res_ty, - [a128, b128, current_accs_list[acc_idx], - 0, 0, 0, a_scale, 0, b_scale], + current_accs_list[acc_idx] = ( + rocdl.mfma_scale_f32_16x16x128_f8f6f4( + mfma_res_ty, + [ + a128, + b128, + current_accs_list[acc_idx], + 0, + 0, + 0, + a_scale, + 0, + b_scale, + ], + ) ) return current_accs_list - - def compute_dq(lds_a_buffer, lds_a_scale_buffer, lds_b_buffer, lds_b_scale_buffer): + def compute_dq( + lds_a_buffer, lds_a_scale_buffer, lds_b_buffer, lds_b_scale_buffer + ): # (m, n) @ (n, head) = (m, head) current_accs_list = [acc_init] * dq_n_accs @@ -1294,60 +1786,116 @@ def compute_dq(lds_a_buffer, lds_a_scale_buffer, lds_b_buffer, lds_b_scale_buffe lds_b_scale_row = lds_b_scale_row % 2 for mi in range_constexpr(dq_num_subtiles_m): - lds_a_col = dq_m_wave_id * dq_m_per_wave + mi * 16 #+ lane_mod_2 * 8 - a0 = lds_load_packs_k32_transposed(lds_a_row0, lds_a_col, tile_m, lds_a_buffer) - a1 = lds_load_packs_k32_transposed(lds_a_row1, lds_a_col, tile_m, lds_a_buffer) + lds_a_col = ( + dq_m_wave_id * dq_m_per_wave + mi * 16 + ) # + lane_mod_2 * 8 + a0 = lds_load_packs_k32_transposed( + lds_a_row0, lds_a_col, tile_m, lds_a_buffer + ) + a1 = lds_load_packs_k32_transposed( + lds_a_row1, lds_a_col, tile_m, lds_a_buffer + ) if const_expr(tile_n == 128): - a2 = lds_load_packs_k32_transposed(lds_a_row2, lds_a_col, tile_m, lds_a_buffer) - a3 = lds_load_packs_k32_transposed(lds_a_row3, lds_a_col, tile_m, lds_a_buffer) + a2 = lds_load_packs_k32_transposed( + lds_a_row2, lds_a_col, tile_m, lds_a_buffer + ) + a3 = lds_load_packs_k32_transposed( + lds_a_row3, lds_a_col, tile_m, lds_a_buffer + ) else: a2 = fx.Int64(0) a3 = fx.Int64(0) a128 = pack_i64x4_to_i32x8(a0, a1, a2, a3) - lds_a_scale_row = dq_m_wave_id * dq_m_per_wave + mi * 16 + lane_mod_16 - a_scale = lds_scale_load(lds_a_scale_row, lds_a_scale_col, tile_n_mx, lds_a_scale_buffer) + lds_a_scale_row = ( + dq_m_wave_id * dq_m_per_wave + mi * 16 + lane_mod_16 + ) + a_scale = lds_scale_load( + lds_a_scale_row, lds_a_scale_col, tile_n_mx, lds_a_scale_buffer + ) for hi in range_constexpr(dq_num_subtiles_head): - lds_b_col = dq_head_wave_id * dq_head_per_wave + hi * 16 #+ lane_mod_2 * 8 - b0 = lds_load_packs_k32_transposed(lds_b_row0, lds_b_col, tile_head, lds_b_buffer) - b1 = lds_load_packs_k32_transposed(lds_b_row1, lds_b_col, tile_head, lds_b_buffer) + lds_b_col = ( + dq_head_wave_id * dq_head_per_wave + hi * 16 + ) # + lane_mod_2 * 8 + b0 = lds_load_packs_k32_transposed( + lds_b_row0, lds_b_col, tile_head, lds_b_buffer + ) + b1 = lds_load_packs_k32_transposed( + lds_b_row1, lds_b_col, tile_head, lds_b_buffer + ) if const_expr(tile_n == 128): - b2 = lds_load_packs_k32_transposed(lds_b_row2, lds_b_col, tile_head, lds_b_buffer) - b3 = lds_load_packs_k32_transposed(lds_b_row3, lds_b_col, tile_head, lds_b_buffer) + b2 = lds_load_packs_k32_transposed( + lds_b_row2, lds_b_col, tile_head, lds_b_buffer + ) + b3 = lds_load_packs_k32_transposed( + lds_b_row3, lds_b_col, tile_head, lds_b_buffer + ) else: b2 = fx.Int64(0) b3 = fx.Int64(0) b128 = pack_i64x4_to_i32x8(b0, b1, b2, b3) - lds_b_scale_col = dq_head_wave_id * dq_head_per_wave + hi * 16 + lane_mod_16 - b_scale = lds_scale_load(lds_b_scale_row, lds_b_scale_col, tile_head, lds_b_scale_buffer) + lds_b_scale_col = ( + dq_head_wave_id * dq_head_per_wave + hi * 16 + lane_mod_16 + ) + b_scale = lds_scale_load( + lds_b_scale_row, + lds_b_scale_col, + tile_head, + lds_b_scale_buffer, + ) acc_idx = mi * dq_num_subtiles_head + hi - current_accs_list[acc_idx] = rocdl.mfma_scale_f32_16x16x128_f8f6f4( - mfma_res_ty, - [a128, b128, current_accs_list[acc_idx], - 0, 0, 0, a_scale, 0, b_scale], + current_accs_list[acc_idx] = ( + rocdl.mfma_scale_f32_16x16x128_f8f6f4( + mfma_res_ty, + [ + a128, + b128, + current_accs_list[acc_idx], + 0, + 0, + 0, + a_scale, + 0, + b_scale, + ], + ) ) return current_accs_list - def store_dq_atomic(final_accs, offset_m): for mi in range_constexpr(dq_num_subtiles_m): for hi in range_constexpr(dq_num_subtiles_head): for ii in range_constexpr(4): - global_row = offset_m + dq_m_wave_id * dq_m_per_wave + mi * 16 + lane_div_16 * 4 + ii - global_col = dq_head_wave_id * dq_head_per_wave + hi * 16 + lane_mod_16 + global_row = ( + offset_m + + dq_m_wave_id * dq_m_per_wave + + mi * 16 + + lane_div_16 * 4 + + ii + ) + global_col = ( + dq_head_wave_id * dq_head_per_wave + hi * 16 + lane_mod_16 + ) global_idx = global_row * head_dim + global_col global_idx_bytes = global_idx * 4 - + acc_idx = mi * dq_num_subtiles_head + hi acc = final_accs[acc_idx] - val_f32 = vector.extract(acc, static_position=[ii], dynamic_position=[]) + val_f32 = vector.extract( + acc, static_position=[ii], dynamic_position=[] + ) val_f32 = val_f32 * c_sm_scale - rocdl.raw_ptr_buffer_atomic_fadd(val_f32, dq_rsrc, fx.Int32(global_idx_bytes), fx.Int32(0), fx.Int32(0)) - #buffer_ops.buffer_store(val_f32, dq_rsrc, global_idx) - + rocdl.raw_ptr_buffer_atomic_fadd( + val_f32, + dq_rsrc, + fx.Int32(global_idx_bytes), + fx.Int32(0), + fx.Int32(0), + ) + # buffer_ops.buffer_store(val_f32, dq_rsrc, global_idx) def store_dk_atomic(final_accs): for ni in range_constexpr(dk_num_subtiles_n): @@ -1356,20 +1904,35 @@ def store_dk_atomic(final_accs): acc = final_accs[acc_idx] for ii in range_constexpr(4): - global_row = global_offset_n + dk_n_wave_id * dk_n_per_wave + ni * 16 + lane_div_16 * 4 + ii - global_col = dk_head_wave_id * dk_head_per_wave + hi * 16 + lane_mod_16 + global_row = ( + global_offset_n + + dk_n_wave_id * dk_n_per_wave + + ni * 16 + + lane_div_16 * 4 + + ii + ) + global_col = ( + dk_head_wave_id * dk_head_per_wave + hi * 16 + lane_mod_16 + ) global_idx = global_row * head_dim + global_col - + acc_idx = ni * dk_num_subtiles_head + hi acc = final_accs[acc_idx] - val_f32 = vector.extract(acc, static_position=[ii], dynamic_position=[]) + val_f32 = vector.extract( + acc, static_position=[ii], dynamic_position=[] + ) val_f32 = val_f32 * c_sm_scale if const_expr(gqa_size == 1): buffer_ops.buffer_store(val_f32, dk_rsrc, global_idx) else: global_idx_bytes = global_idx * 4 - rocdl.raw_ptr_buffer_atomic_fadd(val_f32, dk_rsrc, fx.Int32(global_idx_bytes), fx.Int32(0), fx.Int32(0)) - + rocdl.raw_ptr_buffer_atomic_fadd( + val_f32, + dk_rsrc, + fx.Int32(global_idx_bytes), + fx.Int32(0), + fx.Int32(0), + ) def store_dv_atomic(final_accs): for ni in range_constexpr(dv_n_num_subtiles): @@ -1378,19 +1941,34 @@ def store_dv_atomic(final_accs): acc = final_accs[acc_idx] for ii in range_constexpr(4): - global_row = global_offset_n + dv_n_wave_id * dv_n_per_wave + ni * 16 + lane_div_16 * 4 + ii - global_col = dv_head_wave_id * dv_head_per_wave + hi * 16 + lane_mod_16 + global_row = ( + global_offset_n + + dv_n_wave_id * dv_n_per_wave + + ni * 16 + + lane_div_16 * 4 + + ii + ) + global_col = ( + dv_head_wave_id * dv_head_per_wave + hi * 16 + lane_mod_16 + ) global_idx = global_row * head_dim + global_col - + acc_idx = ni * dv_head_num_subtiles + hi acc = final_accs[acc_idx] - val_f32 = vector.extract(acc, static_position=[ii], dynamic_position=[]) + val_f32 = vector.extract( + acc, static_position=[ii], dynamic_position=[] + ) if const_expr(gqa_size == 1): buffer_ops.buffer_store(val_f32, dv_rsrc, global_idx) else: global_idx_bytes = global_idx * 4 - rocdl.raw_ptr_buffer_atomic_fadd(val_f32, dv_rsrc, fx.Int32(global_idx_bytes), fx.Int32(0), fx.Int32(0)) - + rocdl.raw_ptr_buffer_atomic_fadd( + val_f32, + dv_rsrc, + fx.Int32(global_idx_bytes), + fx.Int32(0), + fx.Int32(0), + ) # ── Scheduling hints ────────────────────────────────────────────── rocdl.sched_barrier(0) @@ -1402,71 +1980,145 @@ def hot_loop_scheduler(): # ── Main pipeline ───────────────────────────────────────────────── def _pack_state(dk, dv): - return list(dk) + list(dv) + return list(dk) + list(dv) def _unpack_state(vals): dk = list(vals[:dk_n_accs]) dv = list(vals[dk_n_accs:]) - return dk, dv + return dk, dv def pingpong(offset_m, inner_state): dk, dv = _unpack_state(inner_state) next_offset_m = offset_m + tile_m next_offset_m_mx = next_offset_m // 32 - store_q_tile_to_lds(prefetch_q_quant_head_tile(next_offset_m), lds_q_quant_head_ping) - store_q_scale_tile_to_lds(prefetch_q_scale_head_tile(next_offset_m), lds_q_scale_head_ping) - store_q_tile_to_lds(prefetch_q_quant_m_tile(next_offset_m), lds_q_quant_m_ping) - store_q_scale_tile_to_lds(prefetch_q_scale_m_tile(next_offset_m_mx), lds_q_scale_m_ping) - store_do_tile_to_lds(prefetch_do_quant_head_tile(next_offset_m), lds_do_quant_head_ping) - store_do_scale_tile_to_lds(prefetch_do_scale_head_tile(next_offset_m), lds_do_scale_head_ping) - store_do_tile_to_lds(prefetch_do_quant_m_tile(next_offset_m), lds_do_quant_m_ping) - store_do_scale_tile_to_lds(prefetch_do_scale_m_tile(next_offset_m_mx), lds_do_scale_m_ping) - - qk = compute_qk(lds_q_quant_head_pong, lds_q_scale_head_pong, lds_k_quant_head, lds_k_scale_head) + store_q_tile_to_lds( + prefetch_q_quant_head_tile(next_offset_m), lds_q_quant_head_ping + ) + store_q_scale_tile_to_lds( + prefetch_q_scale_head_tile(next_offset_m), lds_q_scale_head_ping + ) + store_q_tile_to_lds( + prefetch_q_quant_m_tile(next_offset_m), lds_q_quant_m_ping + ) + store_q_scale_tile_to_lds( + prefetch_q_scale_m_tile(next_offset_m_mx), lds_q_scale_m_ping + ) + store_do_tile_to_lds( + prefetch_do_quant_head_tile(next_offset_m), lds_do_quant_head_ping + ) + store_do_scale_tile_to_lds( + prefetch_do_scale_head_tile(next_offset_m), lds_do_scale_head_ping + ) + store_do_tile_to_lds( + prefetch_do_quant_m_tile(next_offset_m), lds_do_quant_m_ping + ) + store_do_scale_tile_to_lds( + prefetch_do_scale_m_tile(next_offset_m_mx), lds_do_scale_m_ping + ) + + qk = compute_qk( + lds_q_quant_head_pong, + lds_q_scale_head_pong, + lds_k_quant_head, + lds_k_scale_head, + ) p = softmax(qk, offset_m) mxquant_m_and_store_to_lds(p, lds_ppt_shuffle, lds_ppt_scale_shuffle) gpu.barrier() - dv = compute_dv(dv, lds_ppt_shuffle, lds_ppt_scale_shuffle, lds_do_quant_m_pong, lds_do_scale_m_pong) - dp = compute_dp(lds_do_quant_head_pong, lds_do_scale_head_pong, lds_v, lds_v_scale) + dv = compute_dv( + dv, + lds_ppt_shuffle, + lds_ppt_scale_shuffle, + lds_do_quant_m_pong, + lds_do_scale_m_pong, + ) + dp = compute_dp( + lds_do_quant_head_pong, lds_do_scale_head_pong, lds_v, lds_v_scale + ) ds = compute_ds(dp, p, offset_m) mxquant_m_and_store_to_lds(ds, lds_dst_shuffle, lds_dst_scale_shuffle) mxquant_n_and_store_to_lds(ds, lds_ds_shuffle, lds_ds_scale_shuffle) gpu.barrier() - dk = compute_dk(dk, lds_dst_shuffle, lds_dst_scale_shuffle, lds_q_quant_m_pong, lds_q_scale_m_pong) - dq = compute_dq(lds_ds_shuffle, lds_ds_scale_shuffle, lds_k_quant_n, lds_k_scale_n) + dk = compute_dk( + dk, + lds_dst_shuffle, + lds_dst_scale_shuffle, + lds_q_quant_m_pong, + lds_q_scale_m_pong, + ) + dq = compute_dq( + lds_ds_shuffle, lds_ds_scale_shuffle, lds_k_quant_n, lds_k_scale_n + ) store_dq_atomic(dq, offset_m) hot_loop_scheduler() gpu.barrier() next_offset_m = offset_m + (tile_m * 2) next_offset_m_mx = next_offset_m // 32 - store_q_tile_to_lds(prefetch_q_quant_head_tile(next_offset_m), lds_q_quant_head_pong) - store_q_scale_tile_to_lds(prefetch_q_scale_head_tile(next_offset_m), lds_q_scale_head_pong) - store_q_tile_to_lds(prefetch_q_quant_m_tile(next_offset_m), lds_q_quant_m_pong) - store_q_scale_tile_to_lds(prefetch_q_scale_m_tile(next_offset_m_mx), lds_q_scale_m_pong) - store_do_tile_to_lds(prefetch_do_quant_head_tile(next_offset_m), lds_do_quant_head_pong) - store_do_scale_tile_to_lds(prefetch_do_scale_head_tile(next_offset_m), lds_do_scale_head_pong) - store_do_tile_to_lds(prefetch_do_quant_m_tile(next_offset_m), lds_do_quant_m_pong) - store_do_scale_tile_to_lds(prefetch_do_scale_m_tile(next_offset_m_mx), lds_do_scale_m_pong) - - qk = compute_qk(lds_q_quant_head_ping, lds_q_scale_head_ping, lds_k_quant_head, lds_k_scale_head) + store_q_tile_to_lds( + prefetch_q_quant_head_tile(next_offset_m), lds_q_quant_head_pong + ) + store_q_scale_tile_to_lds( + prefetch_q_scale_head_tile(next_offset_m), lds_q_scale_head_pong + ) + store_q_tile_to_lds( + prefetch_q_quant_m_tile(next_offset_m), lds_q_quant_m_pong + ) + store_q_scale_tile_to_lds( + prefetch_q_scale_m_tile(next_offset_m_mx), lds_q_scale_m_pong + ) + store_do_tile_to_lds( + prefetch_do_quant_head_tile(next_offset_m), lds_do_quant_head_pong + ) + store_do_scale_tile_to_lds( + prefetch_do_scale_head_tile(next_offset_m), lds_do_scale_head_pong + ) + store_do_tile_to_lds( + prefetch_do_quant_m_tile(next_offset_m), lds_do_quant_m_pong + ) + store_do_scale_tile_to_lds( + prefetch_do_scale_m_tile(next_offset_m_mx), lds_do_scale_m_pong + ) + + qk = compute_qk( + lds_q_quant_head_ping, + lds_q_scale_head_ping, + lds_k_quant_head, + lds_k_scale_head, + ) p = softmax(qk, offset_m + tile_m) mxquant_m_and_store_to_lds(p, lds_ppt_shuffle, lds_ppt_scale_shuffle) gpu.barrier() - dv = compute_dv(dv, lds_ppt_shuffle, lds_ppt_scale_shuffle, lds_do_quant_m_ping, lds_do_scale_m_ping) - dp = compute_dp(lds_do_quant_head_ping, lds_do_scale_head_ping, lds_v, lds_v_scale) + dv = compute_dv( + dv, + lds_ppt_shuffle, + lds_ppt_scale_shuffle, + lds_do_quant_m_ping, + lds_do_scale_m_ping, + ) + dp = compute_dp( + lds_do_quant_head_ping, lds_do_scale_head_ping, lds_v, lds_v_scale + ) ds = compute_ds(dp, p, offset_m + tile_m) mxquant_m_and_store_to_lds(ds, lds_dst_shuffle, lds_dst_scale_shuffle) mxquant_n_and_store_to_lds(ds, lds_ds_shuffle, lds_ds_scale_shuffle) gpu.barrier() - dk = compute_dk(dk, lds_dst_shuffle, lds_dst_scale_shuffle, lds_q_quant_m_ping, lds_q_scale_m_ping) - dq = compute_dq(lds_ds_shuffle, lds_ds_scale_shuffle, lds_k_quant_n, lds_k_scale_n) + dk = compute_dk( + dk, + lds_dst_shuffle, + lds_dst_scale_shuffle, + lds_q_quant_m_ping, + lds_q_scale_m_ping, + ) + dq = compute_dq( + lds_ds_shuffle, lds_ds_scale_shuffle, lds_k_quant_n, lds_k_scale_n + ) store_dq_atomic(dq, offset_m + tile_m) hot_loop_scheduler() gpu.barrier() - - return _pack_state(dk, dv) + + return _pack_state(dk, dv) if const_expr(causal): start_m = (global_offset_n // (tile_m * 2)) * (tile_m * 2) @@ -1475,19 +2127,29 @@ def pingpong(offset_m, inner_state): start_m_mx = start_m // 32 store_q_tile_to_lds(prefetch_q_quant_head_tile(start_m), lds_q_quant_head_pong) - store_q_scale_tile_to_lds(prefetch_q_scale_head_tile(start_m), lds_q_scale_head_pong) + store_q_scale_tile_to_lds( + prefetch_q_scale_head_tile(start_m), lds_q_scale_head_pong + ) store_q_tile_to_lds(prefetch_q_quant_m_tile(start_m), lds_q_quant_m_pong) - store_q_scale_tile_to_lds(prefetch_q_scale_m_tile(start_m_mx), lds_q_scale_m_pong) + store_q_scale_tile_to_lds( + prefetch_q_scale_m_tile(start_m_mx), lds_q_scale_m_pong + ) store_k_tile_to_lds(prefetch_k_quant_head_tile(), lds_k_quant_head) store_k_scale_tile_to_lds(prefetch_k_scale_head_tile(), lds_k_scale_head) store_k_tile_to_lds(prefetch_k_quant_n_tile(), lds_k_quant_n) store_k_scale_tile_to_lds(prefetch_k_scale_n_tile(), lds_k_scale_n) store_v_tile_to_lds(prefetch_v_tile(), lds_v) store_v_scale_tile_to_lds(prefetch_v_scale_tile(), lds_v_scale) - store_do_tile_to_lds(prefetch_do_quant_head_tile(start_m), lds_do_quant_head_pong) - store_do_scale_tile_to_lds(prefetch_do_scale_head_tile(start_m), lds_do_scale_head_pong) + store_do_tile_to_lds( + prefetch_do_quant_head_tile(start_m), lds_do_quant_head_pong + ) + store_do_scale_tile_to_lds( + prefetch_do_scale_head_tile(start_m), lds_do_scale_head_pong + ) store_do_tile_to_lds(prefetch_do_quant_m_tile(start_m), lds_do_quant_m_pong) - store_do_scale_tile_to_lds(prefetch_do_scale_m_tile(start_m_mx), lds_do_scale_m_pong) + store_do_scale_tile_to_lds( + prefetch_do_scale_m_tile(start_m_mx), lds_do_scale_m_pong + ) gpu.barrier() dk = [acc_init] * dk_n_accs dv = [acc_init] * dv_n_accs @@ -1495,24 +2157,45 @@ def pingpong(offset_m, inner_state): num_tiles_loop = seqlen_rounded // tile_m if const_expr((num_tiles_loop % 2) == 1): upper_bound = seqlen_rounded - tile_m - init_state = _pack_state(dk, dv) + init_state = _pack_state(dk, dv) for iv, inner in range(start_m, upper_bound, tile_m * 2, init=init_state): results = yield pingpong(iv, inner) dk, dv = _unpack_state(results) curr_m = arith.index(seqlen_rounded - tile_m) - qk = compute_qk(lds_q_quant_head_pong, lds_q_scale_head_pong, lds_k_quant_head, lds_k_scale_head) + qk = compute_qk( + lds_q_quant_head_pong, + lds_q_scale_head_pong, + lds_k_quant_head, + lds_k_scale_head, + ) p = softmax(qk, curr_m) mxquant_m_and_store_to_lds(p, lds_ppt_shuffle, lds_ppt_scale_shuffle) gpu.barrier() - dv = compute_dv(dv, lds_ppt_shuffle, lds_ppt_scale_shuffle, lds_do_quant_m_pong, lds_do_scale_m_pong) - dp = compute_dp(lds_do_quant_head_pong, lds_do_scale_head_pong, lds_v, lds_v_scale) + dv = compute_dv( + dv, + lds_ppt_shuffle, + lds_ppt_scale_shuffle, + lds_do_quant_m_pong, + lds_do_scale_m_pong, + ) + dp = compute_dp( + lds_do_quant_head_pong, lds_do_scale_head_pong, lds_v, lds_v_scale + ) ds = compute_ds(dp, p, curr_m) mxquant_m_and_store_to_lds(ds, lds_dst_shuffle, lds_dst_scale_shuffle) mxquant_n_and_store_to_lds(ds, lds_ds_shuffle, lds_ds_scale_shuffle) gpu.barrier() - dk = compute_dk(dk, lds_dst_shuffle, lds_dst_scale_shuffle, lds_q_quant_m_pong, lds_q_scale_m_pong) - dq = compute_dq(lds_ds_shuffle, lds_ds_scale_shuffle, lds_k_quant_n, lds_k_scale_n) + dk = compute_dk( + dk, + lds_dst_shuffle, + lds_dst_scale_shuffle, + lds_q_quant_m_pong, + lds_q_scale_m_pong, + ) + dq = compute_dq( + lds_ds_shuffle, lds_ds_scale_shuffle, lds_k_quant_n, lds_k_scale_n + ) store_dq_atomic(dq, curr_m) else: upper_bound = seqlen_rounded - (tile_m * 2) @@ -1524,51 +2207,104 @@ def pingpong(offset_m, inner_state): curr_m = arith.index(seqlen_rounded - tile_m * 2) last_m = arith.index(seqlen_rounded - tile_m) last_m_mx = last_m // 32 - store_q_tile_to_lds(prefetch_q_quant_head_tile(last_m), lds_q_quant_head_ping) - store_q_scale_tile_to_lds(prefetch_q_scale_head_tile(last_m), lds_q_scale_head_ping) + store_q_tile_to_lds( + prefetch_q_quant_head_tile(last_m), lds_q_quant_head_ping + ) + store_q_scale_tile_to_lds( + prefetch_q_scale_head_tile(last_m), lds_q_scale_head_ping + ) store_q_tile_to_lds(prefetch_q_quant_m_tile(last_m), lds_q_quant_m_ping) - store_q_scale_tile_to_lds(prefetch_q_scale_m_tile(last_m_mx), lds_q_scale_m_ping) - store_do_tile_to_lds(prefetch_do_quant_head_tile(last_m), lds_do_quant_head_ping) - store_do_scale_tile_to_lds(prefetch_do_scale_head_tile(last_m), lds_do_scale_head_ping) + store_q_scale_tile_to_lds( + prefetch_q_scale_m_tile(last_m_mx), lds_q_scale_m_ping + ) + store_do_tile_to_lds( + prefetch_do_quant_head_tile(last_m), lds_do_quant_head_ping + ) + store_do_scale_tile_to_lds( + prefetch_do_scale_head_tile(last_m), lds_do_scale_head_ping + ) store_do_tile_to_lds(prefetch_do_quant_m_tile(last_m), lds_do_quant_m_ping) - store_do_scale_tile_to_lds(prefetch_do_scale_m_tile(last_m_mx), lds_do_scale_m_ping) + store_do_scale_tile_to_lds( + prefetch_do_scale_m_tile(last_m_mx), lds_do_scale_m_ping + ) - qk = compute_qk(lds_q_quant_head_pong, lds_q_scale_head_pong, lds_k_quant_head, lds_k_scale_head) + qk = compute_qk( + lds_q_quant_head_pong, + lds_q_scale_head_pong, + lds_k_quant_head, + lds_k_scale_head, + ) p = softmax(qk, curr_m) mxquant_m_and_store_to_lds(p, lds_ppt_shuffle, lds_ppt_scale_shuffle) gpu.barrier() - dv = compute_dv(dv, lds_ppt_shuffle, lds_ppt_scale_shuffle, lds_do_quant_m_pong, lds_do_scale_m_pong) - dp = compute_dp(lds_do_quant_head_pong, lds_do_scale_head_pong, lds_v, lds_v_scale) + dv = compute_dv( + dv, + lds_ppt_shuffle, + lds_ppt_scale_shuffle, + lds_do_quant_m_pong, + lds_do_scale_m_pong, + ) + dp = compute_dp( + lds_do_quant_head_pong, lds_do_scale_head_pong, lds_v, lds_v_scale + ) ds = compute_ds(dp, p, curr_m) mxquant_m_and_store_to_lds(ds, lds_dst_shuffle, lds_dst_scale_shuffle) mxquant_n_and_store_to_lds(ds, lds_ds_shuffle, lds_ds_scale_shuffle) gpu.barrier() - dk = compute_dk(dk, lds_dst_shuffle, lds_dst_scale_shuffle, lds_q_quant_m_pong, lds_q_scale_m_pong) - dq = compute_dq(lds_ds_shuffle, lds_ds_scale_shuffle, lds_k_quant_n, lds_k_scale_n) + dk = compute_dk( + dk, + lds_dst_shuffle, + lds_dst_scale_shuffle, + lds_q_quant_m_pong, + lds_q_scale_m_pong, + ) + dq = compute_dq( + lds_ds_shuffle, lds_ds_scale_shuffle, lds_k_quant_n, lds_k_scale_n + ) store_dq_atomic(dq, curr_m) hot_loop_scheduler() gpu.barrier() - + curr_m = last_m - qk = compute_qk(lds_q_quant_head_ping, lds_q_scale_head_ping, lds_k_quant_head, lds_k_scale_head) + qk = compute_qk( + lds_q_quant_head_ping, + lds_q_scale_head_ping, + lds_k_quant_head, + lds_k_scale_head, + ) p = softmax(qk, curr_m) mxquant_m_and_store_to_lds(p, lds_ppt_shuffle, lds_ppt_scale_shuffle) gpu.barrier() - dv = compute_dv(dv, lds_ppt_shuffle, lds_ppt_scale_shuffle, lds_do_quant_m_ping, lds_do_scale_m_ping) - dp = compute_dp(lds_do_quant_head_ping, lds_do_scale_head_ping, lds_v, lds_v_scale) + dv = compute_dv( + dv, + lds_ppt_shuffle, + lds_ppt_scale_shuffle, + lds_do_quant_m_ping, + lds_do_scale_m_ping, + ) + dp = compute_dp( + lds_do_quant_head_ping, lds_do_scale_head_ping, lds_v, lds_v_scale + ) ds = compute_ds(dp, p, curr_m) mxquant_m_and_store_to_lds(ds, lds_dst_shuffle, lds_dst_scale_shuffle) mxquant_n_and_store_to_lds(ds, lds_ds_shuffle, lds_ds_scale_shuffle) gpu.barrier() - dk = compute_dk(dk, lds_dst_shuffle, lds_dst_scale_shuffle, lds_q_quant_m_ping, lds_q_scale_m_ping) - dq = compute_dq(lds_ds_shuffle, lds_ds_scale_shuffle, lds_k_quant_n, lds_k_scale_n) + dk = compute_dk( + dk, + lds_dst_shuffle, + lds_dst_scale_shuffle, + lds_q_quant_m_ping, + lds_q_scale_m_ping, + ) + dq = compute_dq( + lds_ds_shuffle, lds_ds_scale_shuffle, lds_k_quant_n, lds_k_scale_n + ) store_dq_atomic(dq, curr_m) store_dk_atomic(dk) store_dv_atomic(dv) - # ── Host launcher ────────────────────────────────────────────────────── _cache_tag = (tile_m, tile_n, head_dim) @@ -1599,7 +2335,7 @@ def launch_attn_bwd( stride_kv_batch: fx.Int32, stride_kv_scale_batch: fx.Int32, stride_MD_batch: fx.Int32, - stride_qkvo_nheads: fx.Int32, + stride_qkvo_nheads: fx.Int32, stride_qkvo_scale_nheads: fx.Int32, stride_MD_nheads: fx.Int32, stream: fx.Stream, @@ -1636,26 +2372,48 @@ def launch_attn_bwd( allocator_ds_shuffle.finalize() allocator_ds_scale_shuffle.finalize() - gx = num_heads_q + gx = num_heads_q gy = (seqlen + tile_n - 1) // tile_n gz = batch - launcher = kernel_attn_bwd(arg_dq, arg_dk, arg_dv, arg_q_quant_head, arg_q_scale_head, arg_q_quant_m, arg_q_scale_m, arg_k_quant_head, arg_k_scale_head, arg_k_quant_n, arg_k_scale_n, arg_v, arg_v_scale, arg_do_quant_head, arg_do_scale_head, arg_do_quant_m, arg_do_scale_m, arg_M, arg_D, - batch, - stride_qo_batch, - stride_qo_scale_batch, - stride_kv_batch, - stride_kv_scale_batch, - stride_MD_batch, - stride_qkvo_nheads, - stride_qkvo_scale_nheads, - stride_MD_nheads) + launcher = kernel_attn_bwd( + arg_dq, + arg_dk, + arg_dv, + arg_q_quant_head, + arg_q_scale_head, + arg_q_quant_m, + arg_q_scale_m, + arg_k_quant_head, + arg_k_scale_head, + arg_k_quant_n, + arg_k_scale_n, + arg_v, + arg_v_scale, + arg_do_quant_head, + arg_do_scale_head, + arg_do_quant_m, + arg_do_scale_m, + arg_M, + arg_D, + batch, + stride_qo_batch, + stride_qo_scale_batch, + stride_kv_batch, + stride_kv_scale_batch, + stride_MD_batch, + stride_qkvo_nheads, + stride_qkvo_scale_nheads, + stride_MD_nheads, + ) if waves_per_eu is not None: _wpe = int(waves_per_eu) if _wpe >= 1: for op in ctx.gpu_module_body.operations: - if hasattr(op, 'attributes') and op.OPERATION_NAME == "gpu.func": - op.attributes["rocdl.waves_per_eu"] = ir.IntegerAttr.get(T.i32, _wpe) + if hasattr(op, "attributes") and op.OPERATION_NAME == "gpu.func": + op.attributes["rocdl.waves_per_eu"] = ir.IntegerAttr.get( + T.i32, _wpe + ) launcher.launch( grid=(gx, gy, gz), block=(256, 1, 1), diff --git a/op_tests/flydsl_tests/test_attn_bwd_mxfp8_gfx950.py b/op_tests/flydsl_tests/test_attn_bwd_mxfp8_gfx950.py index 6fd580e491..1a2e791fe1 100644 --- a/op_tests/flydsl_tests/test_attn_bwd_mxfp8_gfx950.py +++ b/op_tests/flydsl_tests/test_attn_bwd_mxfp8_gfx950.py @@ -19,6 +19,7 @@ ARCH = str(get_rocm_arch()) + def check_result(test_out, ref_out, atol=0.01, rtol=0.01, pass_pct=95.0): """Compare outputs and print result. Returns (passed, max_delta, pct_close).""" close_mask = torch.isclose(test_out.float(), ref_out.float(), atol=atol, rtol=rtol) @@ -26,7 +27,7 @@ def check_result(test_out, ref_out, atol=0.01, rtol=0.01, pass_pct=95.0): passed = pct_close > pass_pct if passed: return True - + max_delta = (ref_out.float() - test_out.float()).abs().max().item() print( f" max_delta={max_delta:.4f}, {pct_close:.1f}% close (atol={atol}, rtol={rtol})" @@ -42,7 +43,20 @@ def mx_quant(x, dim=-1): return x_fp32.contiguous(), x_fp8.contiguous(), x_scale.contiguous() -def run_torch(q_fp32_head, q_fp32_m, k_fp32_head, k_fp32_n, v, do_fp32_head, do_fp32_m, m, D, sm_scale, causal, gqa_size): +def run_torch( + q_fp32_head, + q_fp32_m, + k_fp32_head, + k_fp32_n, + v, + do_fp32_head, + do_fp32_m, + m, + D, + sm_scale, + causal, + gqa_size, +): batch = q_fp32_head.shape[0] num_heads_q = q_fp32_head.shape[1] num_heads_kv = num_heads_q // gqa_size @@ -55,10 +69,10 @@ def run_torch(q_fp32_head, q_fp32_m, k_fp32_head, k_fp32_n, v, do_fp32_head, do_ if causal: mask = torch.tril(torch.ones((seqlen, seqlen), device=device)) p[:, :, mask == 0] = 0.0 - + ppT, _, _ = mx_quant(p, -2) ppT = ppT.transpose(-2, -1) - dv = torch.matmul(ppT, do_fp32_m) + dv = torch.matmul(ppT, do_fp32_m) dp = torch.matmul(do_fp32_head, v_f32.transpose(-2, -1)) ds = p * (dp - D[:, :, :, None]) dsT, _, _ = mx_quant(ds, -1) @@ -76,11 +90,7 @@ def run_torch(q_fp32_head, q_fp32_m, k_fp32_head, k_fp32_n, v, do_fp32_head, do_ @pytest.mark.parametrize("batch", [1, 4]) @pytest.mark.parametrize( "num_heads_q, num_heads_kv", - [ - (48, 48), - (64, 8), - (80, 20) - ], + [(48, 48), (64, 8), (80, 20)], ) @pytest.mark.parametrize("seqlen", [128, 1024, 1056, 1152, 4096]) @pytest.mark.parametrize("head_dim", [64, 128]) @@ -88,8 +98,13 @@ def run_torch(q_fp32_head, q_fp32_m, k_fp32_head, k_fp32_n, v, do_fp32_head, do_ @pytest.mark.parametrize("tile_n", [64, 128]) @pytest.mark.parametrize("causal", [False, True]) def test_attn_bwd_flyc( - batch, num_heads_q, num_heads_kv, seqlen, head_dim, - tile_m, tile_n, + batch, + num_heads_q, + num_heads_kv, + seqlen, + head_dim, + tile_m, + tile_n, causal, waves_per_eu: int = 0, ): @@ -102,9 +117,13 @@ def test_attn_bwd_flyc( sm_scale = 0.5 _wpe = int(waves_per_eu) launch_fn = compile_attn_bwd_mxfp8_gfx950( - num_heads_q=num_heads_q, num_heads_kv=num_heads_kv, - seqlen=seqlen, head_dim=head_dim, - tile_m=tile_m, tile_n=tile_n, tile_head=tile_head, + num_heads_q=num_heads_q, + num_heads_kv=num_heads_kv, + seqlen=seqlen, + head_dim=head_dim, + tile_m=tile_m, + tile_n=tile_n, + tile_head=tile_head, sm_scale=sm_scale, causal=causal, waves_per_eu=_wpe, @@ -112,10 +131,30 @@ def test_attn_bwd_flyc( device = torch.device("cuda") gqa_size = num_heads_q // num_heads_kv - q_fp32 = torch.randn(batch, num_heads_q, seqlen, head_dim, device=device, dtype=torch.float32) * 0.5 - k_fp32 = torch.randn(batch, num_heads_kv, seqlen, head_dim, device=device, dtype=torch.float32) * 0.5 - v_fp32 = torch.randn(batch, num_heads_kv, seqlen, head_dim, device=device, dtype=torch.float32) * 0.5 - do_fp32 = torch.randn(batch, num_heads_q, seqlen, head_dim, device=device, dtype=torch.float32) * 0.5 + q_fp32 = ( + torch.randn( + batch, num_heads_q, seqlen, head_dim, device=device, dtype=torch.float32 + ) + * 0.5 + ) + k_fp32 = ( + torch.randn( + batch, num_heads_kv, seqlen, head_dim, device=device, dtype=torch.float32 + ) + * 0.5 + ) + v_fp32 = ( + torch.randn( + batch, num_heads_kv, seqlen, head_dim, device=device, dtype=torch.float32 + ) + * 0.5 + ) + do_fp32 = ( + torch.randn( + batch, num_heads_q, seqlen, head_dim, device=device, dtype=torch.float32 + ) + * 0.5 + ) q_fp32_head, q_quant_head, q_scale_head = mx_quant(q_fp32, -1) q_fp32_m, q_quant_m, q_scale_m = mx_quant(q_fp32, -2) @@ -140,13 +179,53 @@ def test_attn_bwd_flyc( m = m + torch.log(l) D = (o_fp32 * do_fp32).sum(dim=-1) - dq_ref, dk_ref, dv_ref = run_torch(q_fp32_head, q_fp32_m, k_fp32_head, k_fp32_n, v_fp32, do_fp32_head, do_fp32_m, m, D, sm_scale, causal, gqa_size) + dq_ref, dk_ref, dv_ref = run_torch( + q_fp32_head, + q_fp32_m, + k_fp32_head, + k_fp32_n, + v_fp32, + do_fp32_head, + do_fp32_m, + m, + D, + sm_scale, + causal, + gqa_size, + ) - dq_fly = torch.zeros((batch, num_heads_q, seqlen, head_dim), dtype=torch.float32, device=device) - dk_fly = torch.zeros((batch, num_heads_kv, seqlen, head_dim), dtype=torch.float32, device=device) - dv_fly = torch.zeros((batch, num_heads_kv, seqlen, head_dim), dtype=torch.float32, device=device) + dq_fly = torch.zeros( + (batch, num_heads_q, seqlen, head_dim), dtype=torch.float32, device=device + ) + dk_fly = torch.zeros( + (batch, num_heads_kv, seqlen, head_dim), dtype=torch.float32, device=device + ) + dv_fly = torch.zeros( + (batch, num_heads_kv, seqlen, head_dim), dtype=torch.float32, device=device + ) - def launch_kernel(dq, dk, dv, q_quant_head, q_scale_head, q_quant_m, q_scale_m, k_quant_head, k_scale_head, k_quant_n, k_scale_n, v, v_scale, do_quant_head, do_scale_head, do_quant_m, do_scale_m, m, D, batch): + def launch_kernel( + dq, + dk, + dv, + q_quant_head, + q_scale_head, + q_quant_m, + q_scale_m, + k_quant_head, + k_scale_head, + k_quant_n, + k_scale_n, + v, + v_scale, + do_quant_head, + do_scale_head, + do_quant_m, + do_scale_m, + m, + D, + batch, + ): launch_fn( dq.contiguous().view(-1), dk.contiguous().view(-1), @@ -199,7 +278,7 @@ def launch_kernel(dq, dk, dv, q_quant_head, q_scale_head, q_quant_m, q_scale_m, do_scale_m, m, D, - batch + batch, ) dq_fly_fp32 = dq_fly.to(torch.float32) diff --git a/op_tests/op_benchmarks/flydsl/bench_attn_bwd_mxfp8_gfx950.py b/op_tests/op_benchmarks/flydsl/bench_attn_bwd_mxfp8_gfx950.py index 3f1ce84779..85b5d52113 100644 --- a/op_tests/op_benchmarks/flydsl/bench_attn_bwd_mxfp8_gfx950.py +++ b/op_tests/op_benchmarks/flydsl/bench_attn_bwd_mxfp8_gfx950.py @@ -9,7 +9,11 @@ import torch from aiter.ops.flydsl.kernels.attn_bwd_mxfp8_gfx950 import compile_attn_bwd_mxfp8_gfx950 from utils import run_perftest -from op_tests.flydsl_tests.test_attn_bwd_mxfp8_gfx950 import run_torch, mx_quant, check_result +from op_tests.flydsl_tests.test_attn_bwd_mxfp8_gfx950 import ( + run_torch, + mx_quant, + check_result, +) from flydsl.runtime.device import get_rocm_arch logging.basicConfig(level=logging.INFO) @@ -17,29 +21,38 @@ DEFAULT_BENCH_ITERS = 20 DEFAULT_BENCH_WARMUP = 3 + def bench_attn_bwd_flyc( - batch, num_heads_q, num_heads_kv, seqlen, head_dim, - tile_m, tile_n, + batch, + num_heads_q, + num_heads_kv, + seqlen, + head_dim, + tile_m, + tile_n, causal, test_graph, bench_iters: int = DEFAULT_BENCH_ITERS, bench_warmup: int = DEFAULT_BENCH_WARMUP, waves_per_eu: int = 0, - check_correctness: bool = False + check_correctness: bool = False, ): """Attention bwd using the @flyc.kernel / @flyc.jit API.""" tile_head = head_dim print("=" * 80) - print( - f"[flyc] Attention Backward Test (Tile: {tile_m}x{tile_n}x{tile_head})" - ) + print(f"[flyc] Attention Backward Test (Tile: {tile_m}x{tile_n}x{tile_head})") print("=" * 80) - + sm_scale = 0.5 _wpe = int(waves_per_eu) if waves_per_eu else 0 launch_fn = compile_attn_bwd_mxfp8_gfx950( - num_heads_q=num_heads_q, num_heads_kv=num_heads_kv, seqlen=seqlen, head_dim=head_dim, - tile_m=tile_m, tile_n=tile_n, tile_head=tile_head, + num_heads_q=num_heads_q, + num_heads_kv=num_heads_kv, + seqlen=seqlen, + head_dim=head_dim, + tile_m=tile_m, + tile_n=tile_n, + tile_head=tile_head, sm_scale=sm_scale, causal=causal, waves_per_eu=_wpe, @@ -48,11 +61,36 @@ def bench_attn_bwd_flyc( device = torch.device("cuda") gqa_size = num_heads_q // num_heads_kv - q_fp32 = torch.randn(batch, num_heads_q, seqlen, head_dim, device=device, dtype=torch.float32) * 0.5 - k_fp32 = torch.randn(batch, num_heads_kv, seqlen, head_dim, device=device, dtype=torch.float32) * 0.5 - v_fp32 = torch.randn(batch, num_heads_kv, seqlen, head_dim, device=device, dtype=torch.float32) * 0.5 - o_fp32 = torch.randn(batch, num_heads_q, seqlen, head_dim, device=device, dtype=torch.float32) * 0.5 - do_fp32 = torch.randn(batch, num_heads_q, seqlen, head_dim, device=device, dtype=torch.float32) * 0.5 + q_fp32 = ( + torch.randn( + batch, num_heads_q, seqlen, head_dim, device=device, dtype=torch.float32 + ) + * 0.5 + ) + k_fp32 = ( + torch.randn( + batch, num_heads_kv, seqlen, head_dim, device=device, dtype=torch.float32 + ) + * 0.5 + ) + v_fp32 = ( + torch.randn( + batch, num_heads_kv, seqlen, head_dim, device=device, dtype=torch.float32 + ) + * 0.5 + ) + o_fp32 = ( + torch.randn( + batch, num_heads_q, seqlen, head_dim, device=device, dtype=torch.float32 + ) + * 0.5 + ) + do_fp32 = ( + torch.randn( + batch, num_heads_q, seqlen, head_dim, device=device, dtype=torch.float32 + ) + * 0.5 + ) q_fp32_head, q_quant_head, q_scale_head = mx_quant(q_fp32, -1) q_fp32_m, q_quant_m, q_scale_m = mx_quant(q_fp32, -2) @@ -76,12 +114,52 @@ def bench_attn_bwd_flyc( D = (o_fp32 * do_fp32).sum(dim=-1) if check_correctness: - dq_ref, dk_ref, dv_ref = run_torch(q_fp32_head, q_fp32_m, k_fp32_head, k_fp32_n, v_fp32, do_fp32_head, do_fp32_m, m, D, sm_scale, causal, gqa_size) - dq_fly = torch.zeros((batch, num_heads_q, seqlen, head_dim), dtype=torch.float32, device=device) - dk_fly = torch.zeros((batch, num_heads_kv, seqlen, head_dim), dtype=torch.float32, device=device) - dv_fly = torch.zeros((batch, num_heads_kv, seqlen, head_dim), dtype=torch.float32, device=device) + dq_ref, dk_ref, dv_ref = run_torch( + q_fp32_head, + q_fp32_m, + k_fp32_head, + k_fp32_n, + v_fp32, + do_fp32_head, + do_fp32_m, + m, + D, + sm_scale, + causal, + gqa_size, + ) + dq_fly = torch.zeros( + (batch, num_heads_q, seqlen, head_dim), dtype=torch.float32, device=device + ) + dk_fly = torch.zeros( + (batch, num_heads_kv, seqlen, head_dim), dtype=torch.float32, device=device + ) + dv_fly = torch.zeros( + (batch, num_heads_kv, seqlen, head_dim), dtype=torch.float32, device=device + ) - def launch_kernel(dq, dk, dv, q_quant_head, q_scale_head, q_quant_m, q_scale_m, k_quant_head, k_scale_head, k_quant_n, k_scale_n, v, v_scale, do_quant_head, do_scale_head, do_quant_m, do_scale_m, m, D, batch): + def launch_kernel( + dq, + dk, + dv, + q_quant_head, + q_scale_head, + q_quant_m, + q_scale_m, + k_quant_head, + k_scale_head, + k_quant_n, + k_scale_n, + v, + v_scale, + do_quant_head, + do_scale_head, + do_quant_m, + do_scale_m, + m, + D, + batch, + ): launch_fn( dq.contiguous().view(-1), dk.contiguous().view(-1), @@ -167,7 +245,7 @@ def launch_kernel(dq, dk, dv, q_quant_head, q_scale_head, q_quant_m, q_scale_m, do_scale_m, m, D, - batch + batch, ) dq_fly_fp32 = dq_fly.to(torch.float32) @@ -179,8 +257,20 @@ def launch_kernel(dq, dk, dv, q_quant_head, q_scale_head, q_quant_m, q_scale_m, assert check_result(dk_fly_fp32, dk_ref, rtol=0.01, atol=0.01) assert check_result(dv_fly_fp32, dv_ref, rtol=0.01, atol=0.01) - bytes_moved = (4 + 4) * batch * num_heads_q * seqlen * head_dim + (3 + 2 * 4) * batch * num_heads_kv * seqlen * head_dim + 2 * 4 * batch * num_heads_q * seqlen - flops = batch * num_heads_q * (5 * 2 * seqlen * seqlen * head_dim + 5 * seqlen * seqlen + 2 * 3 * seqlen * seqlen) + bytes_moved = ( + (4 + 4) * batch * num_heads_q * seqlen * head_dim + + (3 + 2 * 4) * batch * num_heads_kv * seqlen * head_dim + + 2 * 4 * batch * num_heads_q * seqlen + ) + flops = ( + batch + * num_heads_q + * ( + 5 * 2 * seqlen * seqlen * head_dim + + 5 * seqlen * seqlen + + 2 * 3 * seqlen * seqlen + ) + ) if causal: flops /= 2 tflops = flops / (us / 1e6) / 1e12 @@ -209,12 +299,17 @@ def launch_kernel(dq, dk, dv, q_quant_head, q_scale_head, q_quant_m, q_scale_m, torch.set_default_device("cuda") bench_attn_bwd_flyc( - batch=args.batch, num_heads_q=args.num_heads_q, num_heads_kv=args.num_heads_kv, seqlen=args.seqlen, head_dim=args.head, - tile_m=args.tile_m, tile_n=args.tile_n, + batch=args.batch, + num_heads_q=args.num_heads_q, + num_heads_kv=args.num_heads_kv, + seqlen=args.seqlen, + head_dim=args.head, + tile_m=args.tile_m, + tile_n=args.tile_n, causal=args.causal, test_graph=bool(args.test_graph), bench_iters=args.num_iters, bench_warmup=args.num_warmup, waves_per_eu=int(args.waves_per_eu), - check_correctness=args.check_correctness + check_correctness=args.check_correctness, ) diff --git a/op_tests/op_benchmarks/flydsl/utils.py b/op_tests/op_benchmarks/flydsl/utils.py index bbcf69d2bf..b3d67bc375 100644 --- a/op_tests/op_benchmarks/flydsl/utils.py +++ b/op_tests/op_benchmarks/flydsl/utils.py @@ -365,4 +365,4 @@ def get_trace_perf(prof, num_iters): pd.set_option("display.max_colwidth", 90) pd.set_option("display.float_format", "{:,.1f}".format) logger.info(f"{df}") - return df.at[avg_name, "device_time_sum"] \ No newline at end of file + return df.at[avg_name, "device_time_sum"] From 74adeec0ddd3202df0b4c14bfc0f6015a36e6b23 Mon Sep 17 00:00:00 2001 From: Lukasz Burzawa Date: Sat, 16 May 2026 19:53:28 +0000 Subject: [PATCH 7/8] supports 2d quant --- .../flydsl/kernels/attn_bwd_mxfp8_gfx950.py | 1052 ++++++----------- .../_triton_kernels/quant/mxfp8_quant.py | 262 +++- aiter/ops/triton/quant/mxfp8_quant.py | 134 +++ .../test_attn_bwd_mxfp8_gfx950.py | 139 +-- .../flydsl/bench_attn_bwd_mxfp8_gfx950.py | 159 +-- op_tests/op_benchmarks/flydsl/utils.py | 1 - .../triton_tests/quant/test_quant_mxfp8.py | 176 ++- 7 files changed, 1064 insertions(+), 859 deletions(-) diff --git a/aiter/ops/flydsl/kernels/attn_bwd_mxfp8_gfx950.py b/aiter/ops/flydsl/kernels/attn_bwd_mxfp8_gfx950.py index 96263b8c3a..191c164451 100644 --- a/aiter/ops/flydsl/kernels/attn_bwd_mxfp8_gfx950.py +++ b/aiter/ops/flydsl/kernels/attn_bwd_mxfp8_gfx950.py @@ -87,18 +87,7 @@ def compile_attn_bwd_mxfp8_gfx950( allocator_pong = SmemAllocator(None, arch=gpu_arch, global_sym_name="smem0") allocator_ping = SmemAllocator(None, arch=gpu_arch, global_sym_name="smem1") - allocator_k_quant_head = SmemAllocator( - None, arch=gpu_arch, global_sym_name="smem_k_quant_head" - ) - allocator_k_scale_head = SmemAllocator( - None, arch=gpu_arch, global_sym_name="smem_k_scale_head" - ) - allocator_k_quant_n = SmemAllocator( - None, arch=gpu_arch, global_sym_name="smem_k_quant_n" - ) - allocator_k_scale_n = SmemAllocator( - None, arch=gpu_arch, global_sym_name="smem_k_scale_n" - ) + allocator_k = SmemAllocator(None, arch=gpu_arch, global_sym_name="smem_k") allocator_v = SmemAllocator(None, arch=gpu_arch, global_sym_name="smem_v") allocator_v_scale = SmemAllocator( None, arch=gpu_arch, global_sym_name="smem_v_scale" @@ -147,10 +136,7 @@ def _vec16_type(): # ── LDS sizing (pure Python, no MLIR ops) ──────────────────────────────── lds_qo_tile_bytes = int(tile_m) * int(tile_head) - lds_qo_scale_tile_bytes = (int(tile_m) * int(tile_head)) // 32 lds_k_tile_bytes = int(tile_n) * int(tile_head) - lds_k_scale_head_tile_bytes = int(tile_n) * int(tile_head_mx) - lds_k_scale_n_tile_bytes = int(tile_n_mx) * int(tile_head) lds_v_tile_bytes = int(tile_n) * int(tile_head) lds_v_scale_tile_bytes = int(tile_n) * int(tile_head_mx) lds_ppt_tile_bytes = int(tile_n) * int(tile_m) @@ -160,45 +146,20 @@ def _vec16_type(): lds_ds_tile_bytes = int(tile_m) * int(tile_n) lds_ds_scale_tile_bytes = int(tile_m) * int(tile_n_mx) - buffer_size_bytes = lds_qo_tile_bytes * 4 + lds_qo_scale_tile_bytes * 4 + buffer_size_bytes = lds_qo_tile_bytes * 2 # + lds_qo_scale_tile_bytes * 4 lds_pong_offset = allocator_pong._align(allocator_pong.ptr, 16) allocator_pong.ptr = lds_pong_offset + buffer_size_bytes - lds_q_quant_head_pong_offset = lds_pong_offset - lds_q_scale_head_pong_offset = lds_q_quant_head_pong_offset + lds_qo_tile_bytes - lds_q_quant_m_pong_offset = lds_q_scale_head_pong_offset + lds_qo_scale_tile_bytes - lds_q_scale_m_pong_offset = lds_q_quant_m_pong_offset + lds_qo_tile_bytes - lds_do_quant_head_pong_offset = lds_q_scale_m_pong_offset + lds_qo_scale_tile_bytes - lds_do_scale_head_pong_offset = lds_do_quant_head_pong_offset + lds_qo_tile_bytes - lds_do_quant_m_pong_offset = lds_do_scale_head_pong_offset + lds_qo_scale_tile_bytes - lds_do_scale_m_pong_offset = lds_do_quant_m_pong_offset + lds_qo_tile_bytes + lds_q_pong_offset = lds_pong_offset + lds_do_pong_offset = lds_q_pong_offset + lds_qo_tile_bytes lds_ping_offset = allocator_ping._align(allocator_ping.ptr, 16) allocator_ping.ptr = lds_ping_offset + buffer_size_bytes - lds_q_quant_head_ping_offset = lds_ping_offset - lds_q_scale_head_ping_offset = lds_q_quant_head_ping_offset + lds_qo_tile_bytes - lds_q_quant_m_ping_offset = lds_q_scale_head_ping_offset + lds_qo_scale_tile_bytes - lds_q_scale_m_ping_offset = lds_q_quant_m_ping_offset + lds_qo_tile_bytes - lds_do_quant_head_ping_offset = lds_q_scale_m_ping_offset + lds_qo_scale_tile_bytes - lds_do_scale_head_ping_offset = lds_do_quant_head_ping_offset + lds_qo_tile_bytes - lds_do_quant_m_ping_offset = lds_do_scale_head_ping_offset + lds_qo_scale_tile_bytes - lds_do_scale_m_ping_offset = lds_do_quant_m_ping_offset + lds_qo_tile_bytes - - lds_k_quant_head_offset = allocator_k_quant_head._align( - allocator_k_quant_head.ptr, 16 - ) - allocator_k_quant_head.ptr = lds_k_quant_head_offset + lds_k_tile_bytes + lds_q_ping_offset = lds_ping_offset + lds_do_ping_offset = lds_q_ping_offset + lds_qo_tile_bytes - lds_k_scale_head_offset = allocator_k_scale_head._align( - allocator_k_scale_head.ptr, 16 - ) - allocator_k_scale_head.ptr = lds_k_scale_head_offset + lds_k_scale_head_tile_bytes - - lds_k_quant_n_offset = allocator_k_quant_n._align(allocator_k_quant_n.ptr, 16) - allocator_k_quant_n.ptr = lds_k_quant_n_offset + lds_k_tile_bytes - - lds_k_scale_n_offset = allocator_k_scale_n._align(allocator_k_scale_n.ptr, 16) - allocator_k_scale_n.ptr = lds_k_scale_n_offset + lds_k_scale_n_tile_bytes + lds_k_offset = allocator_k._align(allocator_k.ptr, 16) + allocator_k.ptr = lds_k_offset + lds_k_tile_bytes lds_v_offset = allocator_v._align(allocator_v.ptr, 16) allocator_v.ptr = lds_v_offset + lds_v_tile_bytes @@ -242,31 +203,30 @@ def kernel_attn_bwd( arg_dq: fx.Tensor, arg_dk: fx.Tensor, arg_dv: fx.Tensor, - arg_q_quant_head: fx.Tensor, - arg_q_scale_head: fx.Tensor, - arg_q_quant_m: fx.Tensor, - arg_q_scale_m: fx.Tensor, - arg_k_quant_head: fx.Tensor, - arg_k_scale_head: fx.Tensor, - arg_k_quant_n: fx.Tensor, - arg_k_scale_n: fx.Tensor, + arg_q: fx.Tensor, + arg_q_scale: fx.Tensor, + arg_k: fx.Tensor, + arg_k_scale: fx.Tensor, arg_v: fx.Tensor, arg_v_scale: fx.Tensor, - arg_do_quant_head: fx.Tensor, - arg_do_scale_head: fx.Tensor, - arg_do_quant_m: fx.Tensor, - arg_do_scale_m: fx.Tensor, + arg_do: fx.Tensor, + arg_do_scale: fx.Tensor, arg_M: fx.Tensor, arg_D: fx.Tensor, batch: fx.Int32, stride_qo_batch: fx.Int32, - stride_qo_scale_batch: fx.Int32, stride_kv_batch: fx.Int32, - stride_kv_scale_batch: fx.Int32, stride_MD_batch: fx.Int32, stride_qkvo_nheads: fx.Int32, - stride_qkvo_scale_nheads: fx.Int32, stride_MD_nheads: fx.Int32, + stride_q_scale_batch: fx.Int32, + stride_q_scale_nheads: fx.Int32, + stride_k_scale_batch: fx.Int32, + stride_k_scale_nheads: fx.Int32, + stride_v_scale_batch: fx.Int32, + stride_v_scale_nheads: fx.Int32, + stride_do_scale_batch: fx.Int32, + stride_do_scale_nheads: fx.Int32, ): # ---- Types ---- @@ -287,10 +247,7 @@ def kernel_attn_bwd( # ---- LDS (separate ping/pong buffers) ---- base_ptr_pong = allocator_pong.get_base() base_ptr_ping = allocator_ping.get_base() - base_ptr_k_quant_head = allocator_k_quant_head.get_base() - base_ptr_k_scale_head = allocator_k_scale_head.get_base() - base_ptr_k_quant_n = allocator_k_quant_n.get_base() - base_ptr_k_scale_n = allocator_k_scale_n.get_base() + base_ptr_k = allocator_k.get_base() base_ptr_v = allocator_v.get_base() base_ptr_v_scale = allocator_v_scale.get_base() base_ptr_ppt_shuffle = allocator_ppt_shuffle.get_base() @@ -300,111 +257,36 @@ def kernel_attn_bwd( base_ptr_ds_shuffle = allocator_ds_shuffle.get_base() base_ptr_ds_scale_shuffle = allocator_ds_scale_shuffle.get_base() - lds_q_quant_head_pong = SmemPtr( + lds_q_pong = SmemPtr( base_ptr_pong, - lds_q_quant_head_pong_offset, + lds_q_pong_offset, T.f8, shape=(tile_m * tile_head,), ).get() - lds_q_quant_head_ping = SmemPtr( + lds_q_ping = SmemPtr( base_ptr_ping, - lds_q_quant_head_ping_offset, + lds_q_ping_offset, T.f8, shape=(tile_m * tile_head,), ).get() - lds_q_scale_head_pong = SmemPtr( - base_ptr_pong, - lds_q_scale_head_pong_offset, - T.i8, - shape=(tile_m * tile_head_mx,), - ).get() - lds_q_scale_head_ping = SmemPtr( - base_ptr_ping, - lds_q_scale_head_ping_offset, - T.i8, - shape=(tile_m * tile_head_mx,), - ).get() - lds_q_quant_m_pong = SmemPtr( - base_ptr_pong, lds_q_quant_m_pong_offset, T.f8, shape=(tile_m * tile_head,) - ).get() - lds_q_quant_m_ping = SmemPtr( - base_ptr_ping, lds_q_quant_m_ping_offset, T.f8, shape=(tile_m * tile_head,) - ).get() - lds_q_scale_m_pong = SmemPtr( - base_ptr_pong, - lds_q_scale_m_pong_offset, - T.i8, - shape=(tile_m_mx * tile_head,), - ).get() - lds_q_scale_m_ping = SmemPtr( - base_ptr_ping, - lds_q_scale_m_ping_offset, - T.i8, - shape=(tile_m_mx * tile_head,), - ).get() - lds_do_quant_head_pong = SmemPtr( + lds_do_pong = SmemPtr( base_ptr_pong, - lds_do_quant_head_pong_offset, + lds_do_pong_offset, T.f8, shape=(tile_m * tile_head,), ).get() - lds_do_quant_head_ping = SmemPtr( + lds_do_ping = SmemPtr( base_ptr_ping, - lds_do_quant_head_ping_offset, + lds_do_ping_offset, T.f8, shape=(tile_m * tile_head,), ).get() - lds_do_scale_head_pong = SmemPtr( - base_ptr_pong, - lds_do_scale_head_pong_offset, - T.i8, - shape=(tile_m * tile_head_mx,), - ).get() - lds_do_scale_head_ping = SmemPtr( - base_ptr_ping, - lds_do_scale_head_ping_offset, - T.i8, - shape=(tile_m * tile_head_mx,), - ).get() - lds_do_quant_m_pong = SmemPtr( - base_ptr_pong, lds_do_quant_m_pong_offset, T.f8, shape=(tile_m * tile_head,) - ).get() - lds_do_quant_m_ping = SmemPtr( - base_ptr_ping, lds_do_quant_m_ping_offset, T.f8, shape=(tile_m * tile_head,) - ).get() - lds_do_scale_m_pong = SmemPtr( - base_ptr_pong, - lds_do_scale_m_pong_offset, - T.i8, - shape=(tile_head * tile_m_mx,), - ).get() - lds_do_scale_m_ping = SmemPtr( - base_ptr_ping, - lds_do_scale_m_ping_offset, - T.i8, - shape=(tile_head * tile_m_mx,), - ).get() - lds_k_quant_head = SmemPtr( - base_ptr_k_quant_head, - lds_k_quant_head_offset, + lds_k = SmemPtr( + base_ptr_k, + lds_k_offset, T.f8, shape=(tile_n * tile_head,), ).get() - lds_k_scale_head = SmemPtr( - base_ptr_k_scale_head, - lds_k_scale_head_offset, - T.i8, - shape=(tile_n * tile_head_mx,), - ).get() - lds_k_quant_n = SmemPtr( - base_ptr_k_quant_n, lds_k_quant_n_offset, T.f8, shape=(tile_n * tile_head,) - ).get() - lds_k_scale_n = SmemPtr( - base_ptr_k_scale_n, - lds_k_scale_n_offset, - T.i8, - shape=(tile_n_mx * tile_head,), - ).get() lds_v = SmemPtr( base_ptr_v, lds_v_offset, T.f8, shape=(tile_n * tile_head,) ).get() @@ -447,78 +329,62 @@ def kernel_attn_bwd( stride_qkvo_nheads ) offset_dkdv_nheads = offset_kv_nheads * 4 - offset_qo_scale_nheads = batch_id * fx.Index( - stride_qo_scale_batch - ) + head_q * fx.Index(stride_qkvo_scale_nheads) - offset_kv_scale_nheads = batch_id * fx.Index( - stride_kv_scale_batch - ) + head_kv * fx.Index(stride_qkvo_scale_nheads) + offset_q_scale_nheads = batch_id * fx.Index( + stride_q_scale_batch + ) + head_q * fx.Index(stride_q_scale_nheads) + offset_k_scale_nheads = batch_id * fx.Index( + stride_k_scale_batch + ) + head_kv * fx.Index(stride_k_scale_nheads) + offset_v_scale_nheads = batch_id * fx.Index( + stride_v_scale_batch + ) + head_kv * fx.Index(stride_v_scale_nheads) + offset_do_scale_nheads = batch_id * fx.Index( + stride_do_scale_batch + ) + head_q * fx.Index(stride_do_scale_nheads) offset_MD_nheads = ( batch_id * fx.Index(stride_MD_batch) + head_q * fx.Index(stride_MD_nheads) ) * 4 # ---- Buffer resources (runtime byte sizes for OOB protection) ---- head_dim_mx = head_dim // 32 + seqlen_mx = seqlen // 32 global_buffer_size_tensor = fx.Index(seqlen * head_dim) global_buffer_size_scale = fx.Index(seqlen * head_dim_mx) + global_buffer_size_scale_2d = fx.Index(seqlen_mx * head_dim_mx) q_nrec = arith.index_cast(T.i64, global_buffer_size_tensor) - q_scale_nrec = arith.index_cast(T.i64, global_buffer_size_scale) + q_scale_nrec = arith.index_cast(T.i64, global_buffer_size_scale_2d) k_nrec = arith.index_cast(T.i64, global_buffer_size_tensor) - k_scale_nrec = arith.index_cast(T.i64, global_buffer_size_scale) + k_scale_nrec = arith.index_cast(T.i64, global_buffer_size_scale_2d) v_nrec = arith.index_cast(T.i64, global_buffer_size_tensor) v_scale_nrec = arith.index_cast(T.i64, global_buffer_size_scale) do_nrec = arith.index_cast(T.i64, global_buffer_size_tensor) - do_scale_nrec = arith.index_cast(T.i64, global_buffer_size_scale) + do_scale_nrec = arith.index_cast(T.i64, global_buffer_size_scale_2d) output_nrec = arith.index_cast(T.i64, global_buffer_size_tensor * 4) MD_nrec = arith.index_cast(T.i64, fx.Index(seqlen * 4)) - q_quant_head_rsrc = buffer_ops.create_buffer_resource( - arg_q_quant_head, - max_size=False, - num_records_bytes=q_nrec, - base_byte_offset=offset_qo_nheads, - ) - q_scale_head_rsrc = buffer_ops.create_buffer_resource( - arg_q_scale_head, - max_size=False, - num_records_bytes=q_scale_nrec, - base_byte_offset=offset_qo_scale_nheads, - ) - q_quant_m_rsrc = buffer_ops.create_buffer_resource( - arg_q_quant_m, + q_rsrc = buffer_ops.create_buffer_resource( + arg_q, max_size=False, num_records_bytes=q_nrec, base_byte_offset=offset_qo_nheads, ) - q_scale_m_rsrc = buffer_ops.create_buffer_resource( - arg_q_scale_m, + q_scale_rsrc = buffer_ops.create_buffer_resource( + arg_q_scale, max_size=False, num_records_bytes=q_scale_nrec, - base_byte_offset=offset_qo_scale_nheads, + base_byte_offset=offset_q_scale_nheads, ) - k_quant_head_rsrc = buffer_ops.create_buffer_resource( - arg_k_quant_head, + k_rsrc = buffer_ops.create_buffer_resource( + arg_k, max_size=False, num_records_bytes=k_nrec, base_byte_offset=offset_kv_nheads, ) - k_scale_head_rsrc = buffer_ops.create_buffer_resource( - arg_k_scale_head, + k_scale_rsrc = buffer_ops.create_buffer_resource( + arg_k_scale, max_size=False, num_records_bytes=k_scale_nrec, - base_byte_offset=offset_kv_scale_nheads, - ) - k_quant_n_rsrc = buffer_ops.create_buffer_resource( - arg_k_quant_n, - max_size=False, - num_records_bytes=k_nrec, - base_byte_offset=offset_kv_nheads, - ) - k_scale_n_rsrc = buffer_ops.create_buffer_resource( - arg_k_scale_n, - max_size=False, - num_records_bytes=k_scale_nrec, - base_byte_offset=offset_kv_scale_nheads, + base_byte_offset=offset_k_scale_nheads, ) v_rsrc = buffer_ops.create_buffer_resource( arg_v, @@ -530,31 +396,19 @@ def kernel_attn_bwd( arg_v_scale, max_size=False, num_records_bytes=v_scale_nrec, - base_byte_offset=offset_kv_scale_nheads, + base_byte_offset=offset_v_scale_nheads, ) - do_quant_head_rsrc = buffer_ops.create_buffer_resource( - arg_do_quant_head, + do_rsrc = buffer_ops.create_buffer_resource( + arg_do, max_size=False, num_records_bytes=do_nrec, base_byte_offset=offset_qo_nheads, ) - do_scale_head_rsrc = buffer_ops.create_buffer_resource( - arg_do_scale_head, + do_scale_rsrc = buffer_ops.create_buffer_resource( + arg_do_scale, max_size=False, num_records_bytes=do_scale_nrec, - base_byte_offset=offset_qo_scale_nheads, - ) - do_quant_m_rsrc = buffer_ops.create_buffer_resource( - arg_do_quant_m, - max_size=False, - num_records_bytes=do_nrec, - base_byte_offset=offset_qo_nheads, - ) - do_scale_m_rsrc = buffer_ops.create_buffer_resource( - arg_do_scale_m, - max_size=False, - num_records_bytes=do_scale_nrec, - base_byte_offset=offset_qo_scale_nheads, + base_byte_offset=offset_do_scale_nheads, ) dq_rsrc = buffer_ops.create_buffer_resource( arg_dq, @@ -635,6 +489,7 @@ def kernel_attn_bwd( dv_n_per_wave = tile_n // dv_n_num_waves dv_n_num_subtiles = dv_n_per_wave // 16 dv_head_per_wave = tile_head // dv_head_num_waves + dv_head_mx_per_wave = tile_head_mx // dv_head_num_waves dv_head_num_subtiles = dv_head_per_wave // 16 dv_n_accs = dv_n_num_subtiles * dv_head_num_subtiles @@ -650,6 +505,7 @@ def kernel_attn_bwd( dk_n_per_wave = tile_n // dk_n_num_waves dk_num_subtiles_n = dk_n_per_wave // 16 dk_head_per_wave = tile_head // dk_head_num_waves + dk_head_mx_per_wave = tile_head_mx // dk_head_num_waves dk_num_subtiles_head = dk_head_per_wave // 16 dk_n_accs = dk_num_subtiles_n * dk_num_subtiles_head @@ -665,6 +521,7 @@ def kernel_attn_bwd( dq_m_per_wave = tile_m // dq_m_num_waves dq_num_subtiles_m = dq_m_per_wave // 16 dq_head_per_wave = tile_head // dq_head_num_waves + dq_head_mx_per_wave = tile_head_mx // dq_head_num_waves dq_num_subtiles_head = dq_head_per_wave // 16 dq_n_accs = dq_num_subtiles_m * dq_num_subtiles_head @@ -727,46 +584,24 @@ def lds_scale_load(row, col, lds_stride, lds_buffer): c4 = fx.Index(4) tx_i32_base = tx * c4 - def load_q_quant_head_16(idx_elem): - return buffer_copy_gmem16_dwordx4( - buffer_ops, - vector, - elem_type=_elem_type(), - idx_i32=idx_elem, - rsrc=q_quant_head_rsrc, - vec_elems=16, - elem_bytes=elem_bytes, - ) - - def load_q_quant_m_16(idx_elem): - return buffer_copy_gmem16_dwordx4( - buffer_ops, - vector, - elem_type=_elem_type(), - idx_i32=idx_elem, - rsrc=q_quant_m_rsrc, - vec_elems=16, - elem_bytes=elem_bytes, - ) - - def load_k_quant_head_16(idx_elem): + def load_q_16(idx_elem): return buffer_copy_gmem16_dwordx4( buffer_ops, vector, elem_type=_elem_type(), idx_i32=idx_elem, - rsrc=k_quant_head_rsrc, + rsrc=q_rsrc, vec_elems=16, elem_bytes=elem_bytes, ) - def load_k_quant_n_16(idx_elem): + def load_k_16(idx_elem): return buffer_copy_gmem16_dwordx4( buffer_ops, vector, elem_type=_elem_type(), idx_i32=idx_elem, - rsrc=k_quant_n_rsrc, + rsrc=k_rsrc, vec_elems=16, elem_bytes=elem_bytes, ) @@ -782,24 +617,13 @@ def load_v_16(idx_elem): elem_bytes=elem_bytes, ) - def load_do_quant_head_16(idx_elem): - return buffer_copy_gmem16_dwordx4( - buffer_ops, - vector, - elem_type=_elem_type(), - idx_i32=idx_elem, - rsrc=do_quant_head_rsrc, - vec_elems=16, - elem_bytes=elem_bytes, - ) - - def load_do_quant_m_16(idx_elem): + def load_do_16(idx_elem): return buffer_copy_gmem16_dwordx4( buffer_ops, vector, elem_type=_elem_type(), idx_i32=idx_elem, - rsrc=do_quant_m_rsrc, + rsrc=do_rsrc, vec_elems=16, elem_bytes=elem_bytes, ) @@ -822,43 +646,23 @@ def kv_tile_chunk_coord_i32(i: int): layout_tile_div4=layout_kv_tile_div4, ) - def prefetch_q_quant_head_tile(offset_m): - parts = [] - for i in range_constexpr(num_qo_loads): - row_q_local, col_q_local_i32 = qo_tile_chunk_coord_i32(i) - row_q_global = offset_m + row_q_local - idx_elem = row_q_global * head_dim_div4 + col_q_local_i32 - q_16B = load_q_quant_head_16(idx_elem) - parts.append(vector.bitcast(T.i32x4, q_16B)) - return parts - - def prefetch_q_quant_m_tile(offset_m): + def prefetch_q_tile(offset_m): parts = [] for i in range_constexpr(num_qo_loads): row_q_local, col_q_local_i32 = qo_tile_chunk_coord_i32(i) row_q_global = offset_m + row_q_local idx_elem = row_q_global * head_dim_div4 + col_q_local_i32 - q_16B = load_q_quant_m_16(idx_elem) + q_16B = load_q_16(idx_elem) parts.append(vector.bitcast(T.i32x4, q_16B)) return parts - def prefetch_k_quant_head_tile(): - parts = [] - for i in range_constexpr(num_kv_loads): - row_k_local, col_k_local_i32 = kv_tile_chunk_coord_i32(i) - row_k_global = global_offset_n + row_k_local - idx_elem = row_k_global * head_dim_div4 + col_k_local_i32 - k_16B = load_k_quant_head_16(idx_elem) - parts.append(vector.bitcast(T.i32x4, k_16B)) - return parts - - def prefetch_k_quant_n_tile(): + def prefetch_k_tile(): parts = [] for i in range_constexpr(num_kv_loads): row_k_local, col_k_local_i32 = kv_tile_chunk_coord_i32(i) row_k_global = global_offset_n + row_k_local idx_elem = row_k_global * head_dim_div4 + col_k_local_i32 - k_16B = load_k_quant_n_16(idx_elem) + k_16B = load_k_16(idx_elem) parts.append(vector.bitcast(T.i32x4, k_16B)) return parts @@ -872,109 +676,67 @@ def prefetch_v_tile(): parts.append(vector.bitcast(T.i32x4, v_16B)) return parts - def prefetch_do_quant_head_tile(offset_m): + def prefetch_do_tile(offset_m): parts = [] for i in range_constexpr(num_qo_loads): row_do_local, col_do_local_i32 = qo_tile_chunk_coord_i32(i) row_do_global = offset_m + row_do_local idx_elem = row_do_global * head_dim_div4 + col_do_local_i32 - do_16B = load_do_quant_head_16(idx_elem) + do_16B = load_do_16(idx_elem) parts.append(vector.bitcast(T.i32x4, do_16B)) return parts - def prefetch_do_quant_m_tile(offset_m): + def prefetch_q_scale_head_2d_tile(offset_m): parts = [] - for i in range_constexpr(num_qo_loads): - row_do_local, col_do_local_i32 = qo_tile_chunk_coord_i32(i) - row_do_global = offset_m + row_do_local - idx_elem = row_do_global * head_dim_div4 + col_do_local_i32 - do_16B = load_do_quant_m_16(idx_elem) - parts.append(vector.bitcast(T.i32x4, do_16B)) - return parts - - def prefetch_q_scale_head_tile(offset_m): - vec_width = bytes_per_thread_qo_scale - if const_expr(vec_width == 1): - if const_expr(bytes_per_tile_qo_scale < total_threads): - idx_elem = offset_m * head_dim_mx + tx % bytes_per_tile_qo_scale - else: - idx_elem = offset_m * head_dim_mx + tx + for i in range_constexpr(ps_m_num_subtiles // 2): + global_row = offset_m + ps_m_wave_id * ps_m_mx_per_wave + i + global_col = lane_div_16 % tile_head_mx + global_idx = global_row * head_dim_mx + global_col vec = buffer_ops.buffer_load( - q_scale_head_rsrc, idx_elem, vec_width=1, dtype=T.i8 + q_scale_rsrc, global_idx, vec_width=1, dtype=T.i8 ) - vec = vector.from_elements(T.vec(1, T.i8), [vec]) - else: # vec_width=2 - idx_elem = (offset_m * head_dim_mx + tx * vec_width) // 2 - vec = buffer_ops.buffer_load( - q_scale_head_rsrc, idx_elem, vec_width=1, dtype=T.i16 - ) - vec = vector.from_elements(T.vec(1, T.i16), [vec]) - vec = vector.bitcast(T.i8x2, vec) - return vec + vec = vec.extui(T.i32) + parts.append(vec) + return parts - def prefetch_q_scale_m_tile(offset_m): - vec_width = bytes_per_thread_qo_scale - if const_expr(vec_width == 1): - if const_expr(bytes_per_tile_qo_scale < total_threads): - idx_elem = offset_m * head_dim + tx % bytes_per_tile_qo_scale - else: - idx_elem = offset_m * head_dim + tx - vec = buffer_ops.buffer_load( - q_scale_m_rsrc, idx_elem, vec_width=1, dtype=T.i8 - ) - vec = vector.from_elements(T.vec(1, T.i8), [vec]) - else: # vec_width=2 - idx_elem = (offset_m * head_dim + tx * vec_width) // 2 + def prefetch_q_scale_m_2d_tile(offset_m): + parts = [] + for i in range_constexpr(dk_num_subtiles_head // 2): + global_row = offset_m + lane_div_16 % tile_m_mx + global_col = dk_head_wave_id * dk_head_mx_per_wave + i + global_idx = global_row * head_dim_mx + global_col vec = buffer_ops.buffer_load( - q_scale_m_rsrc, idx_elem, vec_width=1, dtype=T.i16 + q_scale_rsrc, global_idx, vec_width=1, dtype=T.i8 ) - vec = vector.from_elements(T.vec(1, T.i16), [vec]) - vec = vector.bitcast(T.i8x2, vec) - return vec + vec = vec.extui(T.i32) + parts.append(vec) + return parts - def prefetch_k_scale_head_tile(): - vec_width = bytes_per_thread_kv_scale - if const_expr(vec_width == 1): - if const_expr(bytes_per_tile_kv_scale < total_threads): - idx_elem = ( - global_offset_n * head_dim_mx + tx % bytes_per_tile_kv_scale - ) - else: - idx_elem = global_offset_n * head_dim_mx + tx - vec = buffer_ops.buffer_load( - k_scale_head_rsrc, idx_elem, vec_width=1, dtype=T.i8 - ) - vec = vector.from_elements(T.vec(1, T.i8), [vec]) - else: # vec_width=2 - idx_elem = (global_offset_n * head_dim_mx + tx * vec_width) // 2 + def prefetch_k_scale_head_2d_tile(): + parts = [] + for i in range_constexpr(ps_n_num_subtiles // 2): + global_row = global_offset_n_mx + ps_n_wave_id * ps_n_mx_per_wave + i + global_col = lane_div_16 % tile_head_mx + global_idx = global_row * head_dim_mx + global_col vec = buffer_ops.buffer_load( - k_scale_head_rsrc, idx_elem, vec_width=1, dtype=T.i16 + k_scale_rsrc, global_idx, vec_width=1, dtype=T.i8 ) - vec = vector.from_elements(T.vec(1, T.i16), [vec]) - vec = vector.bitcast(T.i8x2, vec) - return vec + vec = vec.extui(T.i32) + parts.append(vec) + return parts - def prefetch_k_scale_n_tile(): - vec_width = bytes_per_thread_kv_scale - if const_expr(vec_width == 1): - if const_expr(bytes_per_tile_kv_scale < total_threads): - idx_elem = ( - global_offset_n_mx * head_dim + tx % bytes_per_tile_kv_scale - ) - else: - idx_elem = global_offset_n_mx * head_dim + tx - vec = buffer_ops.buffer_load( - k_scale_n_rsrc, idx_elem, vec_width=1, dtype=T.i8 - ) - vec = vector.from_elements(T.vec(1, T.i8), [vec]) - else: # vec_width=2 - idx_elem = (global_offset_n_mx * head_dim + tx * vec_width) // 2 + def prefetch_k_scale_n_2d_tile(): + parts = [] + for i in range_constexpr(dq_num_subtiles_head // 2): + global_row = global_offset_n_mx + lane_div_16 % tile_n_mx + global_col = dq_head_wave_id * dq_head_mx_per_wave + i + global_idx = global_row * head_dim_mx + global_col vec = buffer_ops.buffer_load( - k_scale_n_rsrc, idx_elem, vec_width=1, dtype=T.i16 + k_scale_rsrc, global_idx, vec_width=1, dtype=T.i8 ) - vec = vector.from_elements(T.vec(1, T.i16), [vec]) - vec = vector.bitcast(T.i8x2, vec) - return vec + vec = vec.extui(T.i32) + parts.append(vec) + return parts def prefetch_v_scale_tile(): vec_width = bytes_per_thread_kv_scale @@ -998,45 +760,31 @@ def prefetch_v_scale_tile(): vec = vector.bitcast(T.i8x2, vec) return vec - def prefetch_do_scale_head_tile(offset_m): - vec_width = bytes_per_thread_qo_scale - if const_expr(vec_width == 1): - if const_expr(bytes_per_tile_qo_scale < total_threads): - idx_elem = offset_m * head_dim_mx + tx % bytes_per_tile_qo_scale - else: - idx_elem = offset_m * head_dim_mx + tx - vec = buffer_ops.buffer_load( - do_scale_head_rsrc, idx_elem, vec_width=1, dtype=T.i8 - ) - vec = vector.from_elements(T.vec(1, T.i8), [vec]) - else: # vec_width=2 - idx_elem = (offset_m * head_dim_mx + tx * vec_width) // 2 + def prefetch_do_scale_head_2d_tile(offset_m): + parts = [] + for i in range_constexpr(ps_m_num_subtiles // 2): + global_row = offset_m + ps_m_wave_id * ps_m_mx_per_wave + i + global_col = lane_div_16 % tile_head_mx + global_idx = global_row * head_dim_mx + global_col vec = buffer_ops.buffer_load( - do_scale_head_rsrc, idx_elem, vec_width=1, dtype=T.i16 + do_scale_rsrc, global_idx, vec_width=1, dtype=T.i8 ) - vec = vector.from_elements(T.vec(1, T.i16), [vec]) - vec = vector.bitcast(T.i8x2, vec) - return vec + vec = vec.extui(T.i32) + parts.append(vec) + return parts - def prefetch_do_scale_m_tile(offset_m): - vec_width = bytes_per_thread_qo_scale - if const_expr(vec_width == 1): - if const_expr(bytes_per_tile_qo_scale < total_threads): - idx_elem = offset_m * head_dim + tx % bytes_per_tile_qo_scale - else: - idx_elem = offset_m * head_dim + tx - vec = buffer_ops.buffer_load( - do_scale_m_rsrc, idx_elem, vec_width=1, dtype=T.i8 - ) - vec = vector.from_elements(T.vec(1, T.i8), [vec]) - else: # vec_width=2 - idx_elem = (offset_m * head_dim + tx * vec_width) // 2 + def prefetch_do_scale_m_2d_tile(offset_m): + parts = [] + for i in range_constexpr(dv_head_num_subtiles // 2): + global_row = offset_m + lane_div_16 % tile_m_mx + global_col = dv_head_wave_id * dv_head_mx_per_wave + i + global_idx = global_row * head_dim_mx + global_col vec = buffer_ops.buffer_load( - do_scale_m_rsrc, idx_elem, vec_width=1, dtype=T.i16 + do_scale_rsrc, global_idx, vec_width=1, dtype=T.i8 ) - vec = vector.from_elements(T.vec(1, T.i16), [vec]) - vec = vector.bitcast(T.i8x2, vec) - return vec + vec = vec.extui(T.i32) + parts.append(vec) + return parts def store_q_tile_to_lds(vec_q_parts, lds_buffer): for i in range_constexpr(num_qo_loads): @@ -1086,20 +834,6 @@ def store_do_tile_to_lds(vec_do_parts, lds_buffer): v16 = vector.bitcast(_vec16_type(), vec_do_parts[i]) vector.store(v16, lds_buffer, [idx0]) - def store_q_scale_tile_to_lds(vec_scale, lds_buffer): - vec_width = bytes_per_thread_qo_scale - idx = tx * vec_width - if total_threads > bytes_per_tile_qo_scale: - idx = idx % bytes_per_tile_qo_scale - vector.store(vec_scale, lds_buffer, [idx]) - - def store_k_scale_tile_to_lds(vec_scale, lds_buffer): - vec_width = bytes_per_thread_kv_scale - idx = tx * vec_width - if total_threads > bytes_per_tile_kv_scale: - idx = idx % bytes_per_tile_kv_scale - vector.store(vec_scale, lds_buffer, [idx]) - def store_v_scale_tile_to_lds(vec_scale, lds_buffer): vec_width = bytes_per_thread_kv_scale idx = tx * vec_width @@ -1107,13 +841,6 @@ def store_v_scale_tile_to_lds(vec_scale, lds_buffer): idx = idx % bytes_per_tile_kv_scale vector.store(vec_scale, lds_buffer, [idx]) - def store_do_scale_tile_to_lds(vec_scale, lds_buffer): - vec_width = bytes_per_thread_qo_scale - idx = tx * vec_width - if total_threads > bytes_per_tile_qo_scale: - idx = idx % bytes_per_tile_qo_scale - vector.store(vec_scale, lds_buffer, [idx]) - # ── Compute tile (MFMA) ─────────────────────────────────────────── def pack_i64x4_to_i32x8(x0, x1, x2, x3): @@ -1122,9 +849,7 @@ def pack_i64x4_to_i32x8(x0, x1, x2, x3): v4 = vector.from_elements(vec4_i64, [x0, x1, x2, x3]) return vector.bitcast(vec8_i32, v4) - def compute_qk( - lds_a_buffer, lds_a_scale_buffer, lds_b_buffer, lds_b_scale_buffer - ): + def compute_qk(lds_a_buffer, a_scales, lds_b_buffer, b_scales): # (m, head) @ (head, n) = (m, n) current_accs_list = [acc_init] * ps_n_accs @@ -1132,9 +857,7 @@ def compute_qk( ku0 = 0 ku1 = 1 - lds_col0 = ( - ku0 * 64 + lane_div_16 * 16 - ) # 16 elements packed per lane, 64 per wave + lds_col0 = ku0 * 64 + lane_div_16 * 16 lds_col1 = ku1 * 64 + lane_div_16 * 16 lds_scale_col = lane_div_16 if const_expr(tile_head == 64): @@ -1154,9 +877,10 @@ def compute_qk( a128 = pack_i64x4_to_i32x8(a0, a1, a2, a3) lds_a_scale_row = lds_a_row - a_scale = lds_scale_load( - lds_a_scale_row, lds_scale_col, tile_head_mx, lds_a_scale_buffer - ) + # a_scale = lds_scale_load( + # lds_a_scale_row, lds_scale_col, tile_head_mx, lds_a_scale_buffer + # ) + a_scale = a_scales[mi // 2] for ni in range_constexpr(ps_n_num_subtiles): lds_b_row = ps_n_wave_id * ps_n_per_wave + ni * 16 + lane_mod_16 @@ -1170,11 +894,9 @@ def compute_qk( else: b2 = b3 = fx.Int64(0) b128 = pack_i64x4_to_i32x8(b0, b1, b2, b3) - b_scale = lds_scale_load( - lds_b_row, lds_scale_col, tile_head_mx, lds_b_scale_buffer - ) - # fx.printf("ni={}, mi={}", ni, mi) + b_scale = b_scales[ni // 2] + acc_idx = mi * ps_n_num_subtiles + ni current_accs_list[acc_idx] = rocdl.mfma_scale_f32_16x16x128_f8f6f4( mfma_res_ty, @@ -1252,7 +974,7 @@ def softmax(accs_in, offset_m): return accs_out def compute_dv( - accs_in, lds_a_buffer, lds_a_scale_buffer, lds_b_buffer, lds_b_scale_buffer + accs_in, lds_a_buffer, lds_a_scale_buffer, lds_b_buffer, b_scales ): current_accs_list = list(accs_in) mfma_res_ty = T.f32x4 @@ -1313,15 +1035,7 @@ def compute_dv( b3 = fx.Int64(0) b128 = pack_i64x4_to_i32x8(b0, b1, b2, b3) - lds_b_scale_col = ( - dv_head_wave_id * dv_head_per_wave + hi * 16 + lane_mod_16 - ) - b_scale = lds_scale_load( - lds_b_scale_row, - lds_b_scale_col, - tile_head, - lds_b_scale_buffer, - ) + b_scale = b_scales[hi // 2] acc_idx = ni * dv_head_num_subtiles + hi current_accs_list[acc_idx] = ( @@ -1342,16 +1056,12 @@ def compute_dv( ) return current_accs_list - def compute_dp( - lds_a_buffer, lds_a_scale_buffer, lds_b_buffer, lds_b_scale_buffer - ): + def compute_dp(lds_a_buffer, a_scales, lds_b_buffer, lds_b_scale_buffer): current_accs_list = [acc_init] * ps_n_accs mfma_res_ty = T.f32x4 ku0 = 0 ku1 = 1 - lds_col0 = ( - ku0 * 64 + lane_div_16 * 16 - ) # 16 elements packed per lane, 64 per wave + lds_col0 = ku0 * 64 + lane_div_16 * 16 lds_col1 = ku1 * 64 + lane_div_16 * 16 lds_scale_col = lane_div_16 if const_expr(tile_head == 64): @@ -1369,9 +1079,8 @@ def compute_dp( else: a2 = a3 = fx.Int64(0) a128 = pack_i64x4_to_i32x8(a0, a1, a2, a3) - a_scale = lds_scale_load( - lds_a_row, lds_scale_col, tile_head_mx, lds_a_scale_buffer - ) + + a_scale = a_scales[mi // 2] for ni in range_constexpr(ps_n_num_subtiles): lds_b_row = ps_n_wave_id * ps_n_per_wave + ni * 16 + lane_mod_16 @@ -1663,7 +1372,7 @@ def mxquant_n_and_store_to_lds(accs_in, lds_buffer, lds_buffer_scale): vector.store(scale_vector, lds_buffer_scale, [lds_scale_idx]) def compute_dk( - accs_in, lds_a_buffer, lds_a_scale_buffer, lds_b_buffer, lds_b_scale_buffer + accs_in, lds_a_buffer, lds_a_scale_buffer, lds_b_buffer, b_scales ): current_accs_list = list(accs_in) mfma_res_ty = T.f32x4 @@ -1672,9 +1381,7 @@ def compute_dk( ku0 = ku128 * 2 ku1 = ku0 + 1 - lds_a_col0 = ( - ku0 * 64 + lane_div_16 * 16 - ) # 16 elements packed per lane, 64 per wave + lds_a_col0 = ku0 * 64 + lane_div_16 * 16 lds_a_col1 = ku1 * 64 + lane_div_16 * 16 lds_b_row0 = ku0 * 64 + lane_div_16 * 16 + lane_div_2 @@ -1706,9 +1413,7 @@ def compute_dk( ) for hi in range_constexpr(dk_num_subtiles_head): - lds_b_col = ( - dk_head_wave_id * dk_head_per_wave + hi * 16 - ) # + lane_mod_2 * 8 + lds_b_col = dk_head_wave_id * dk_head_per_wave + hi * 16 b0 = lds_load_packs_k32_transposed( lds_b_row0, lds_b_col, tile_head, lds_b_buffer ) @@ -1727,15 +1432,7 @@ def compute_dk( b3 = fx.Int64(0) b128 = pack_i64x4_to_i32x8(b0, b1, b2, b3) - lds_b_scale_col = ( - dk_head_wave_id * dk_head_per_wave + hi * 16 + lane_mod_16 - ) - b_scale = lds_scale_load( - lds_b_scale_row, - lds_b_scale_col, - tile_head, - lds_b_scale_buffer, - ) + b_scale = b_scales[hi // 2] acc_idx = ni * dk_num_subtiles_head + hi current_accs_list[acc_idx] = ( @@ -1756,9 +1453,7 @@ def compute_dk( ) return current_accs_list - def compute_dq( - lds_a_buffer, lds_a_scale_buffer, lds_b_buffer, lds_b_scale_buffer - ): + def compute_dq(lds_a_buffer, lds_a_scale_buffer, lds_b_buffer, b_scales): # (m, n) @ (n, head) = (m, head) current_accs_list = [acc_init] * dq_n_accs @@ -1780,10 +1475,10 @@ def compute_dq( lds_b_row3 = ku1 * 64 + lane_div_16 * 16 + 8 + lane_div_2 lds_a_scale_col = lane_div_16 - lds_b_scale_row = lane_div_16 + # lds_b_scale_row = lane_div_16 if const_expr(tile_n == 64): lds_a_scale_col = lds_a_scale_col % 2 - lds_b_scale_row = lds_b_scale_row % 2 + # lds_b_scale_row = lds_b_scale_row % 2 for mi in range_constexpr(dq_num_subtiles_m): lds_a_col = ( @@ -1815,9 +1510,7 @@ def compute_dq( ) for hi in range_constexpr(dq_num_subtiles_head): - lds_b_col = ( - dq_head_wave_id * dq_head_per_wave + hi * 16 - ) # + lane_mod_2 * 8 + lds_b_col = dq_head_wave_id * dq_head_per_wave + hi * 16 b0 = lds_load_packs_k32_transposed( lds_b_row0, lds_b_col, tile_head, lds_b_buffer ) @@ -1836,15 +1529,7 @@ def compute_dq( b3 = fx.Int64(0) b128 = pack_i64x4_to_i32x8(b0, b1, b2, b3) - lds_b_scale_col = ( - dq_head_wave_id * dq_head_per_wave + hi * 16 + lane_mod_16 - ) - b_scale = lds_scale_load( - lds_b_scale_row, - lds_b_scale_col, - tile_head, - lds_b_scale_buffer, - ) + b_scale = b_scales[hi // 2] acc_idx = mi * dq_num_subtiles_head + hi current_accs_list[acc_idx] = ( @@ -1979,49 +1664,84 @@ def hot_loop_scheduler(): # ── Main pipeline ───────────────────────────────────────────────── - def _pack_state(dk, dv): - return list(dk) + list(dv) + def _pack_state(dk, dv, q_scales_head, q_scales_m, do_scales_head, do_scales_m): + return ( + list(dk) + + list(dv) + + list(q_scales_head) + + list(q_scales_m) + + list(do_scales_head) + + list(do_scales_m) + ) def _unpack_state(vals): dk = list(vals[:dk_n_accs]) - dv = list(vals[dk_n_accs:]) - return dk, dv + dv = list(vals[dk_n_accs : dk_n_accs + dv_n_accs]) + q_scales_head = list( + vals[ + dk_n_accs + + dv_n_accs : dk_n_accs + + dv_n_accs + + ps_m_num_subtiles // 2 + ] + ) + q_scales_m = list( + vals[ + dk_n_accs + + dv_n_accs + + ps_m_num_subtiles // 2 : dk_n_accs + + dv_n_accs + + ps_m_num_subtiles // 2 + + dk_num_subtiles_head // 2 + ] + ) + do_scales_head = list( + vals[ + dk_n_accs + + dv_n_accs + + ps_m_num_subtiles // 2 + + dk_num_subtiles_head // 2 : dk_n_accs + + dv_n_accs + + ps_m_num_subtiles // 2 + + dk_num_subtiles_head // 2 + + ps_m_num_subtiles // 2 + ] + ) + do_scales_m = list( + vals[ + dk_n_accs + + dv_n_accs + + ps_m_num_subtiles // 2 + + dk_num_subtiles_head // 2 + + ps_m_num_subtiles // 2 : + ] + ) + return dk, dv, q_scales_head, q_scales_m, do_scales_head, do_scales_m def pingpong(offset_m, inner_state): - dk, dv = _unpack_state(inner_state) + ( + dk, + dv, + q_scales_head_pong, + q_scales_m_pong, + do_scales_head_pong, + do_scales_m_pong, + ) = _unpack_state(inner_state) next_offset_m = offset_m + tile_m next_offset_m_mx = next_offset_m // 32 - store_q_tile_to_lds( - prefetch_q_quant_head_tile(next_offset_m), lds_q_quant_head_ping - ) - store_q_scale_tile_to_lds( - prefetch_q_scale_head_tile(next_offset_m), lds_q_scale_head_ping - ) - store_q_tile_to_lds( - prefetch_q_quant_m_tile(next_offset_m), lds_q_quant_m_ping - ) - store_q_scale_tile_to_lds( - prefetch_q_scale_m_tile(next_offset_m_mx), lds_q_scale_m_ping - ) - store_do_tile_to_lds( - prefetch_do_quant_head_tile(next_offset_m), lds_do_quant_head_ping - ) - store_do_scale_tile_to_lds( - prefetch_do_scale_head_tile(next_offset_m), lds_do_scale_head_ping - ) - store_do_tile_to_lds( - prefetch_do_quant_m_tile(next_offset_m), lds_do_quant_m_ping - ) - store_do_scale_tile_to_lds( - prefetch_do_scale_m_tile(next_offset_m_mx), lds_do_scale_m_ping - ) + store_q_tile_to_lds(prefetch_q_tile(next_offset_m), lds_q_ping) + q_scales_head_ping = prefetch_q_scale_head_2d_tile(next_offset_m_mx) + q_scales_m_ping = prefetch_q_scale_m_2d_tile(next_offset_m_mx) + store_do_tile_to_lds(prefetch_do_tile(next_offset_m), lds_do_ping) + do_scales_head_ping = prefetch_do_scale_head_2d_tile(next_offset_m_mx) + do_scales_m_ping = prefetch_do_scale_m_2d_tile(next_offset_m_mx) qk = compute_qk( - lds_q_quant_head_pong, - lds_q_scale_head_pong, - lds_k_quant_head, - lds_k_scale_head, + lds_q_pong, + q_scales_head_pong, + lds_k, + k_scales_head, ) p = softmax(qk, offset_m) mxquant_m_and_store_to_lds(p, lds_ppt_shuffle, lds_ppt_scale_shuffle) @@ -2030,12 +1750,10 @@ def pingpong(offset_m, inner_state): dv, lds_ppt_shuffle, lds_ppt_scale_shuffle, - lds_do_quant_m_pong, - lds_do_scale_m_pong, - ) - dp = compute_dp( - lds_do_quant_head_pong, lds_do_scale_head_pong, lds_v, lds_v_scale + lds_do_pong, + do_scales_m_pong, ) + dp = compute_dp(lds_do_pong, do_scales_head_pong, lds_v, lds_v_scale) ds = compute_ds(dp, p, offset_m) mxquant_m_and_store_to_lds(ds, lds_dst_shuffle, lds_dst_scale_shuffle) mxquant_n_and_store_to_lds(ds, lds_ds_shuffle, lds_ds_scale_shuffle) @@ -2044,48 +1762,28 @@ def pingpong(offset_m, inner_state): dk, lds_dst_shuffle, lds_dst_scale_shuffle, - lds_q_quant_m_pong, - lds_q_scale_m_pong, - ) - dq = compute_dq( - lds_ds_shuffle, lds_ds_scale_shuffle, lds_k_quant_n, lds_k_scale_n + lds_q_pong, + q_scales_m_pong, ) + dq = compute_dq(lds_ds_shuffle, lds_ds_scale_shuffle, lds_k, k_scales_n) store_dq_atomic(dq, offset_m) hot_loop_scheduler() gpu.barrier() next_offset_m = offset_m + (tile_m * 2) next_offset_m_mx = next_offset_m // 32 - store_q_tile_to_lds( - prefetch_q_quant_head_tile(next_offset_m), lds_q_quant_head_pong - ) - store_q_scale_tile_to_lds( - prefetch_q_scale_head_tile(next_offset_m), lds_q_scale_head_pong - ) - store_q_tile_to_lds( - prefetch_q_quant_m_tile(next_offset_m), lds_q_quant_m_pong - ) - store_q_scale_tile_to_lds( - prefetch_q_scale_m_tile(next_offset_m_mx), lds_q_scale_m_pong - ) - store_do_tile_to_lds( - prefetch_do_quant_head_tile(next_offset_m), lds_do_quant_head_pong - ) - store_do_scale_tile_to_lds( - prefetch_do_scale_head_tile(next_offset_m), lds_do_scale_head_pong - ) - store_do_tile_to_lds( - prefetch_do_quant_m_tile(next_offset_m), lds_do_quant_m_pong - ) - store_do_scale_tile_to_lds( - prefetch_do_scale_m_tile(next_offset_m_mx), lds_do_scale_m_pong - ) + store_q_tile_to_lds(prefetch_q_tile(next_offset_m), lds_q_pong) + q_scales_head_pong = prefetch_q_scale_head_2d_tile(next_offset_m_mx) + q_scales_m_pong = prefetch_q_scale_m_2d_tile(next_offset_m_mx) + store_do_tile_to_lds(prefetch_do_tile(next_offset_m), lds_do_pong) + do_scales_head_pong = prefetch_do_scale_head_2d_tile(next_offset_m_mx) + do_scales_m_pong = prefetch_do_scale_m_2d_tile(next_offset_m_mx) qk = compute_qk( - lds_q_quant_head_ping, - lds_q_scale_head_ping, - lds_k_quant_head, - lds_k_scale_head, + lds_q_ping, + q_scales_head_ping, + lds_k, + k_scales_head, ) p = softmax(qk, offset_m + tile_m) mxquant_m_and_store_to_lds(p, lds_ppt_shuffle, lds_ppt_scale_shuffle) @@ -2094,12 +1792,10 @@ def pingpong(offset_m, inner_state): dv, lds_ppt_shuffle, lds_ppt_scale_shuffle, - lds_do_quant_m_ping, - lds_do_scale_m_ping, - ) - dp = compute_dp( - lds_do_quant_head_ping, lds_do_scale_head_ping, lds_v, lds_v_scale + lds_do_ping, + do_scales_m_ping, ) + dp = compute_dp(lds_do_ping, do_scales_head_ping, lds_v, lds_v_scale) ds = compute_ds(dp, p, offset_m + tile_m) mxquant_m_and_store_to_lds(ds, lds_dst_shuffle, lds_dst_scale_shuffle) mxquant_n_and_store_to_lds(ds, lds_ds_shuffle, lds_ds_scale_shuffle) @@ -2108,17 +1804,22 @@ def pingpong(offset_m, inner_state): dk, lds_dst_shuffle, lds_dst_scale_shuffle, - lds_q_quant_m_ping, - lds_q_scale_m_ping, - ) - dq = compute_dq( - lds_ds_shuffle, lds_ds_scale_shuffle, lds_k_quant_n, lds_k_scale_n + lds_q_ping, + q_scales_m_ping, ) + dq = compute_dq(lds_ds_shuffle, lds_ds_scale_shuffle, lds_k, k_scales_n) store_dq_atomic(dq, offset_m + tile_m) hot_loop_scheduler() gpu.barrier() - return _pack_state(dk, dv) + return _pack_state( + dk, + dv, + q_scales_head_pong, + q_scales_m_pong, + do_scales_head_pong, + do_scales_m_pong, + ) if const_expr(causal): start_m = (global_offset_n // (tile_m * 2)) * (tile_m * 2) @@ -2126,30 +1827,17 @@ def pingpong(offset_m, inner_state): start_m = fx.Index(0) start_m_mx = start_m // 32 - store_q_tile_to_lds(prefetch_q_quant_head_tile(start_m), lds_q_quant_head_pong) - store_q_scale_tile_to_lds( - prefetch_q_scale_head_tile(start_m), lds_q_scale_head_pong - ) - store_q_tile_to_lds(prefetch_q_quant_m_tile(start_m), lds_q_quant_m_pong) - store_q_scale_tile_to_lds( - prefetch_q_scale_m_tile(start_m_mx), lds_q_scale_m_pong - ) - store_k_tile_to_lds(prefetch_k_quant_head_tile(), lds_k_quant_head) - store_k_scale_tile_to_lds(prefetch_k_scale_head_tile(), lds_k_scale_head) - store_k_tile_to_lds(prefetch_k_quant_n_tile(), lds_k_quant_n) - store_k_scale_tile_to_lds(prefetch_k_scale_n_tile(), lds_k_scale_n) + store_q_tile_to_lds(prefetch_q_tile(start_m), lds_q_pong) + q_scales_head_pong = prefetch_q_scale_head_2d_tile(start_m_mx) + q_scales_m_pong = prefetch_q_scale_m_2d_tile(start_m_mx) + store_k_tile_to_lds(prefetch_k_tile(), lds_k) + k_scales_head = prefetch_k_scale_head_2d_tile() + k_scales_n = prefetch_k_scale_n_2d_tile() store_v_tile_to_lds(prefetch_v_tile(), lds_v) store_v_scale_tile_to_lds(prefetch_v_scale_tile(), lds_v_scale) - store_do_tile_to_lds( - prefetch_do_quant_head_tile(start_m), lds_do_quant_head_pong - ) - store_do_scale_tile_to_lds( - prefetch_do_scale_head_tile(start_m), lds_do_scale_head_pong - ) - store_do_tile_to_lds(prefetch_do_quant_m_tile(start_m), lds_do_quant_m_pong) - store_do_scale_tile_to_lds( - prefetch_do_scale_m_tile(start_m_mx), lds_do_scale_m_pong - ) + store_do_tile_to_lds(prefetch_do_tile(start_m), lds_do_pong) + do_scales_head_pong = prefetch_do_scale_head_2d_tile(start_m_mx) + do_scales_m_pong = prefetch_do_scale_m_2d_tile(start_m_mx) gpu.barrier() dk = [acc_init] * dk_n_accs dv = [acc_init] * dv_n_accs @@ -2157,17 +1845,31 @@ def pingpong(offset_m, inner_state): num_tiles_loop = seqlen_rounded // tile_m if const_expr((num_tiles_loop % 2) == 1): upper_bound = seqlen_rounded - tile_m - init_state = _pack_state(dk, dv) + init_state = _pack_state( + dk, + dv, + q_scales_head_pong, + q_scales_m_pong, + do_scales_head_pong, + do_scales_m_pong, + ) for iv, inner in range(start_m, upper_bound, tile_m * 2, init=init_state): results = yield pingpong(iv, inner) - dk, dv = _unpack_state(results) + ( + dk, + dv, + q_scales_head_pong, + q_scales_m_pong, + do_scales_head_pong, + do_scales_m_pong, + ) = _unpack_state(results) curr_m = arith.index(seqlen_rounded - tile_m) qk = compute_qk( - lds_q_quant_head_pong, - lds_q_scale_head_pong, - lds_k_quant_head, - lds_k_scale_head, + lds_q_pong, + q_scales_head_pong, + lds_k, + k_scales_head, ) p = softmax(qk, curr_m) mxquant_m_and_store_to_lds(p, lds_ppt_shuffle, lds_ppt_scale_shuffle) @@ -2176,12 +1878,10 @@ def pingpong(offset_m, inner_state): dv, lds_ppt_shuffle, lds_ppt_scale_shuffle, - lds_do_quant_m_pong, - lds_do_scale_m_pong, - ) - dp = compute_dp( - lds_do_quant_head_pong, lds_do_scale_head_pong, lds_v, lds_v_scale + lds_do_pong, + do_scales_m_pong, ) + dp = compute_dp(lds_do_pong, do_scales_head_pong, lds_v, lds_v_scale) ds = compute_ds(dp, p, curr_m) mxquant_m_and_store_to_lds(ds, lds_dst_shuffle, lds_dst_scale_shuffle) mxquant_n_and_store_to_lds(ds, lds_ds_shuffle, lds_ds_scale_shuffle) @@ -2190,49 +1890,47 @@ def pingpong(offset_m, inner_state): dk, lds_dst_shuffle, lds_dst_scale_shuffle, - lds_q_quant_m_pong, - lds_q_scale_m_pong, - ) - dq = compute_dq( - lds_ds_shuffle, lds_ds_scale_shuffle, lds_k_quant_n, lds_k_scale_n + lds_q_pong, + q_scales_m_pong, ) + dq = compute_dq(lds_ds_shuffle, lds_ds_scale_shuffle, lds_k, k_scales_n) store_dq_atomic(dq, curr_m) else: upper_bound = seqlen_rounded - (tile_m * 2) - init_state = _pack_state(dk, dv) + init_state = _pack_state( + dk, + dv, + q_scales_head_pong, + q_scales_m_pong, + do_scales_head_pong, + do_scales_m_pong, + ) for iv, inner in range(start_m, upper_bound, tile_m * 2, init=init_state): results = yield pingpong(iv, inner) - dk, dv = _unpack_state(results) + ( + dk, + dv, + q_scales_head_pong, + q_scales_m_pong, + do_scales_head_pong, + do_scales_m_pong, + ) = _unpack_state(results) curr_m = arith.index(seqlen_rounded - tile_m * 2) last_m = arith.index(seqlen_rounded - tile_m) last_m_mx = last_m // 32 - store_q_tile_to_lds( - prefetch_q_quant_head_tile(last_m), lds_q_quant_head_ping - ) - store_q_scale_tile_to_lds( - prefetch_q_scale_head_tile(last_m), lds_q_scale_head_ping - ) - store_q_tile_to_lds(prefetch_q_quant_m_tile(last_m), lds_q_quant_m_ping) - store_q_scale_tile_to_lds( - prefetch_q_scale_m_tile(last_m_mx), lds_q_scale_m_ping - ) - store_do_tile_to_lds( - prefetch_do_quant_head_tile(last_m), lds_do_quant_head_ping - ) - store_do_scale_tile_to_lds( - prefetch_do_scale_head_tile(last_m), lds_do_scale_head_ping - ) - store_do_tile_to_lds(prefetch_do_quant_m_tile(last_m), lds_do_quant_m_ping) - store_do_scale_tile_to_lds( - prefetch_do_scale_m_tile(last_m_mx), lds_do_scale_m_ping - ) + store_q_tile_to_lds(prefetch_q_tile(last_m), lds_q_ping) + q_scales_head_ping = prefetch_q_scale_head_2d_tile(last_m_mx) + q_scales_m_ping = prefetch_q_scale_m_2d_tile(last_m_mx) + store_do_tile_to_lds(prefetch_do_tile(last_m), lds_do_ping) + do_scales_head_ping = prefetch_do_scale_head_2d_tile(last_m_mx) + do_scales_m_ping = prefetch_do_scale_m_2d_tile(last_m_mx) qk = compute_qk( - lds_q_quant_head_pong, - lds_q_scale_head_pong, - lds_k_quant_head, - lds_k_scale_head, + lds_q_pong, + q_scales_head_pong, + lds_k, + k_scales_head, ) p = softmax(qk, curr_m) mxquant_m_and_store_to_lds(p, lds_ppt_shuffle, lds_ppt_scale_shuffle) @@ -2241,12 +1939,10 @@ def pingpong(offset_m, inner_state): dv, lds_ppt_shuffle, lds_ppt_scale_shuffle, - lds_do_quant_m_pong, - lds_do_scale_m_pong, - ) - dp = compute_dp( - lds_do_quant_head_pong, lds_do_scale_head_pong, lds_v, lds_v_scale + lds_do_pong, + do_scales_m_pong, ) + dp = compute_dp(lds_do_pong, do_scales_head_pong, lds_v, lds_v_scale) ds = compute_ds(dp, p, curr_m) mxquant_m_and_store_to_lds(ds, lds_dst_shuffle, lds_dst_scale_shuffle) mxquant_n_and_store_to_lds(ds, lds_ds_shuffle, lds_ds_scale_shuffle) @@ -2255,12 +1951,10 @@ def pingpong(offset_m, inner_state): dk, lds_dst_shuffle, lds_dst_scale_shuffle, - lds_q_quant_m_pong, - lds_q_scale_m_pong, - ) - dq = compute_dq( - lds_ds_shuffle, lds_ds_scale_shuffle, lds_k_quant_n, lds_k_scale_n + lds_q_pong, + q_scales_m_pong, ) + dq = compute_dq(lds_ds_shuffle, lds_ds_scale_shuffle, lds_k, k_scales_n) store_dq_atomic(dq, curr_m) hot_loop_scheduler() @@ -2268,10 +1962,10 @@ def pingpong(offset_m, inner_state): curr_m = last_m qk = compute_qk( - lds_q_quant_head_ping, - lds_q_scale_head_ping, - lds_k_quant_head, - lds_k_scale_head, + lds_q_ping, + q_scales_head_ping, + lds_k, + k_scales_head, ) p = softmax(qk, curr_m) mxquant_m_and_store_to_lds(p, lds_ppt_shuffle, lds_ppt_scale_shuffle) @@ -2280,12 +1974,10 @@ def pingpong(offset_m, inner_state): dv, lds_ppt_shuffle, lds_ppt_scale_shuffle, - lds_do_quant_m_ping, - lds_do_scale_m_ping, - ) - dp = compute_dp( - lds_do_quant_head_ping, lds_do_scale_head_ping, lds_v, lds_v_scale + lds_do_ping, + do_scales_m_ping, ) + dp = compute_dp(lds_do_ping, do_scales_head_ping, lds_v, lds_v_scale) ds = compute_ds(dp, p, curr_m) mxquant_m_and_store_to_lds(ds, lds_dst_shuffle, lds_dst_scale_shuffle) mxquant_n_and_store_to_lds(ds, lds_ds_shuffle, lds_ds_scale_shuffle) @@ -2294,12 +1986,10 @@ def pingpong(offset_m, inner_state): dk, lds_dst_shuffle, lds_dst_scale_shuffle, - lds_q_quant_m_ping, - lds_q_scale_m_ping, - ) - dq = compute_dq( - lds_ds_shuffle, lds_ds_scale_shuffle, lds_k_quant_n, lds_k_scale_n + lds_q_ping, + q_scales_m_ping, ) + dq = compute_dq(lds_ds_shuffle, lds_ds_scale_shuffle, lds_k, k_scales_n) store_dq_atomic(dq, curr_m) store_dk_atomic(dk) @@ -2313,40 +2003,36 @@ def launch_attn_bwd( arg_dq: fx.Tensor, arg_dk: fx.Tensor, arg_dv: fx.Tensor, - arg_q_quant_head: fx.Tensor, - arg_q_scale_head: fx.Tensor, - arg_q_quant_m: fx.Tensor, - arg_q_scale_m: fx.Tensor, - arg_k_quant_head: fx.Tensor, - arg_k_scale_head: fx.Tensor, - arg_k_quant_n: fx.Tensor, - arg_k_scale_n: fx.Tensor, + arg_q: fx.Tensor, + arg_q_scale: fx.Tensor, + arg_k: fx.Tensor, + arg_k_scale: fx.Tensor, arg_v: fx.Tensor, arg_v_scale: fx.Tensor, arg_do_quant_head: fx.Tensor, - arg_do_scale_head: fx.Tensor, - arg_do_quant_m: fx.Tensor, - arg_do_scale_m: fx.Tensor, + arg_do_scale: fx.Tensor, arg_M: fx.Tensor, arg_D: fx.Tensor, batch: fx.Int32, stride_qo_batch: fx.Int32, - stride_qo_scale_batch: fx.Int32, stride_kv_batch: fx.Int32, - stride_kv_scale_batch: fx.Int32, stride_MD_batch: fx.Int32, stride_qkvo_nheads: fx.Int32, - stride_qkvo_scale_nheads: fx.Int32, stride_MD_nheads: fx.Int32, + stride_q_scale_batch: fx.Int32, + stride_q_scale_nheads: fx.Int32, + stride_k_scale_batch: fx.Int32, + stride_k_scale_nheads: fx.Int32, + stride_v_scale_batch: fx.Int32, + stride_v_scale_nheads: fx.Int32, + stride_do_scale_batch: fx.Int32, + stride_do_scale_nheads: fx.Int32, stream: fx.Stream, ): _ = _cache_tag allocator_pong.finalized = False allocator_ping.finalized = False - allocator_k_quant_head.finalized = False - allocator_k_scale_head.finalized = False - allocator_k_quant_n.finalized = False - allocator_k_scale_n.finalized = False + allocator_k.finalized = False allocator_v.finalized = False allocator_v_scale.finalized = False allocator_ppt_shuffle.finalized = False @@ -2359,10 +2045,7 @@ def launch_attn_bwd( with ir.InsertionPoint(ctx.gpu_module_body): allocator_pong.finalize() allocator_ping.finalize() - allocator_k_quant_head.finalize() - allocator_k_scale_head.finalize() - allocator_k_quant_n.finalize() - allocator_k_scale_n.finalize() + allocator_k.finalize() allocator_v.finalize() allocator_v_scale.finalize() allocator_ppt_shuffle.finalize() @@ -2380,31 +2063,30 @@ def launch_attn_bwd( arg_dq, arg_dk, arg_dv, - arg_q_quant_head, - arg_q_scale_head, - arg_q_quant_m, - arg_q_scale_m, - arg_k_quant_head, - arg_k_scale_head, - arg_k_quant_n, - arg_k_scale_n, + arg_q, + arg_q_scale, + arg_k, + arg_k_scale, arg_v, arg_v_scale, arg_do_quant_head, - arg_do_scale_head, - arg_do_quant_m, - arg_do_scale_m, + arg_do_scale, arg_M, arg_D, batch, stride_qo_batch, - stride_qo_scale_batch, stride_kv_batch, - stride_kv_scale_batch, stride_MD_batch, stride_qkvo_nheads, - stride_qkvo_scale_nheads, stride_MD_nheads, + stride_q_scale_batch, + stride_q_scale_nheads, + stride_k_scale_batch, + stride_k_scale_nheads, + stride_v_scale_batch, + stride_v_scale_nheads, + stride_do_scale_batch, + stride_do_scale_nheads, ) if waves_per_eu is not None: _wpe = int(waves_per_eu) diff --git a/aiter/ops/triton/_triton_kernels/quant/mxfp8_quant.py b/aiter/ops/triton/_triton_kernels/quant/mxfp8_quant.py index 8e3ec93371..56212cd93a 100644 --- a/aiter/ops/triton/_triton_kernels/quant/mxfp8_quant.py +++ b/aiter/ops/triton/_triton_kernels/quant/mxfp8_quant.py @@ -15,6 +15,16 @@ def _get_max_quant_val(dtype: tl.constexpr): tl.static_assert(False, f"Invalid {dtype=}") +@triton.jit +def _get_max_power_of_2_quant_val(dtype: tl.constexpr): + if dtype == tl.float8e5: + return 32768.0 + elif dtype == tl.float8e4nv: + return 256.0 + else: + tl.static_assert(False, f"Invalid {dtype=}") + + @triton.jit def _compute_mx_quant_and_scale( src_tensor, @@ -36,13 +46,13 @@ def _compute_mx_quant_and_scale( abs_tensor, [BLOCK_SIZE_OUT_DIM, BLOCK_SIZE_QUANT_MX_SCALE, 32] ) max_val = tl.max(abs_tensor, axis=2, keep_dims=True) - dequant_scale = max_val / _get_max_quant_val(mx_tensor_dtype) if SCALE_ROUNDING_MODE == 0: # ROUND_UP # compute 2 ** ceil(log2(dequant_scale)) # Adding 0x007FFFFF adds exponent by 1 unless mantissa is all zeros # A corner case: exponent is 0xFF that will overflow but that's already # NaN so assume we don't care. + dequant_scale = max_val / _get_max_quant_val(mx_tensor_dtype) dequant_scale_exponent = ( dequant_scale.to(tl.uint32, bitcast=True) + 0x007FFFFF ) & 0x7F800000 @@ -50,6 +60,7 @@ def _compute_mx_quant_and_scale( # ROUND_DOWN # compute 2 ** floor(log2(dequant_scale)) assert SCALE_ROUNDING_MODE == 1 + dequant_scale = max_val / _get_max_power_of_2_quant_val(mx_tensor_dtype) dequant_scale_exponent = dequant_scale.to(tl.uint32, bitcast=True) & 0x7F800000 dequant_scale_rounded = dequant_scale_exponent.to(tl.float32, bitcast=True) quant_scale = tl.where(dequant_scale_rounded == 0, 0, 1.0 / dequant_scale_rounded) @@ -271,3 +282,252 @@ def _upcast_from_mxfp8( out_tensor = out_tensor.reshape([BLOCK_SIZE_OUT_DIM, BLOCK_SIZE_QUANT_DIM]) out_tensor = out_tensor.to(dst_dtype) tl.store(out_ptr + out_offsets, out_tensor, mask=full_mask_out) + + +@triton.jit +def _compute_mx_quant_and_scale_2d( + src_tensor, + valid_src_mask, + mx_tensor_dtype: tl.constexpr, + SCALE_ROUNDING_MODE: tl.constexpr = 0, +): + BLOCK_SIZE_M: tl.constexpr = src_tensor.shape[0] + BLOCK_SIZE_N: tl.constexpr = src_tensor.shape[1] + BLOCK_SIZE_M_SCALE: tl.constexpr = BLOCK_SIZE_M // 32 + BLOCK_SIZE_N_SCALE: tl.constexpr = BLOCK_SIZE_N // 32 + + # Explicit cast to fp32 since most ops are not supported on bfloat16. + f32_tensor = src_tensor.to(tl.float32) + abs_tensor = tl.abs(f32_tensor) + # Don't consider padding tensors in scale computation + abs_tensor = tl.where(valid_src_mask, abs_tensor, -1.0) + + # Reshape to (M_SCALE, 32, N_SCALE, 32) so that for each (i, j) scale block, + # the elements live at abs_4d[i, :, j, :] — a 32x32 sub-block. + abs_4d = tl.reshape(abs_tensor, [BLOCK_SIZE_M_SCALE, 32, BLOCK_SIZE_N_SCALE, 32]) + # Two sequential reductions to compute the max over each 32x32 block. + max_val = tl.max(abs_4d, axis=3, keep_dims=True) + max_val = tl.max(max_val, axis=1, keep_dims=True) + + dequant_scale = max_val / _get_max_quant_val(mx_tensor_dtype) + if SCALE_ROUNDING_MODE == 0: + # ROUND_UP + dequant_scale_exponent = ( + dequant_scale.to(tl.uint32, bitcast=True) + 0x007FFFFF + ) & 0x7F800000 + else: + # ROUND_DOWN + tl.static_assert(SCALE_ROUNDING_MODE == 1) + dequant_scale_exponent = dequant_scale.to(tl.uint32, bitcast=True) & 0x7F800000 + dequant_scale_rounded = dequant_scale_exponent.to(tl.float32, bitcast=True) + quant_scale = tl.where(dequant_scale_rounded == 0, 0, 1.0 / dequant_scale_rounded) + + # Broadcast (M_SCALE, 1, N_SCALE, 1) over (M_SCALE, 32, N_SCALE, 32). + f32_tensor_4d = tl.reshape( + f32_tensor, [BLOCK_SIZE_M_SCALE, 32, BLOCK_SIZE_N_SCALE, 32] + ) + quant_tensor = f32_tensor_4d * quant_scale + + quant_tensor = quant_tensor.reshape([BLOCK_SIZE_M, BLOCK_SIZE_N]) + quant_tensor = tl.where(valid_src_mask, quant_tensor, 0) + + dequant_scale_exponent = dequant_scale_exponent.reshape( + [BLOCK_SIZE_M_SCALE, BLOCK_SIZE_N_SCALE] + ) + dequant_scale_exponent = (dequant_scale_exponent >> 23).to(tl.uint8) + out_tensor = quant_tensor.to(mx_tensor_dtype) + + return out_tensor, dequant_scale_exponent + + +@triton.jit +def _downcast_to_mxfp8_2d( + mx_tensor_ptr, + stride_mxt_b, + stride_mxt_m, + stride_mxt_n: tl.constexpr, + mx_scale_ptr, + stride_mx_scale_b, + stride_mx_scale_m, + stride_mx_scale_n, + src_ptr, + stride_src_b, + stride_src_m, + stride_src_n, + M, + N, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + SCALE_ROUNDING_MODE: tl.constexpr, +): + + tl.static_assert(stride_mxt_n == 1, f"Output stride, {stride_mxt_n=} must be 1.") + tl.static_assert( + BLOCK_SIZE_M % 32 == 0, f"{BLOCK_SIZE_M=} must be a multiple of 32" + ) + tl.static_assert( + BLOCK_SIZE_N % 32 == 0, f"{BLOCK_SIZE_N=} must be a multiple of 32" + ) + + mx_tensor_dtype: tl.constexpr = mx_tensor_ptr.dtype.element_ty + tl.static_assert( + mx_tensor_dtype == tl.float8e4nv or mx_tensor_dtype == tl.float8e5, + f"Invalid {mx_tensor_dtype=}. Must be float8.", + ) + + src_dtype: tl.constexpr = src_ptr.dtype.element_ty + tl.static_assert( + mx_scale_ptr.dtype.element_ty == tl.uint8, + f"{mx_scale_ptr.dtype.element_ty=} must be uint8", + ) + tl.static_assert( + (src_dtype == tl.float32) + or (src_dtype == tl.bfloat16) + or (src_dtype == tl.float16), + f"{src_dtype=} must be float32 or bfloat16 or float16", + ) + + batch = tl.program_id(0).to(tl.int64) + m_block = tl.program_id(1).to(tl.int64) + n_block = tl.program_id(2).to(tl.int64) + + BLOCK_SIZE_M_SCALE: tl.constexpr = BLOCK_SIZE_M // 32 + BLOCK_SIZE_N_SCALE: tl.constexpr = BLOCK_SIZE_N // 32 + + start_m = m_block * BLOCK_SIZE_M + start_n = n_block * BLOCK_SIZE_N + start_scale_m = m_block * BLOCK_SIZE_M_SCALE + start_scale_n = n_block * BLOCK_SIZE_N_SCALE + + src_ptr += batch * stride_src_b + start_m * stride_src_m + start_n * stride_src_n + mx_tensor_ptr += ( + batch * stride_mxt_b + start_m * stride_mxt_m + start_n * stride_mxt_n + ) + mx_scale_ptr += ( + batch * stride_mx_scale_b + + start_scale_m * stride_mx_scale_m + + start_scale_n * stride_mx_scale_n + ) + + offs_m = tl.arange(0, BLOCK_SIZE_M)[:, None].to(tl.int64) + offs_n = tl.arange(0, BLOCK_SIZE_N)[None, :].to(tl.int64) + offs_scale_m = tl.arange(0, BLOCK_SIZE_M_SCALE)[:, None].to(tl.int64) + offs_scale_n = tl.arange(0, BLOCK_SIZE_N_SCALE)[None, :].to(tl.int64) + + mask_m = start_m + offs_m < M + mask_n = start_n + offs_n < N + full_mask = mask_m & mask_n + + mask_scale_m = start_scale_m + offs_scale_m < tl.cdiv(M, 32) + mask_scale_n = start_scale_n + offs_scale_n < tl.cdiv(N, 32) + full_scale_mask = mask_scale_m & mask_scale_n + + src_offsets = offs_m * stride_src_m + offs_n * stride_src_n + mx_tensor_offsets = offs_m * stride_mxt_m + offs_n * stride_mxt_n + scale_offsets = offs_scale_m * stride_mx_scale_m + offs_scale_n * stride_mx_scale_n + + src_tensor = tl.load(src_ptr + src_offsets, mask=full_mask) + + out_tensor, scale_tensor = _compute_mx_quant_and_scale_2d( + src_tensor, full_mask, mx_tensor_dtype, SCALE_ROUNDING_MODE + ) + + tl.store(mx_scale_ptr + scale_offsets, scale_tensor, mask=full_scale_mask) + tl.store(mx_tensor_ptr + mx_tensor_offsets, out_tensor, mask=full_mask) + + +@triton.jit +def _upcast_from_mxfp8_2d( + out_ptr, + stride_o_b, + stride_o_m, + stride_o_n: tl.constexpr, + mx_scale_ptr, + stride_scale_b, + stride_scale_m, + stride_scale_n, + mx_tensor_ptr, + stride_tensor_b, + stride_tensor_m, + stride_tensor_n: tl.constexpr, + M, + N, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, +): + + tl.static_assert( + stride_o_n == 1, + "the weight must be contiguous in the n dimension for mx", + ) + tl.static_assert(BLOCK_SIZE_M % 32 == 0, "BLOCK_SIZE_M must be a multiple of 32") + tl.static_assert(BLOCK_SIZE_N % 32 == 0, "BLOCK_SIZE_N must be a multiple of 32") + mx_tensor_dtype: tl.constexpr = mx_tensor_ptr.dtype.element_ty + dst_dtype: tl.constexpr = out_ptr.dtype.element_ty + tl.static_assert( + dst_dtype == tl.float32 or dst_dtype == tl.float16 or dst_dtype == tl.bfloat16 + ) + tl.static_assert( + (mx_tensor_dtype == tl.float8e4nv or mx_tensor_dtype == tl.float8e5) + or mx_tensor_dtype == dst_dtype, + "mx_tensor_ptr must be float8 or dst_dtype", + ) + tl.static_assert( + mx_scale_ptr.dtype.element_ty == tl.uint8, "mx_scale_ptr must be uint8" + ) + + BLOCK_SIZE_M_SCALE: tl.constexpr = BLOCK_SIZE_M // 32 + BLOCK_SIZE_N_SCALE: tl.constexpr = BLOCK_SIZE_N // 32 + + batch = tl.program_id(0).to(tl.int64) + m_block = tl.program_id(1).to(tl.int64) + n_block = tl.program_id(2).to(tl.int64) + + start_m = m_block * BLOCK_SIZE_M + start_n = n_block * BLOCK_SIZE_N + start_scale_m = m_block * BLOCK_SIZE_M_SCALE + start_scale_n = n_block * BLOCK_SIZE_N_SCALE + + mx_tensor_ptr += ( + batch * stride_tensor_b + start_m * stride_tensor_m + start_n * stride_tensor_n + ) + mx_scale_ptr += ( + batch * stride_scale_b + + start_scale_m * stride_scale_m + + start_scale_n * stride_scale_n + ) + out_ptr += batch * stride_o_b + start_m * stride_o_m + start_n * stride_o_n + + offs_m = tl.arange(0, BLOCK_SIZE_M)[:, None].to(tl.int64) + offs_n = tl.arange(0, BLOCK_SIZE_N)[None, :].to(tl.int64) + offs_scale_m = tl.arange(0, BLOCK_SIZE_M_SCALE)[:, None].to(tl.int64) + offs_scale_n = tl.arange(0, BLOCK_SIZE_N_SCALE)[None, :].to(tl.int64) + + mask_m = start_m + offs_m < M + mask_n = start_n + offs_n < N + full_mask = mask_m & mask_n + + mask_scale_m = start_scale_m + offs_scale_m < tl.cdiv(M, 32) + mask_scale_n = start_scale_n + offs_scale_n < tl.cdiv(N, 32) + full_scale_mask = mask_scale_m & mask_scale_n + + tensor_offsets = offs_m * stride_tensor_m + offs_n * stride_tensor_n + scale_offsets = offs_scale_m * stride_scale_m + offs_scale_n * stride_scale_n + out_offsets = offs_m * stride_o_m + offs_n * stride_o_n + + tensor = tl.load(mx_tensor_ptr + tensor_offsets, mask=full_mask) + scale = tl.load(mx_scale_ptr + scale_offsets, mask=full_scale_mask) + + dst_scale = (scale.to(tl.uint32) << 23).to(tl.float32, bitcast=True) + dst_tensor = tensor.to(tl.float32) + + # Broadcast the per-32x32-block scale across the full tile. + dst_tensor = dst_tensor.reshape([BLOCK_SIZE_M_SCALE, 32, BLOCK_SIZE_N_SCALE, 32]) + dst_scale = dst_scale.reshape([BLOCK_SIZE_M_SCALE, 1, BLOCK_SIZE_N_SCALE, 1]) + scale_4d = scale.reshape([BLOCK_SIZE_M_SCALE, 1, BLOCK_SIZE_N_SCALE, 1]) + + out_tensor = dst_tensor * dst_scale + out_tensor = tl.where(scale_4d == 0xFF, float("nan"), out_tensor) + out_tensor = out_tensor.reshape([BLOCK_SIZE_M, BLOCK_SIZE_N]) + out_tensor = out_tensor.to(dst_dtype) + tl.store(out_ptr + out_offsets, out_tensor, mask=full_mask) diff --git a/aiter/ops/triton/quant/mxfp8_quant.py b/aiter/ops/triton/quant/mxfp8_quant.py index a5d0a2a635..cd9eb4f63a 100644 --- a/aiter/ops/triton/quant/mxfp8_quant.py +++ b/aiter/ops/triton/quant/mxfp8_quant.py @@ -6,12 +6,16 @@ from aiter.ops.triton._triton_kernels.quant.mxfp8_quant import ( _downcast_to_mxfp8, _upcast_from_mxfp8, + _downcast_to_mxfp8_2d, + _upcast_from_mxfp8_2d, ) from aiter.ops.triton.utils.logger import AiterTritonLogger __all__ = [ "downcast_to_mxfp8", "upcast_from_mxfp8", + "downcast_to_mxfp8_2d", + "upcast_from_mxfp8_2d", ] @@ -130,3 +134,133 @@ def upcast_from_mxfp8( ) out = out.transpose(axis, scale.ndim - 1).contiguous() return out + + +def downcast_to_mxfp8_2d( + src_tensor: torch.Tensor, + out_quant_type: torch.dtype, + SCALE_ROUNDING_MODE: int = 0, +): + """ + Convert the last two dimensions of ``src_tensor`` to the mxfp8 format, + where each 32x32 block of those two dimensions shares a single scale. + + The quantized tensor preserves the input shape; the scale tensor has shape + ``(..., cdiv(M, 32), cdiv(N, 32))`` and dtype uint8, with ``M`` and ``N`` + being the last two dims of ``src_tensor``. + + ``out_quant_type`` must be ``torch.float8_e4m3fn`` or ``torch.float8_e5m2``. + """ + assert ( + src_tensor.ndim >= 2 + ), f"src_tensor must have at least 2 dimensions, got {src_tensor.ndim}" + assert out_quant_type in { + torch.float8_e4m3fn, + torch.float8_e5m2, + }, f"Invalid out_quant_type {out_quant_type=}" + + src_tensor = src_tensor.contiguous() + M = src_tensor.shape[-2] + N = src_tensor.shape[-1] + M_scale = triton.cdiv(M, 32) + N_scale = triton.cdiv(N, 32) + + out_shape = src_tensor.shape + out_scale_shape = src_tensor.shape[:-2] + (M_scale, N_scale) + + out_quant_tensor = src_tensor.new_empty(out_shape, dtype=out_quant_type) + out_scale = src_tensor.new_empty(out_scale_shape, dtype=torch.uint8) + + kernel_src = src_tensor.reshape(-1, M, N) + kernel_quant = out_quant_tensor.view(-1, M, N) + kernel_scale = out_scale.view(-1, M_scale, N_scale) + + BLOCK_M = 128 + BLOCK_N = 128 + B = kernel_src.shape[0] + grid = (B, triton.cdiv(M, BLOCK_M), triton.cdiv(N, BLOCK_N)) + + _downcast_to_mxfp8_2d[grid]( + kernel_quant, + *kernel_quant.stride(), + kernel_scale, + *kernel_scale.stride(), + kernel_src, + *kernel_src.stride(), + M, + N, + BLOCK_M, + BLOCK_N, + SCALE_ROUNDING_MODE, + num_warps=8, + ) + + return out_quant_tensor, out_scale + + +def upcast_from_mxfp8_2d( + tensor: torch.Tensor, + scale: torch.Tensor, + dtype: torch.dtype, +): + """ + Inverse of :func:`downcast_to_mxfp8_2d`. ``scale`` is expected to be a + ``(..., cdiv(M, 32), cdiv(N, 32))`` uint8 tensor where each entry + corresponds to a 32x32 block of ``tensor``'s last two dimensions. + """ + assert tensor.ndim >= 2 and scale.ndim >= 2 + assert tensor.dtype in { + torch.float8_e5m2, + torch.float8_e4m3fn, + torch.float8_e4m3fnuz, + }, f"Invalid tensor dtype {tensor.dtype=}" + assert scale.dtype == torch.uint8, f"Invalid scale dtype {scale.dtype=}" + assert dtype in ( + torch.float16, + torch.bfloat16, + torch.float32, + ), f"Invalid output dtype {dtype=}" + assert tensor.ndim == scale.ndim, ( + f"tensor and scale must have the same number of dimensions. " + f"Got {tensor.ndim=} and {scale.ndim=}" + ) + + tensor = tensor.contiguous() + scale = scale.contiguous() + M = tensor.shape[-2] + N = tensor.shape[-1] + M_scale = scale.shape[-2] + N_scale = scale.shape[-1] + assert M_scale == triton.cdiv( + M, 32 + ), f"scale shape mismatch: got {M_scale=} expected {triton.cdiv(M, 32)}" + assert N_scale == triton.cdiv( + N, 32 + ), f"scale shape mismatch: got {N_scale=} expected {triton.cdiv(N, 32)}" + + out = torch.empty(tensor.shape, dtype=dtype, device=tensor.device) + + kernel_out = out.view(-1, M, N) + kernel_tensor = tensor.view(-1, M, N) + kernel_scale = scale.view(-1, M_scale, N_scale) + + BLOCK_M = 128 + BLOCK_N = 128 + B = kernel_out.shape[0] + grid = (B, triton.cdiv(M, BLOCK_M), triton.cdiv(N, BLOCK_N)) + + _upcast_from_mxfp8_2d[grid]( + kernel_out, + *kernel_out.stride(), + kernel_scale, + *kernel_scale.stride(), + kernel_tensor, + *kernel_tensor.stride(), + M, + N, + BLOCK_M, + BLOCK_N, + num_warps=8, + ) + + return out diff --git a/op_tests/flydsl_tests/test_attn_bwd_mxfp8_gfx950.py b/op_tests/flydsl_tests/test_attn_bwd_mxfp8_gfx950.py index 1a2e791fe1..30b8615d62 100644 --- a/op_tests/flydsl_tests/test_attn_bwd_mxfp8_gfx950.py +++ b/op_tests/flydsl_tests/test_attn_bwd_mxfp8_gfx950.py @@ -8,7 +8,12 @@ import torch import pytest -from aiter.ops.triton.quant.mxfp8_quant import downcast_to_mxfp8, upcast_from_mxfp8 +from aiter.ops.triton.quant.mxfp8_quant import ( + downcast_to_mxfp8, + upcast_from_mxfp8, + downcast_to_mxfp8_2d, + upcast_from_mxfp8_2d, +) from aiter.ops.flydsl.kernels.attn_bwd_mxfp8_gfx950 import compile_attn_bwd_mxfp8_gfx950 from flydsl.runtime.device import get_rocm_arch @@ -43,28 +48,31 @@ def mx_quant(x, dim=-1): return x_fp32.contiguous(), x_fp8.contiguous(), x_scale.contiguous() +def mx_quant_2d(x): + x_fp8, x_scale = downcast_to_mxfp8_2d(x, torch.float8_e4m3fn) + x_fp32 = upcast_from_mxfp8_2d(x_fp8, x_scale, torch.float32) + return x_fp32.contiguous(), x_fp8.contiguous(), x_scale.contiguous() + + def run_torch( - q_fp32_head, - q_fp32_m, - k_fp32_head, - k_fp32_n, + q_fp32, + k_fp32, v, - do_fp32_head, - do_fp32_m, + do_fp32, m, D, sm_scale, causal, gqa_size, ): - batch = q_fp32_head.shape[0] - num_heads_q = q_fp32_head.shape[1] + batch = q_fp32.shape[0] + num_heads_q = q_fp32.shape[1] num_heads_kv = num_heads_q // gqa_size - seqlen = q_fp32_head.shape[2] - head_dim = q_fp32_head.shape[3] - device = q_fp32_head.device + seqlen = q_fp32.shape[2] + head_dim = q_fp32.shape[3] + device = q_fp32.device v_f32 = v.to(torch.float32) - qk = torch.matmul(q_fp32_head, k_fp32_head.transpose(-2, -1)) * sm_scale + qk = torch.matmul(q_fp32, k_fp32.transpose(-2, -1)) * sm_scale p = torch.exp(qk - m[:, :, :, None]) if causal: mask = torch.tril(torch.ones((seqlen, seqlen), device=device)) @@ -72,14 +80,14 @@ def run_torch( ppT, _, _ = mx_quant(p, -2) ppT = ppT.transpose(-2, -1) - dv = torch.matmul(ppT, do_fp32_m) - dp = torch.matmul(do_fp32_head, v_f32.transpose(-2, -1)) + dv = torch.matmul(ppT, do_fp32) + dp = torch.matmul(do_fp32, v_f32.transpose(-2, -1)) ds = p * (dp - D[:, :, :, None]) dsT, _, _ = mx_quant(ds, -1) dsT = dsT.transpose(-2, -1) ds, _, _ = mx_quant(ds, -2) - dk = torch.matmul(dsT, q_fp32_m) * sm_scale - dq = torch.matmul(ds, k_fp32_n) * sm_scale + dk = torch.matmul(dsT, q_fp32) * sm_scale + dq = torch.matmul(ds, k_fp32) * sm_scale dk = dk.view(batch, num_heads_kv, gqa_size, seqlen, head_dim).sum(dim=2) dv = dv.view(batch, num_heads_kv, gqa_size, seqlen, head_dim).sum(dim=2) @@ -156,37 +164,29 @@ def test_attn_bwd_flyc( * 0.5 ) - q_fp32_head, q_quant_head, q_scale_head = mx_quant(q_fp32, -1) - q_fp32_m, q_quant_m, q_scale_m = mx_quant(q_fp32, -2) - k_fp32_head, k_quant_head, k_scale_head = mx_quant(k_fp32, -1) - k_fp32_n, k_quant_n, k_scale_n = mx_quant(k_fp32, -2) + q_fp32, q_quant, q_scale = mx_quant_2d(q_fp32) + k_fp32, k_quant, k_scale = mx_quant_2d(k_fp32) v_fp32, v_quant, v_scale = mx_quant(v_fp32) - do_fp32_head, do_quant_head, do_scale_head = mx_quant(do_fp32, -1) - do_fp32_m, do_quant_m, do_scale_m = mx_quant(do_fp32, -2) + do_fp32, do_quant, do_scale = mx_quant_2d(do_fp32) k_fp32 = k_fp32.repeat_interleave(gqa_size, dim=1) - k_fp32_head = k_fp32_head.repeat_interleave(gqa_size, dim=1) - k_fp32_n = k_fp32_n.repeat_interleave(gqa_size, dim=1) v_fp32 = v_fp32.repeat_interleave(gqa_size, dim=1) qk = q_fp32 @ k_fp32.transpose(-2, -1) qk = qk * sm_scale m = qk.max(dim=-1)[0] p = (qk - m[:, :, :, None]).exp() - l = p.sum(dim=-1) - p = p / l[:, :, :, None] + L = p.sum(dim=-1) + p = p / L[:, :, :, None] o_fp32 = torch.matmul(p, v_fp32) - m = m + torch.log(l) + m = m + torch.log(L) D = (o_fp32 * do_fp32).sum(dim=-1) dq_ref, dk_ref, dv_ref = run_torch( - q_fp32_head, - q_fp32_m, - k_fp32_head, - k_fp32_n, + q_fp32, + k_fp32, v_fp32, - do_fp32_head, - do_fp32_m, + do_fp32, m, D, sm_scale, @@ -208,20 +208,14 @@ def launch_kernel( dq, dk, dv, - q_quant_head, - q_scale_head, - q_quant_m, - q_scale_m, - k_quant_head, - k_scale_head, - k_quant_n, - k_scale_n, + q, + q_scale, + k, + k_scale, v, v_scale, - do_quant_head, - do_scale_head, - do_quant_m, - do_scale_m, + do, + do_scale, m, D, batch, @@ -230,31 +224,30 @@ def launch_kernel( dq.contiguous().view(-1), dk.contiguous().view(-1), dv.contiguous().view(-1), - q_quant_head.contiguous().view(-1), - q_scale_head.contiguous().view(-1), - q_quant_m.contiguous().view(-1), - q_scale_m.contiguous().view(-1), - k_quant_head.contiguous().view(-1), - k_scale_head.contiguous().view(-1), - k_quant_n.contiguous().view(-1), - k_scale_n.contiguous().view(-1), + q.contiguous().view(-1), + q_scale.contiguous().view(-1), + k.contiguous().view(-1), + k_scale.contiguous().view(-1), v.contiguous().view(-1), v_scale.contiguous().view(-1), - do_quant_head.contiguous().view(-1), - do_scale_head.contiguous().view(-1), - do_quant_m.contiguous().view(-1), - do_scale_m.contiguous().view(-1), + do.contiguous().view(-1), + do_scale.contiguous().view(-1), m.contiguous().view(-1), D.contiguous().view(-1), batch, - q_quant_head.stride(0), - q_scale_head.stride(0), - k_quant_head.stride(0), - k_scale_head.stride(0), + q.stride(0), + k.stride(0), m.stride(0), - q_quant_head.stride(1), - q_scale_head.stride(1), + q.stride(1), m.stride(1), + q_scale.stride(0), + q_scale.stride(1), + k_scale.stride(0), + k_scale.stride(1), + v_scale.stride(0), + v_scale.stride(1), + do_scale.stride(0), + do_scale.stride(1), torch.cuda.current_stream(), ) @@ -262,20 +255,14 @@ def launch_kernel( dq_fly, dk_fly, dv_fly, - q_quant_head, - q_scale_head, - q_quant_m, - q_scale_m, - k_quant_head, - k_scale_head, - k_quant_n, - k_scale_n, + q_quant, + q_scale, + k_quant, + k_scale, v_quant, v_scale, - do_quant_head, - do_scale_head, - do_quant_m, - do_scale_m, + do_quant, + do_scale, m, D, batch, diff --git a/op_tests/op_benchmarks/flydsl/bench_attn_bwd_mxfp8_gfx950.py b/op_tests/op_benchmarks/flydsl/bench_attn_bwd_mxfp8_gfx950.py index 85b5d52113..32c25933b8 100644 --- a/op_tests/op_benchmarks/flydsl/bench_attn_bwd_mxfp8_gfx950.py +++ b/op_tests/op_benchmarks/flydsl/bench_attn_bwd_mxfp8_gfx950.py @@ -12,6 +12,7 @@ from op_tests.flydsl_tests.test_attn_bwd_mxfp8_gfx950 import ( run_torch, mx_quant, + mx_quant_2d, check_result, ) from flydsl.runtime.device import get_rocm_arch @@ -57,7 +58,7 @@ def bench_attn_bwd_flyc( causal=causal, waves_per_eu=_wpe, ) - print(f"✓ Kernel prepared") + print("✓ Kernel prepared") device = torch.device("cuda") gqa_size = num_heads_q // num_heads_kv @@ -92,36 +93,28 @@ def bench_attn_bwd_flyc( * 0.5 ) - q_fp32_head, q_quant_head, q_scale_head = mx_quant(q_fp32, -1) - q_fp32_m, q_quant_m, q_scale_m = mx_quant(q_fp32, -2) - k_fp32_head, k_quant_head, k_scale_head = mx_quant(k_fp32, -1) - k_fp32_n, k_quant_n, k_scale_n = mx_quant(k_fp32, -2) + q_fp32, q_quant, q_scale = mx_quant_2d(q_fp32) + k_fp32, k_quant, k_scale = mx_quant_2d(k_fp32) v_fp32, v_quant, v_scale = mx_quant(v_fp32) - do_fp32_head, do_quant_head, do_scale_head = mx_quant(do_fp32, -1) - do_fp32_m, do_quant_m, do_scale_m = mx_quant(do_fp32, -2) + do_fp32, do_quant, do_scale = mx_quant_2d(do_fp32) k_fp32 = k_fp32.repeat_interleave(gqa_size, dim=1) - k_fp32_head = k_fp32_head.repeat_interleave(gqa_size, dim=1) - k_fp32_n = k_fp32_n.repeat_interleave(gqa_size, dim=1) v_fp32 = v_fp32.repeat_interleave(gqa_size, dim=1) qk = torch.matmul(q_fp32, k_fp32.transpose(-2, -1)) qk = qk * sm_scale m = qk.max(dim=-1)[0] p = (qk - m[:, :, :, None]).exp() - l = p.sum(dim=-1) - m = m + torch.log(l) + L = p.sum(dim=-1) + m = m + torch.log(L) D = (o_fp32 * do_fp32).sum(dim=-1) if check_correctness: dq_ref, dk_ref, dv_ref = run_torch( - q_fp32_head, - q_fp32_m, - k_fp32_head, - k_fp32_n, + q_fp32, + k_fp32, v_fp32, - do_fp32_head, - do_fp32_m, + do_fp32, m, D, sm_scale, @@ -142,20 +135,14 @@ def launch_kernel( dq, dk, dv, - q_quant_head, - q_scale_head, - q_quant_m, - q_scale_m, - k_quant_head, - k_scale_head, - k_quant_n, - k_scale_n, + q, + q_scale, + k, + k_scale, v, v_scale, - do_quant_head, - do_scale_head, - do_quant_m, - do_scale_m, + do, + do_scale, m, D, batch, @@ -164,31 +151,30 @@ def launch_kernel( dq.contiguous().view(-1), dk.contiguous().view(-1), dv.contiguous().view(-1), - q_quant_head.contiguous().view(-1), - q_scale_head.contiguous().view(-1), - q_quant_m.contiguous().view(-1), - q_scale_m.contiguous().view(-1), - k_quant_head.contiguous().view(-1), - k_scale_head.contiguous().view(-1), - k_quant_n.contiguous().view(-1), - k_scale_n.contiguous().view(-1), + q.contiguous().view(-1), + q_scale.contiguous().view(-1), + k.contiguous().view(-1), + k_scale.contiguous().view(-1), v.contiguous().view(-1), v_scale.contiguous().view(-1), - do_quant_head.contiguous().view(-1), - do_scale_head.contiguous().view(-1), - do_quant_m.contiguous().view(-1), - do_scale_m.contiguous().view(-1), + do.contiguous().view(-1), + do_scale.contiguous().view(-1), m.contiguous().view(-1), D.contiguous().view(-1), batch, - q_quant_head.stride(0), - q_scale_head.stride(0), - k_quant_head.stride(0), - k_scale_head.stride(0), + q.stride(0), + k.stride(0), m.stride(0), - q_quant_head.stride(1), - q_scale_head.stride(1), + q.stride(1), m.stride(1), + q_scale.stride(0), + q_scale.stride(1), + k_scale.stride(0), + k_scale.stride(1), + v_scale.stride(0), + v_scale.stride(1), + do_scale.stride(0), + do_scale.stride(1), torch.cuda.current_stream(), ) @@ -199,20 +185,14 @@ def launch_kernel( dq_fly, dk_fly, dv_fly, - q_quant_head, - q_scale_head, - q_quant_m, - q_scale_m, - k_quant_head, - k_scale_head, - k_quant_n, - k_scale_n, + q_quant, + q_scale, + k_quant, + k_scale, v_quant, v_scale, - do_quant_head, - do_scale_head, - do_quant_m, - do_scale_m, + do_quant, + do_scale, m, D, batch, @@ -220,46 +200,37 @@ def launch_kernel( num_warmup=bench_warmup, testGraph=test_graph, ) - torch.cuda.synchronize() - dq_fly.zero_() - dk_fly.zero_() - dv_fly.zero_() - launch_kernel( - dq_fly, - dk_fly, - dv_fly, - q_quant_head, - q_scale_head, - q_quant_m, - q_scale_m, - k_quant_head, - k_scale_head, - k_quant_n, - k_scale_n, - v_quant, - v_scale, - do_quant_head, - do_scale_head, - do_quant_m, - do_scale_m, - m, - D, - batch, - ) + if check_correctness: + torch.cuda.synchronize() - dq_fly_fp32 = dq_fly.to(torch.float32) - dk_fly_fp32 = dk_fly.to(torch.float32) - dv_fly_fp32 = dv_fly.to(torch.float32) + dq_fly.zero_() + dk_fly.zero_() + dv_fly.zero_() + launch_kernel( + dq_fly, + dk_fly, + dv_fly, + q_quant, + q_scale, + k_quant, + k_scale, + v_quant, + v_scale, + do_quant, + do_scale, + m, + D, + batch, + ) - if check_correctness: - assert check_result(dq_fly_fp32, dq_ref, rtol=0.01, atol=0.01) - assert check_result(dk_fly_fp32, dk_ref, rtol=0.01, atol=0.01) - assert check_result(dv_fly_fp32, dv_ref, rtol=0.01, atol=0.01) + assert check_result(dq_fly, dq_ref, rtol=0.01, atol=0.01) + assert check_result(dk_fly, dk_ref, rtol=0.01, atol=0.01) + assert check_result(dv_fly, dv_ref, rtol=0.01, atol=0.01) bytes_moved = ( - (4 + 4) * batch * num_heads_q * seqlen * head_dim - + (3 + 2 * 4) * batch * num_heads_kv * seqlen * head_dim + (2 + 4) * batch * num_heads_q * seqlen * head_dim + + (2 + 2 * 4) * batch * num_heads_kv * seqlen * head_dim + 2 * 4 * batch * num_heads_q * seqlen ) flops = ( diff --git a/op_tests/op_benchmarks/flydsl/utils.py b/op_tests/op_benchmarks/flydsl/utils.py index b3d67bc375..8bf3824b30 100644 --- a/op_tests/op_benchmarks/flydsl/utils.py +++ b/op_tests/op_benchmarks/flydsl/utils.py @@ -5,7 +5,6 @@ import torch.profiler as tpf import os import copy -import time import numpy as np import pandas as pd import logging diff --git a/op_tests/triton_tests/quant/test_quant_mxfp8.py b/op_tests/triton_tests/quant/test_quant_mxfp8.py index e81d3f78da..85d72b9b89 100644 --- a/op_tests/triton_tests/quant/test_quant_mxfp8.py +++ b/op_tests/triton_tests/quant/test_quant_mxfp8.py @@ -3,7 +3,12 @@ import torch import pytest -from aiter.ops.triton.quant.mxfp8_quant import downcast_to_mxfp8, upcast_from_mxfp8 +from aiter.ops.triton.quant.mxfp8_quant import ( + downcast_to_mxfp8, + upcast_from_mxfp8, + downcast_to_mxfp8_2d, + upcast_from_mxfp8_2d, +) def get_max_quant_val(dtype): @@ -13,6 +18,13 @@ def get_max_quant_val(dtype): return 57344.0 +def get_max_quant_power_of_2_val(dtype): + if dtype == torch.float8_e4m3fn: + return 256.0 + else: + return 32768.0 + + def torch_downcast_to_mxfp8( x: torch.Tensor, dtype: torch.dtype, axis: int, SCALE_ROUNDING_MODE: int = 0 ): @@ -39,10 +51,11 @@ def torch_downcast_to_mxfp8( x_padded = x_padded.reshape(new_shape) x_abs_padded = x_abs_padded.reshape(new_shape) scale = torch.amax(x_abs_padded, -1) - scale = scale / get_max_quant_val(dtype) if SCALE_ROUNDING_MODE == 0: + scale = scale / get_max_quant_val(dtype) scale = (scale.view(torch.int32) + 0x007FFFFF) & 0x7F800000 else: + scale = scale / get_max_quant_power_of_2_val(dtype) scale = scale.view(torch.int32) & 0x7F800000 scale = scale.view(torch.float32) @@ -50,6 +63,8 @@ def torch_downcast_to_mxfp8( x_padded = x_padded * scale_inv x_padded = x_padded.reshape(padded_shape) x = x_padded[..., :quant_dim].clone() + max_val = get_max_quant_val(dtype) + x = x.clamp_(-max_val, max_val) x = x.to(dtype).to(torch.float32) x = x.transpose(axis, x.ndim - 1) scale = scale.transpose(axis, x.ndim - 1) @@ -174,3 +189,160 @@ def test_upcast_from_mxfp8(shape, axis, in_dtype, out_dtype): out_torch = torch_upcast_from_mxfp8(x, x_scale, out_dtype, axis) torch.testing.assert_close(out_triton, out_torch, atol=0.01, rtol=0.01) + + +def torch_downcast_to_mxfp8_2d( + x: torch.Tensor, dtype: torch.dtype, SCALE_ROUNDING_MODE: int = 0 +): + """Reference implementation for the 2D 32x32-block mxfp8 downcast.""" + x = x.to(torch.float32) + orig_shape = x.shape + M = orig_shape[-2] + N = orig_shape[-1] + pad_m = (-M) % 32 + pad_n = (-N) % 32 + Mp = M + pad_m + Np = N + pad_n + + x_padded = torch.zeros(orig_shape[:-2] + (Mp, Np), dtype=x.dtype, device=x.device) + x_padded[..., :M, :N] = x + abs_padded = torch.full( + orig_shape[:-2] + (Mp, Np), -1.0, dtype=x.dtype, device=x.device + ) + abs_padded[..., :M, :N] = torch.abs(x) + + M_s = Mp // 32 + N_s = Np // 32 + leading = orig_shape[:-2] + + abs_blocked = abs_padded.reshape(*leading, M_s, 32, N_s, 32) + # max over the two 32-axes -> (..., M_s, N_s) + scale = abs_blocked.amax(dim=-1).amax(dim=-2) + scale = scale / get_max_quant_val(dtype) + if SCALE_ROUNDING_MODE == 0: + scale = (scale.view(torch.int32) + 0x007FFFFF) & 0x7F800000 + else: + scale = scale.view(torch.int32) & 0x7F800000 + scale = scale.view(torch.float32) + + scale_inv = torch.where(scale == 0.0, 0.0, 1.0 / scale) + # broadcast over the inner 32x32 block + scale_inv_b = scale_inv.unsqueeze(-1).unsqueeze(-3) + + x_blocked = x_padded.reshape(*leading, M_s, 32, N_s, 32) + x_blocked = x_blocked * scale_inv_b + x_padded = x_blocked.reshape(*leading, Mp, Np) + x = x_padded[..., :M, :N].contiguous() + # Triton's fp32 -> fp8 cast saturates on overflow; torch's .to(fp8_e4m3fn) + # turns too large values into NaN. ROUND_DOWN can push the per-block max + # over fp8's representable range, so clamp here to match Triton's behavior. + max_val = get_max_quant_val(dtype) + x = x.clamp_(-max_val, max_val) + x = x.to(dtype).to(torch.float32) + return x, scale + + +def torch_upcast_from_mxfp8_2d( + x: torch.Tensor, scale: torch.Tensor, dtype: torch.dtype +): + """Reference implementation for the 2D 32x32-block mxfp8 upcast.""" + x = x.to(torch.float32) + orig_shape = x.shape + M = orig_shape[-2] + N = orig_shape[-1] + pad_m = (-M) % 32 + pad_n = (-N) % 32 + Mp = M + pad_m + Np = N + pad_n + + x_padded = torch.zeros(orig_shape[:-2] + (Mp, Np), dtype=x.dtype, device=x.device) + x_padded[..., :M, :N] = x + + M_s = Mp // 32 + N_s = Np // 32 + leading = orig_shape[:-2] + + scale_f = upcast_scale(scale) + scale_b = scale_f.unsqueeze(-1).unsqueeze(-3) + x_blocked = x_padded.reshape(*leading, M_s, 32, N_s, 32) + x_blocked = x_blocked * scale_b + x_padded = x_blocked.reshape(*leading, Mp, Np) + x = x_padded[..., :M, :N].contiguous() + x = x.to(dtype) + return x + + +@pytest.mark.parametrize( + "shape", + [ + (32, 32), + (32, 64), + (64, 32), + (64, 64), + (96, 128), + (128, 128), + (128, 256), + (256, 256), + (40, 50), + (68, 68), + (4, 20), + (1, 32), + (32, 1), + (2, 64, 64), + (3, 96, 128), + (2, 3, 64, 64), + ], +) +@pytest.mark.parametrize("in_dtype", [torch.bfloat16, torch.float32]) +@pytest.mark.parametrize("out_dtype", [torch.float8_e4m3fn, torch.float8_e5m2]) +@pytest.mark.parametrize("SCALE_ROUNDING_MODE", [0, 1]) +def test_downcast_to_mxfp8_2d(shape, in_dtype, out_dtype, SCALE_ROUNDING_MODE): + torch.cuda.empty_cache() + torch.manual_seed(20) + x = torch.randn(*shape, dtype=in_dtype, device="cuda") + + out_torch, out_scale_torch = torch_downcast_to_mxfp8_2d( + x, out_dtype, SCALE_ROUNDING_MODE + ) + out_triton, out_scale_triton = downcast_to_mxfp8_2d( + x, out_dtype, SCALE_ROUNDING_MODE + ) + out_triton = out_triton.to(torch.float32) + out_scale_triton = upcast_scale(out_scale_triton) + + torch.testing.assert_close(out_triton, out_torch, atol=0.01, rtol=0.01) + torch.testing.assert_close(out_scale_triton, out_scale_torch, atol=0.01, rtol=0.01) + + +@pytest.mark.parametrize( + "shape", + [ + (32, 32), + (32, 64), + (64, 32), + (64, 64), + (96, 128), + (128, 128), + (128, 256), + (256, 256), + (40, 50), + (68, 68), + (4, 20), + (1, 32), + (32, 1), + (2, 64, 64), + (3, 96, 128), + (2, 3, 64, 64), + ], +) +@pytest.mark.parametrize("in_dtype", [torch.float8_e4m3fn, torch.float8_e5m2]) +@pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float32]) +def test_upcast_from_mxfp8_2d(shape, in_dtype, out_dtype): + torch.cuda.empty_cache() + torch.manual_seed(20) + x = torch.randn(*shape, dtype=out_dtype, device="cuda") + x, x_scale = downcast_to_mxfp8_2d(x, in_dtype) + out_triton = upcast_from_mxfp8_2d(x, x_scale, out_dtype) + out_torch = torch_upcast_from_mxfp8_2d(x, x_scale, out_dtype) + + torch.testing.assert_close(out_triton, out_torch, atol=0.01, rtol=0.01) From f4c9101851ac8b31dea569f352fac59343014ed5 Mon Sep 17 00:00:00 2001 From: Lukasz Burzawa Date: Sat, 16 May 2026 12:55:41 -0700 Subject: [PATCH 8/8] Update aiter/ops/flydsl/kernels/attn_bwd_mxfp8_gfx950.py Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- aiter/ops/flydsl/kernels/attn_bwd_mxfp8_gfx950.py | 1 - 1 file changed, 1 deletion(-) diff --git a/aiter/ops/flydsl/kernels/attn_bwd_mxfp8_gfx950.py b/aiter/ops/flydsl/kernels/attn_bwd_mxfp8_gfx950.py index 191c164451..3ae939e4a5 100644 --- a/aiter/ops/flydsl/kernels/attn_bwd_mxfp8_gfx950.py +++ b/aiter/ops/flydsl/kernels/attn_bwd_mxfp8_gfx950.py @@ -876,7 +876,6 @@ def compute_qk(lds_a_buffer, a_scales, lds_b_buffer, b_scales): a2 = a3 = fx.Int64(0) a128 = pack_i64x4_to_i32x8(a0, a1, a2, a3) - lds_a_scale_row = lds_a_row # a_scale = lds_scale_load( # lds_a_scale_row, lds_scale_col, tile_head_mx, lds_a_scale_buffer # )