diff --git a/aiter/ops/flydsl/kernels/attn_bwd_mxfp8_gfx950.py b/aiter/ops/flydsl/kernels/attn_bwd_mxfp8_gfx950.py new file mode 100644 index 0000000000..3ae939e4a5 --- /dev/null +++ b/aiter/ops/flydsl/kernels/attn_bwd_mxfp8_gfx950.py @@ -0,0 +1,2107 @@ +"""Attn bwd kernel using the @flyc.kernel API.""" + +import flydsl.compiler as flyc +import flydsl.expr as fx +from flydsl.compiler.kernel_function import CompilationContext + +from flydsl.expr import range_constexpr +from flydsl.runtime.device import get_rocm_arch as get_hip_arch +from flydsl.utils.smem_allocator import SmemAllocator, SmemPtr + +from flydsl._mlir import ir +from flydsl._mlir.dialects import scf + +from flydsl.expr import arith, vector, math as fx_math, const_expr +from flydsl.expr import gpu +from flydsl.expr import buffer_ops, rocdl +from flydsl.expr.typing import T + +from aiter.ops.flydsl.kernels.mfma_preshuffle_pipeline import ( + buffer_copy_gmem16_dwordx4, + tile_chunk_coord_i32, + swizzle_xor16, +) + + +def lds_transpose_load(lds_memref, elem_offset): + """Transpose-load from LDS memref via ds_read_tr8_b64 (gfx950). + + Args: + lds_memref: LDS memref value (address-space 3), typically from + ``SmemPtr.get()`` or ``get_op_result_or_value(...)``. + elem_offset: Per-lane linearized element offset into the memref + (ArithValue / ir.Value of index type / Python int). + + Returns: + Loaded and transposed vector ``ir.Value``. + """ + from flydsl._mlir.dialects import llvm, memref + from flydsl.expr.arith import _to_raw + from flydsl.expr.utils.arith import ArithValue as AV + + lds_ptr_ty = ir.Type.parse("!llvm.ptr<3>") + raw_memref = arith.unwrap(lds_memref) + lds_base = memref.extract_aligned_pointer_as_index(raw_memref) + + byte_off = AV(arith.unwrap(elem_offset, index=True)) + total_byte_idx = AV(lds_base) + byte_off + addr_i32 = _to_raw(arith.index_cast(T.i32, total_byte_idx)) + ptr_val = llvm.inttoptr(lds_ptr_ty, addr_i32) + + result_type = T.i32x2 + result = llvm.call_intrinsic( + result_type, "llvm.amdgcn.ds.read.tr8.b64", [ptr_val], [], [] + ) + return result + + +def compile_attn_bwd_mxfp8_gfx950( + *, + num_heads_q: int, + num_heads_kv: int, + seqlen: int, + head_dim: int, + tile_m: int, + tile_n: int, + tile_head: int, + sm_scale: float, + causal: bool = False, + waves_per_eu: int = None, +): + """Compile the attention backward mx8 kernel using the @flyc.kernel API. + + Returns a JitFunction that auto-compiles and executes when called. + Compile-time constants: seqlen, head_dim, tile_m/n/head + Runtime parameters: batch + + """ + + elem_bytes = 1 + tile_head_mx = tile_head // 32 + tile_m_mx = tile_m // 32 + tile_n_mx = tile_n // 32 + gqa_size = num_heads_q // num_heads_kv + seqlen_rounded = ((seqlen + tile_m - 1) // tile_m) * tile_m + + gpu_arch = get_hip_arch() + + allocator_pong = SmemAllocator(None, arch=gpu_arch, global_sym_name="smem0") + allocator_ping = SmemAllocator(None, arch=gpu_arch, global_sym_name="smem1") + allocator_k = SmemAllocator(None, arch=gpu_arch, global_sym_name="smem_k") + allocator_v = SmemAllocator(None, arch=gpu_arch, global_sym_name="smem_v") + allocator_v_scale = SmemAllocator( + None, arch=gpu_arch, global_sym_name="smem_v_scale" + ) + allocator_ppt_shuffle = SmemAllocator( + None, arch=gpu_arch, global_sym_name="smem_ppt_shuffle" + ) + allocator_ppt_scale_shuffle = SmemAllocator( + None, arch=gpu_arch, global_sym_name="smem_ppt_scale_shuffle" + ) + allocator_dst_shuffle = SmemAllocator( + None, arch=gpu_arch, global_sym_name="smem_dst_shuffle" + ) + allocator_dst_scale_shuffle = SmemAllocator( + None, arch=gpu_arch, global_sym_name="smem_dst_scale_shuffle" + ) + allocator_ds_shuffle = SmemAllocator( + None, arch=gpu_arch, global_sym_name="smem_ds_shuffle" + ) + allocator_ds_scale_shuffle = SmemAllocator( + None, arch=gpu_arch, global_sym_name="smem_ds_scale_shuffle" + ) + + wave_size = 64 + total_threads = 256 + + bytes_per_tile_qo = int(tile_m) * int(tile_head) + bytes_per_thread_qo = bytes_per_tile_qo // total_threads + qo_load_bytes = 16 + + bytes_per_tile_kv = int(tile_n) * int(tile_head) + bytes_per_thread_kv = bytes_per_tile_kv // total_threads + kv_load_bytes = 16 + + bytes_per_tile_qo_scale = (int(tile_m) * int(tile_head)) // 32 + bytes_per_thread_qo_scale = max(1, bytes_per_tile_qo_scale // total_threads) + + bytes_per_tile_kv_scale = (int(tile_n) * int(tile_head)) // 32 + bytes_per_thread_kv_scale = max(1, bytes_per_tile_kv_scale // total_threads) + + def _elem_type(): + return T.f8 + + def _vec16_type(): + return T.f8x16 + + # ── LDS sizing (pure Python, no MLIR ops) ──────────────────────────────── + lds_qo_tile_bytes = int(tile_m) * int(tile_head) + lds_k_tile_bytes = int(tile_n) * int(tile_head) + lds_v_tile_bytes = int(tile_n) * int(tile_head) + lds_v_scale_tile_bytes = int(tile_n) * int(tile_head_mx) + lds_ppt_tile_bytes = int(tile_n) * int(tile_m) + lds_ppt_scale_tile_bytes = int(tile_n) * int(tile_m_mx) + lds_dst_tile_bytes = int(tile_n) * int(tile_m) + lds_dst_scale_tile_bytes = int(tile_n) * int(tile_m_mx) + lds_ds_tile_bytes = int(tile_m) * int(tile_n) + lds_ds_scale_tile_bytes = int(tile_m) * int(tile_n_mx) + + buffer_size_bytes = lds_qo_tile_bytes * 2 # + lds_qo_scale_tile_bytes * 4 + + lds_pong_offset = allocator_pong._align(allocator_pong.ptr, 16) + allocator_pong.ptr = lds_pong_offset + buffer_size_bytes + lds_q_pong_offset = lds_pong_offset + lds_do_pong_offset = lds_q_pong_offset + lds_qo_tile_bytes + + lds_ping_offset = allocator_ping._align(allocator_ping.ptr, 16) + allocator_ping.ptr = lds_ping_offset + buffer_size_bytes + lds_q_ping_offset = lds_ping_offset + lds_do_ping_offset = lds_q_ping_offset + lds_qo_tile_bytes + + lds_k_offset = allocator_k._align(allocator_k.ptr, 16) + allocator_k.ptr = lds_k_offset + lds_k_tile_bytes + + lds_v_offset = allocator_v._align(allocator_v.ptr, 16) + allocator_v.ptr = lds_v_offset + lds_v_tile_bytes + + lds_v_scale_offset = allocator_v_scale._align(allocator_v_scale.ptr, 16) + allocator_v_scale.ptr = lds_v_scale_offset + lds_v_scale_tile_bytes + + lds_ppt_shuffle_offset = allocator_ppt_shuffle._align(allocator_ppt_shuffle.ptr, 16) + allocator_ppt_shuffle.ptr = lds_ppt_shuffle_offset + lds_ppt_tile_bytes + + lds_ppt_scale_shuffle_offset = allocator_ppt_scale_shuffle._align( + allocator_ppt_scale_shuffle.ptr, 16 + ) + allocator_ppt_scale_shuffle.ptr = ( + lds_ppt_scale_shuffle_offset + lds_ppt_scale_tile_bytes + ) + + lds_dst_shuffle_offset = allocator_dst_shuffle._align(allocator_dst_shuffle.ptr, 16) + allocator_dst_shuffle.ptr = lds_dst_shuffle_offset + lds_dst_tile_bytes + + lds_dst_scale_shuffle_offset = allocator_dst_scale_shuffle._align( + allocator_dst_scale_shuffle.ptr, 16 + ) + allocator_dst_scale_shuffle.ptr = ( + lds_dst_scale_shuffle_offset + lds_dst_scale_tile_bytes + ) + + lds_ds_shuffle_offset = allocator_ds_shuffle._align(allocator_ds_shuffle.ptr, 16) + allocator_ds_shuffle.ptr = lds_ds_shuffle_offset + lds_ds_tile_bytes + + lds_ds_scale_shuffle_offset = allocator_ds_scale_shuffle._align( + allocator_ds_scale_shuffle.ptr, 16 + ) + allocator_ds_scale_shuffle.ptr = ( + lds_ds_scale_shuffle_offset + lds_ds_scale_tile_bytes + ) + + # ── Kernel function ──────────────────────────────────────────────────── + @flyc.kernel + def kernel_attn_bwd( + arg_dq: fx.Tensor, + arg_dk: fx.Tensor, + arg_dv: fx.Tensor, + arg_q: fx.Tensor, + arg_q_scale: fx.Tensor, + arg_k: fx.Tensor, + arg_k_scale: fx.Tensor, + arg_v: fx.Tensor, + arg_v_scale: fx.Tensor, + arg_do: fx.Tensor, + arg_do_scale: fx.Tensor, + arg_M: fx.Tensor, + arg_D: fx.Tensor, + batch: fx.Int32, + stride_qo_batch: fx.Int32, + stride_kv_batch: fx.Int32, + stride_MD_batch: fx.Int32, + stride_qkvo_nheads: fx.Int32, + stride_MD_nheads: fx.Int32, + stride_q_scale_batch: fx.Int32, + stride_q_scale_nheads: fx.Int32, + stride_k_scale_batch: fx.Int32, + stride_k_scale_nheads: fx.Int32, + stride_v_scale_batch: fx.Int32, + stride_v_scale_nheads: fx.Int32, + stride_do_scale_batch: fx.Int32, + stride_do_scale_nheads: fx.Int32, + ): + + # ---- Types ---- + zero_f = arith.constant(0.0, type=T.f32) + acc_init = arith.constant_vector(0.0, T.f32x4) + log2e = arith.constant(1.4426950408889634, type=T.f32) + c_sm_scale = arith.constant(sm_scale, type=T.f32) + fp8_max_rcp = arith.constant(1.0 / 448.0, type=T.f32) + + tx = gpu.thread_id("x") + bx = gpu.block_id("x") + by = gpu.block_id("y") + bz = gpu.block_id("z") + batch_id = bz + head_q = bx + head_kv = head_q // gqa_size + + # ---- LDS (separate ping/pong buffers) ---- + base_ptr_pong = allocator_pong.get_base() + base_ptr_ping = allocator_ping.get_base() + base_ptr_k = allocator_k.get_base() + base_ptr_v = allocator_v.get_base() + base_ptr_v_scale = allocator_v_scale.get_base() + base_ptr_ppt_shuffle = allocator_ppt_shuffle.get_base() + base_ptr_ppt_scale_shuffle = allocator_ppt_scale_shuffle.get_base() + base_ptr_dst_shuffle = allocator_dst_shuffle.get_base() + base_ptr_dst_scale_shuffle = allocator_dst_scale_shuffle.get_base() + base_ptr_ds_shuffle = allocator_ds_shuffle.get_base() + base_ptr_ds_scale_shuffle = allocator_ds_scale_shuffle.get_base() + + lds_q_pong = SmemPtr( + base_ptr_pong, + lds_q_pong_offset, + T.f8, + shape=(tile_m * tile_head,), + ).get() + lds_q_ping = SmemPtr( + base_ptr_ping, + lds_q_ping_offset, + T.f8, + shape=(tile_m * tile_head,), + ).get() + lds_do_pong = SmemPtr( + base_ptr_pong, + lds_do_pong_offset, + T.f8, + shape=(tile_m * tile_head,), + ).get() + lds_do_ping = SmemPtr( + base_ptr_ping, + lds_do_ping_offset, + T.f8, + shape=(tile_m * tile_head,), + ).get() + lds_k = SmemPtr( + base_ptr_k, + lds_k_offset, + T.f8, + shape=(tile_n * tile_head,), + ).get() + lds_v = SmemPtr( + base_ptr_v, lds_v_offset, T.f8, shape=(tile_n * tile_head,) + ).get() + lds_v_scale = SmemPtr( + base_ptr_v_scale, lds_v_scale_offset, T.i8, shape=(tile_n * tile_head_mx,) + ).get() + lds_ppt_shuffle = SmemPtr( + base_ptr_ppt_shuffle, lds_ppt_shuffle_offset, T.f8, shape=(tile_n * tile_m,) + ).get() + lds_ppt_scale_shuffle = SmemPtr( + base_ptr_ppt_scale_shuffle, + lds_ppt_scale_shuffle_offset, + T.i8, + shape=(tile_n * tile_m_mx,), + ).get() + lds_dst_shuffle = SmemPtr( + base_ptr_dst_shuffle, lds_dst_shuffle_offset, T.f8, shape=(tile_n * tile_m,) + ).get() + lds_dst_scale_shuffle = SmemPtr( + base_ptr_dst_scale_shuffle, + lds_dst_scale_shuffle_offset, + T.i8, + shape=(tile_n * tile_m_mx,), + ).get() + lds_ds_shuffle = SmemPtr( + base_ptr_ds_shuffle, lds_ds_shuffle_offset, T.f8, shape=(tile_m * tile_n,) + ).get() + lds_ds_scale_shuffle = SmemPtr( + base_ptr_ds_scale_shuffle, + lds_ds_scale_shuffle_offset, + T.i8, + shape=(tile_m * tile_n_mx,), + ).get() + + offset_qo_nheads = batch_id * fx.Index(stride_qo_batch) + head_q * fx.Index( + stride_qkvo_nheads + ) + offset_dq_nheads = offset_qo_nheads * 4 + offset_kv_nheads = batch_id * fx.Index(stride_kv_batch) + head_kv * fx.Index( + stride_qkvo_nheads + ) + offset_dkdv_nheads = offset_kv_nheads * 4 + offset_q_scale_nheads = batch_id * fx.Index( + stride_q_scale_batch + ) + head_q * fx.Index(stride_q_scale_nheads) + offset_k_scale_nheads = batch_id * fx.Index( + stride_k_scale_batch + ) + head_kv * fx.Index(stride_k_scale_nheads) + offset_v_scale_nheads = batch_id * fx.Index( + stride_v_scale_batch + ) + head_kv * fx.Index(stride_v_scale_nheads) + offset_do_scale_nheads = batch_id * fx.Index( + stride_do_scale_batch + ) + head_q * fx.Index(stride_do_scale_nheads) + offset_MD_nheads = ( + batch_id * fx.Index(stride_MD_batch) + head_q * fx.Index(stride_MD_nheads) + ) * 4 + + # ---- Buffer resources (runtime byte sizes for OOB protection) ---- + head_dim_mx = head_dim // 32 + seqlen_mx = seqlen // 32 + global_buffer_size_tensor = fx.Index(seqlen * head_dim) + global_buffer_size_scale = fx.Index(seqlen * head_dim_mx) + global_buffer_size_scale_2d = fx.Index(seqlen_mx * head_dim_mx) + q_nrec = arith.index_cast(T.i64, global_buffer_size_tensor) + q_scale_nrec = arith.index_cast(T.i64, global_buffer_size_scale_2d) + k_nrec = arith.index_cast(T.i64, global_buffer_size_tensor) + k_scale_nrec = arith.index_cast(T.i64, global_buffer_size_scale_2d) + v_nrec = arith.index_cast(T.i64, global_buffer_size_tensor) + v_scale_nrec = arith.index_cast(T.i64, global_buffer_size_scale) + do_nrec = arith.index_cast(T.i64, global_buffer_size_tensor) + do_scale_nrec = arith.index_cast(T.i64, global_buffer_size_scale_2d) + output_nrec = arith.index_cast(T.i64, global_buffer_size_tensor * 4) + MD_nrec = arith.index_cast(T.i64, fx.Index(seqlen * 4)) + + q_rsrc = buffer_ops.create_buffer_resource( + arg_q, + max_size=False, + num_records_bytes=q_nrec, + base_byte_offset=offset_qo_nheads, + ) + q_scale_rsrc = buffer_ops.create_buffer_resource( + arg_q_scale, + max_size=False, + num_records_bytes=q_scale_nrec, + base_byte_offset=offset_q_scale_nheads, + ) + k_rsrc = buffer_ops.create_buffer_resource( + arg_k, + max_size=False, + num_records_bytes=k_nrec, + base_byte_offset=offset_kv_nheads, + ) + k_scale_rsrc = buffer_ops.create_buffer_resource( + arg_k_scale, + max_size=False, + num_records_bytes=k_scale_nrec, + base_byte_offset=offset_k_scale_nheads, + ) + v_rsrc = buffer_ops.create_buffer_resource( + arg_v, + max_size=False, + num_records_bytes=v_nrec, + base_byte_offset=offset_kv_nheads, + ) + v_scale_rsrc = buffer_ops.create_buffer_resource( + arg_v_scale, + max_size=False, + num_records_bytes=v_scale_nrec, + base_byte_offset=offset_v_scale_nheads, + ) + do_rsrc = buffer_ops.create_buffer_resource( + arg_do, + max_size=False, + num_records_bytes=do_nrec, + base_byte_offset=offset_qo_nheads, + ) + do_scale_rsrc = buffer_ops.create_buffer_resource( + arg_do_scale, + max_size=False, + num_records_bytes=do_scale_nrec, + base_byte_offset=offset_do_scale_nheads, + ) + dq_rsrc = buffer_ops.create_buffer_resource( + arg_dq, + max_size=False, + num_records_bytes=output_nrec, + base_byte_offset=offset_dq_nheads, + ) + dk_rsrc = buffer_ops.create_buffer_resource( + arg_dk, + max_size=False, + num_records_bytes=output_nrec, + base_byte_offset=offset_dkdv_nheads, + ) + dv_rsrc = buffer_ops.create_buffer_resource( + arg_dv, + max_size=False, + num_records_bytes=output_nrec, + base_byte_offset=offset_dkdv_nheads, + ) + M_rsrc = buffer_ops.create_buffer_resource( + arg_M, + max_size=False, + num_records_bytes=MD_nrec, + base_byte_offset=offset_MD_nheads, + ) + D_rsrc = buffer_ops.create_buffer_resource( + arg_D, + max_size=False, + num_records_bytes=MD_nrec, + base_byte_offset=offset_MD_nheads, + ) + + global_offset_n = by * tile_n + global_offset_n_mx = global_offset_n // 32 + + # ---- Wave / lane decomposition ---- + layout_wave_lane = fx.make_layout((4, wave_size), (64, 1)) + coord_wave_lane = fx.idx2crd(tx, layout_wave_lane) + wave_id = fx.get(coord_wave_lane, 0) + lane_id = fx.get(coord_wave_lane, 1) + + layout_lane16 = fx.make_layout((4, 16), (16, 1)) + coord_lane16 = fx.idx2crd(lane_id, layout_lane16) + lane_div_16 = fx.get(coord_lane16, 0) + lane_mod_16 = fx.get(coord_lane16, 1) + + layout_lane2 = fx.make_layout((8, 2), (2, 1)) + coord_lane2 = fx.idx2crd(lane_mod_16, layout_lane2) + lane_div_2 = fx.get(coord_lane2, 0) + lane_mod_2 = fx.get(coord_lane2, 1) + + # wave partitioning for qk, p, dp, ds + ps_m_num_waves = 2 + ps_n_num_waves = 2 + ps_wave_layout = fx.make_layout( + (ps_m_num_waves, ps_n_num_waves), (ps_n_num_waves, 1) + ) + ps_coord = fx.idx2crd(wave_id, ps_wave_layout) + ps_m_wave_id = fx.get(ps_coord, 0) + ps_n_wave_id = fx.get(ps_coord, 1) + ps_m_per_wave = tile_m // ps_m_num_waves + ps_m_mx_per_wave = tile_m_mx // ps_m_num_waves + ps_m_num_subtiles = ps_m_per_wave // 16 + ps_n_per_wave = tile_n // ps_n_num_waves + ps_n_mx_per_wave = tile_n_mx // ps_n_num_waves + ps_n_num_subtiles = ps_n_per_wave // 16 + ps_n_accs = ps_n_num_subtiles * ps_m_num_subtiles + + # wave partitioning for dv gemm + dv_n_num_waves = 2 + dv_head_num_waves = 2 + dv_wave_layout = fx.make_layout( + (dv_n_num_waves, dv_head_num_waves), (dv_head_num_waves, 1) + ) + dv_coord = fx.idx2crd(wave_id, dv_wave_layout) + dv_n_wave_id = fx.get(dv_coord, 0) + dv_head_wave_id = fx.get(dv_coord, 1) + dv_n_per_wave = tile_n // dv_n_num_waves + dv_n_num_subtiles = dv_n_per_wave // 16 + dv_head_per_wave = tile_head // dv_head_num_waves + dv_head_mx_per_wave = tile_head_mx // dv_head_num_waves + dv_head_num_subtiles = dv_head_per_wave // 16 + dv_n_accs = dv_n_num_subtiles * dv_head_num_subtiles + + # wave partitioning for dk gemm + dk_n_num_waves = 2 + dk_head_num_waves = 2 + dk_wave_layout = fx.make_layout( + (dk_n_num_waves, dk_head_num_waves), (dk_head_num_waves, 1) + ) + dk_coord = fx.idx2crd(wave_id, dk_wave_layout) + dk_n_wave_id = fx.get(dk_coord, 0) + dk_head_wave_id = fx.get(dk_coord, 1) + dk_n_per_wave = tile_n // dk_n_num_waves + dk_num_subtiles_n = dk_n_per_wave // 16 + dk_head_per_wave = tile_head // dk_head_num_waves + dk_head_mx_per_wave = tile_head_mx // dk_head_num_waves + dk_num_subtiles_head = dk_head_per_wave // 16 + dk_n_accs = dk_num_subtiles_n * dk_num_subtiles_head + + # wave partitioning for dq gemm + dq_m_num_waves = 2 + dq_head_num_waves = 2 + dq_wave_layout = fx.make_layout( + (dq_m_num_waves, dq_head_num_waves), (dq_head_num_waves, 1) + ) + dq_coord = fx.idx2crd(wave_id, dq_wave_layout) + dq_m_wave_id = fx.get(dq_coord, 0) + dq_head_wave_id = fx.get(dq_coord, 1) + dq_m_per_wave = tile_m // dq_m_num_waves + dq_num_subtiles_m = dq_m_per_wave // 16 + dq_head_per_wave = tile_head // dq_head_num_waves + dq_head_mx_per_wave = tile_head_mx // dq_head_num_waves + dq_num_subtiles_head = dq_head_per_wave // 16 + dq_n_accs = dq_num_subtiles_m * dq_num_subtiles_head + + # ── A LDS load helpers ── + + def lds_load_16b(curr_row_lds, col_base, lds_stride, lds_buffer, swizzle=16): + if swizzle == 16: + col_base = swizzle_xor16(curr_row_lds, col_base, lds_stride // swizzle) + idx = curr_row_lds * lds_stride + col_base + return vector.load_op(_vec16_type(), lds_buffer, [idx]) + + def lds_load_8b_transposed( + curr_row_lds, col_base, lds_stride, lds_buffer, swizzle=16 + ): + if swizzle == 16: + col_base = swizzle_xor16(curr_row_lds, col_base, lds_stride // swizzle) + col_base = col_base + lane_mod_2 * 8 + idx = curr_row_lds * lds_stride + col_base + return lds_transpose_load(lds_buffer, idx) + + def lds_load_packs_k64( + curr_row_lds, col_base, lds_stride, lds_buffer, swizzle=16 + ): + vec = lds_load_16b(curr_row_lds, col_base, lds_stride, lds_buffer, swizzle) + vec = vector.bitcast(T.i64x2, vec) + val0_i64 = vector.extract(vec, static_position=[0], dynamic_position=[]) + val1_i64 = vector.extract(vec, static_position=[1], dynamic_position=[]) + return val0_i64, val1_i64 + + def lds_load_packs_k32_transposed( + curr_row_lds, col_base, lds_stride, lds_buffer, swizzle=16 + ): + vec = lds_load_8b_transposed( + curr_row_lds, col_base, lds_stride, lds_buffer, swizzle + ) + vec = vector.bitcast(T.vec(1, T.i64), vec) + val_i64 = vector.extract(vec, static_position=[0], dynamic_position=[]) + return val_i64 + + def lds_scale_load(row, col, lds_stride, lds_buffer): + idx = row * lds_stride + col + vec = vector.load_op(T.vec(1, T.i8), lds_buffer, [idx]) + val = vector.extract(vec, static_position=[0], dynamic_position=[]) + val = val.extui(T.i32) + return val + + # ── A global→reg load ───────────────────────────────────────────── + head_dim_div4 = head_dim // 4 + tile_m_div16 = tile_m // 16 + tile_head_div16 = arith.index(tile_head // 16) + num_qo_loads = bytes_per_thread_qo // qo_load_bytes + num_kv_loads = bytes_per_thread_kv // kv_load_bytes + tile_head_dwords = tile_head // 4 + layout_qo_tile_div4 = fx.make_layout( + (tile_m, tile_head_dwords), (tile_head_dwords, 1) + ) + layout_kv_tile_div4 = fx.make_layout( + (tile_n, tile_head_dwords), (tile_head_dwords, 1) + ) + c4 = fx.Index(4) + tx_i32_base = tx * c4 + + def load_q_16(idx_elem): + return buffer_copy_gmem16_dwordx4( + buffer_ops, + vector, + elem_type=_elem_type(), + idx_i32=idx_elem, + rsrc=q_rsrc, + vec_elems=16, + elem_bytes=elem_bytes, + ) + + def load_k_16(idx_elem): + return buffer_copy_gmem16_dwordx4( + buffer_ops, + vector, + elem_type=_elem_type(), + idx_i32=idx_elem, + rsrc=k_rsrc, + vec_elems=16, + elem_bytes=elem_bytes, + ) + + def load_v_16(idx_elem): + return buffer_copy_gmem16_dwordx4( + buffer_ops, + vector, + elem_type=_elem_type(), + idx_i32=idx_elem, + rsrc=v_rsrc, + vec_elems=16, + elem_bytes=elem_bytes, + ) + + def load_do_16(idx_elem): + return buffer_copy_gmem16_dwordx4( + buffer_ops, + vector, + elem_type=_elem_type(), + idx_i32=idx_elem, + rsrc=do_rsrc, + vec_elems=16, + elem_bytes=elem_bytes, + ) + + def qo_tile_chunk_coord_i32(i: int): + return tile_chunk_coord_i32( + arith, + tx_i32_base=tx_i32_base, + i=i, + total_threads=total_threads, + layout_tile_div4=layout_qo_tile_div4, + ) + + def kv_tile_chunk_coord_i32(i: int): + return tile_chunk_coord_i32( + arith, + tx_i32_base=tx_i32_base, + i=i, + total_threads=total_threads, + layout_tile_div4=layout_kv_tile_div4, + ) + + def prefetch_q_tile(offset_m): + parts = [] + for i in range_constexpr(num_qo_loads): + row_q_local, col_q_local_i32 = qo_tile_chunk_coord_i32(i) + row_q_global = offset_m + row_q_local + idx_elem = row_q_global * head_dim_div4 + col_q_local_i32 + q_16B = load_q_16(idx_elem) + parts.append(vector.bitcast(T.i32x4, q_16B)) + return parts + + def prefetch_k_tile(): + parts = [] + for i in range_constexpr(num_kv_loads): + row_k_local, col_k_local_i32 = kv_tile_chunk_coord_i32(i) + row_k_global = global_offset_n + row_k_local + idx_elem = row_k_global * head_dim_div4 + col_k_local_i32 + k_16B = load_k_16(idx_elem) + parts.append(vector.bitcast(T.i32x4, k_16B)) + return parts + + def prefetch_v_tile(): + parts = [] + for i in range_constexpr(num_kv_loads): + row_v_local, col_v_local_i32 = kv_tile_chunk_coord_i32(i) + row_v_global = global_offset_n + row_v_local + idx_elem = row_v_global * head_dim_div4 + col_v_local_i32 + v_16B = load_v_16(idx_elem) + parts.append(vector.bitcast(T.i32x4, v_16B)) + return parts + + def prefetch_do_tile(offset_m): + parts = [] + for i in range_constexpr(num_qo_loads): + row_do_local, col_do_local_i32 = qo_tile_chunk_coord_i32(i) + row_do_global = offset_m + row_do_local + idx_elem = row_do_global * head_dim_div4 + col_do_local_i32 + do_16B = load_do_16(idx_elem) + parts.append(vector.bitcast(T.i32x4, do_16B)) + return parts + + def prefetch_q_scale_head_2d_tile(offset_m): + parts = [] + for i in range_constexpr(ps_m_num_subtiles // 2): + global_row = offset_m + ps_m_wave_id * ps_m_mx_per_wave + i + global_col = lane_div_16 % tile_head_mx + global_idx = global_row * head_dim_mx + global_col + vec = buffer_ops.buffer_load( + q_scale_rsrc, global_idx, vec_width=1, dtype=T.i8 + ) + vec = vec.extui(T.i32) + parts.append(vec) + return parts + + def prefetch_q_scale_m_2d_tile(offset_m): + parts = [] + for i in range_constexpr(dk_num_subtiles_head // 2): + global_row = offset_m + lane_div_16 % tile_m_mx + global_col = dk_head_wave_id * dk_head_mx_per_wave + i + global_idx = global_row * head_dim_mx + global_col + vec = buffer_ops.buffer_load( + q_scale_rsrc, global_idx, vec_width=1, dtype=T.i8 + ) + vec = vec.extui(T.i32) + parts.append(vec) + return parts + + def prefetch_k_scale_head_2d_tile(): + parts = [] + for i in range_constexpr(ps_n_num_subtiles // 2): + global_row = global_offset_n_mx + ps_n_wave_id * ps_n_mx_per_wave + i + global_col = lane_div_16 % tile_head_mx + global_idx = global_row * head_dim_mx + global_col + vec = buffer_ops.buffer_load( + k_scale_rsrc, global_idx, vec_width=1, dtype=T.i8 + ) + vec = vec.extui(T.i32) + parts.append(vec) + return parts + + def prefetch_k_scale_n_2d_tile(): + parts = [] + for i in range_constexpr(dq_num_subtiles_head // 2): + global_row = global_offset_n_mx + lane_div_16 % tile_n_mx + global_col = dq_head_wave_id * dq_head_mx_per_wave + i + global_idx = global_row * head_dim_mx + global_col + vec = buffer_ops.buffer_load( + k_scale_rsrc, global_idx, vec_width=1, dtype=T.i8 + ) + vec = vec.extui(T.i32) + parts.append(vec) + return parts + + def prefetch_v_scale_tile(): + vec_width = bytes_per_thread_kv_scale + if const_expr(vec_width == 1): + if const_expr(bytes_per_tile_kv_scale < total_threads): + idx_elem = ( + global_offset_n * head_dim_mx + tx % bytes_per_tile_kv_scale + ) + else: + idx_elem = global_offset_n * head_dim_mx + tx + vec = buffer_ops.buffer_load( + v_scale_rsrc, idx_elem, vec_width=1, dtype=T.i8 + ) + vec = vector.from_elements(T.vec(1, T.i8), [vec]) + else: # vec_width=2 + idx_elem = (global_offset_n * head_dim_mx + tx * vec_width) // 2 + vec = buffer_ops.buffer_load( + v_scale_rsrc, idx_elem, vec_width=1, dtype=T.i16 + ) + vec = vector.from_elements(T.vec(1, T.i16), [vec]) + vec = vector.bitcast(T.i8x2, vec) + return vec + + def prefetch_do_scale_head_2d_tile(offset_m): + parts = [] + for i in range_constexpr(ps_m_num_subtiles // 2): + global_row = offset_m + ps_m_wave_id * ps_m_mx_per_wave + i + global_col = lane_div_16 % tile_head_mx + global_idx = global_row * head_dim_mx + global_col + vec = buffer_ops.buffer_load( + do_scale_rsrc, global_idx, vec_width=1, dtype=T.i8 + ) + vec = vec.extui(T.i32) + parts.append(vec) + return parts + + def prefetch_do_scale_m_2d_tile(offset_m): + parts = [] + for i in range_constexpr(dv_head_num_subtiles // 2): + global_row = offset_m + lane_div_16 % tile_m_mx + global_col = dv_head_wave_id * dv_head_mx_per_wave + i + global_idx = global_row * head_dim_mx + global_col + vec = buffer_ops.buffer_load( + do_scale_rsrc, global_idx, vec_width=1, dtype=T.i8 + ) + vec = vec.extui(T.i32) + parts.append(vec) + return parts + + def store_q_tile_to_lds(vec_q_parts, lds_buffer): + for i in range_constexpr(num_qo_loads): + row_q_local, col_q_local_i32 = qo_tile_chunk_coord_i32(i) + col_local_bytes = col_q_local_i32 * c4 + col_swz_bytes = swizzle_xor16( + row_q_local, col_local_bytes, tile_head_div16 + ) + col_swz = col_swz_bytes + idx0 = row_q_local * tile_head + col_swz + v16 = vector.bitcast(_vec16_type(), vec_q_parts[i]) + vector.store(v16, lds_buffer, [idx0]) + + def store_k_tile_to_lds(vec_k_parts, lds_buffer): + for i in range_constexpr(num_kv_loads): + row_k_local, col_k_local_i32 = kv_tile_chunk_coord_i32(i) + col_local_bytes = col_k_local_i32 * c4 + col_swz_bytes = swizzle_xor16( + row_k_local, col_local_bytes, tile_head_div16 + ) + col_swz = col_swz_bytes + idx0 = row_k_local * tile_head + col_swz + v16 = vector.bitcast(_vec16_type(), vec_k_parts[i]) + vector.store(v16, lds_buffer, [idx0]) + + def store_v_tile_to_lds(vec_v_parts, lds_buffer): + for i in range_constexpr(num_kv_loads): + row_v_local, col_v_local_i32 = kv_tile_chunk_coord_i32(i) + col_local_bytes = col_v_local_i32 * c4 + col_swz_bytes = swizzle_xor16( + row_v_local, col_local_bytes, tile_head_div16 + ) + col_swz = col_swz_bytes + idx0 = row_v_local * tile_head + col_swz + v16 = vector.bitcast(_vec16_type(), vec_v_parts[i]) + vector.store(v16, lds_buffer, [idx0]) + + def store_do_tile_to_lds(vec_do_parts, lds_buffer): + for i in range_constexpr(num_qo_loads): + row_do_local, col_do_local_i32 = qo_tile_chunk_coord_i32(i) + col_local_bytes = col_do_local_i32 * c4 + col_swz_bytes = swizzle_xor16( + row_do_local, col_local_bytes, tile_head_div16 + ) + col_swz = col_swz_bytes + idx0 = row_do_local * tile_head + col_swz + v16 = vector.bitcast(_vec16_type(), vec_do_parts[i]) + vector.store(v16, lds_buffer, [idx0]) + + def store_v_scale_tile_to_lds(vec_scale, lds_buffer): + vec_width = bytes_per_thread_kv_scale + idx = tx * vec_width + if total_threads > bytes_per_tile_kv_scale: + idx = idx % bytes_per_tile_kv_scale + vector.store(vec_scale, lds_buffer, [idx]) + + # ── Compute tile (MFMA) ─────────────────────────────────────────── + + def pack_i64x4_to_i32x8(x0, x1, x2, x3): + vec4_i64 = T.vec(4, T.i64) + vec8_i32 = T.vec(8, T.i32) + v4 = vector.from_elements(vec4_i64, [x0, x1, x2, x3]) + return vector.bitcast(vec8_i32, v4) + + def compute_qk(lds_a_buffer, a_scales, lds_b_buffer, b_scales): + # (m, head) @ (head, n) = (m, n) + + current_accs_list = [acc_init] * ps_n_accs + mfma_res_ty = T.f32x4 + + ku0 = 0 + ku1 = 1 + lds_col0 = ku0 * 64 + lane_div_16 * 16 + lds_col1 = ku1 * 64 + lane_div_16 * 16 + lds_scale_col = lane_div_16 + if const_expr(tile_head == 64): + lds_scale_col = lds_scale_col % 2 + + for mi in range_constexpr(ps_m_num_subtiles): + lds_a_row = ps_m_wave_id * ps_m_per_wave + mi * 16 + lane_mod_16 + a0, a1 = lds_load_packs_k64( + lds_a_row, lds_col0, tile_head, lds_a_buffer + ) + if const_expr(tile_head == 128): + a2, a3 = lds_load_packs_k64( + lds_a_row, lds_col1, tile_head, lds_a_buffer + ) + else: + a2 = a3 = fx.Int64(0) + a128 = pack_i64x4_to_i32x8(a0, a1, a2, a3) + + # a_scale = lds_scale_load( + # lds_a_scale_row, lds_scale_col, tile_head_mx, lds_a_scale_buffer + # ) + a_scale = a_scales[mi // 2] + + for ni in range_constexpr(ps_n_num_subtiles): + lds_b_row = ps_n_wave_id * ps_n_per_wave + ni * 16 + lane_mod_16 + b0, b1 = lds_load_packs_k64( + lds_b_row, lds_col0, tile_head, lds_b_buffer + ) + if const_expr(tile_head == 128): + b2, b3 = lds_load_packs_k64( + lds_b_row, lds_col1, tile_head, lds_b_buffer + ) + else: + b2 = b3 = fx.Int64(0) + b128 = pack_i64x4_to_i32x8(b0, b1, b2, b3) + + b_scale = b_scales[ni // 2] + + acc_idx = mi * ps_n_num_subtiles + ni + current_accs_list[acc_idx] = rocdl.mfma_scale_f32_16x16x128_f8f6f4( + mfma_res_ty, + [ + a128, + b128, + current_accs_list[acc_idx], + 0, + 0, + 0, + a_scale, + 0, + b_scale, + ], + ) + return current_accs_list + + def softmax(accs_in, offset_m): + # inputs are tile_m x tile_n shape + + accs_out = [acc_init] * ps_n_accs + + for mi in range_constexpr(ps_m_num_subtiles): + global_m_norm_idx = ( + offset_m + ps_m_wave_id * ps_m_per_wave + mi * 16 + lane_div_16 * 4 + ) + m_norm_vector = buffer_ops.buffer_load( + M_rsrc, global_m_norm_idx, vec_width=4 + ) + + for ni in range_constexpr(ps_n_num_subtiles): + + acc_idx = mi * ps_n_num_subtiles + ni + acc = accs_in[acc_idx] + + vals_f32 = [] + for ii in range_constexpr(4): + val_f32 = vector.extract( + acc, static_position=[ii], dynamic_position=[] + ) + m_norm = vector.extract( + m_norm_vector, static_position=[ii], dynamic_position=[] + ) + val_f32 = val_f32 * c_sm_scale + val_f32 = val_f32 - m_norm + val_f32 = val_f32 * log2e + val_f32 = rocdl.exp2(T.f32, val_f32) + if causal: + global_m = ( + offset_m + + ps_m_wave_id * ps_m_per_wave + + mi * 16 + + lane_div_16 * 4 + + ii + ) + global_n = ( + global_offset_n + + ps_n_wave_id * ps_n_per_wave + + ni * 16 + + lane_mod_16 + ) + needs_mask = arith.cmpi( + arith.CmpIPredicate.ugt, global_n, global_m + ) + mask_if = scf.IfOp(needs_mask, [T.f32], has_else=True) + with ir.InsertionPoint(mask_if.then_block): + scf.YieldOp([arith.constant(0.0, type=T.f32)]) + with ir.InsertionPoint(mask_if.else_block): + scf.YieldOp([val_f32]) + val_f32 = mask_if.results[0] + vals_f32.append(val_f32) + vals_f32_vector = vector.from_elements(T.f32x4, vals_f32) + accs_out[acc_idx] = vals_f32_vector + + return accs_out + + def compute_dv( + accs_in, lds_a_buffer, lds_a_scale_buffer, lds_b_buffer, b_scales + ): + current_accs_list = list(accs_in) + mfma_res_ty = T.f32x4 + num_subtiles_reduction = max(1, tile_m // 128) + for ku128 in range_constexpr(num_subtiles_reduction): + ku0 = ku128 * 2 + ku1 = ku0 + 1 + + lds_a_col0 = ( + ku0 * 64 + lane_div_16 * 16 + ) # 16 elements packed per lane, 64 per wave + lds_a_col1 = ku1 * 64 + lane_div_16 * 16 + + lds_b_row0 = ku0 * 64 + lane_div_16 * 16 + lane_div_2 + lds_b_row1 = ku0 * 64 + lane_div_16 * 16 + 8 + lane_div_2 + lds_b_row2 = ku1 * 64 + lane_div_16 * 16 + lane_div_2 + lds_b_row3 = ku1 * 64 + lane_div_16 * 16 + 8 + lane_div_2 + + lds_a_scale_col = lane_div_16 + lds_b_scale_row = lane_div_16 + if const_expr(tile_m == 64): + lds_a_scale_col = lds_a_scale_col % 2 + lds_b_scale_row = lds_b_scale_row % 2 + + for ni in range_constexpr(dv_n_num_subtiles): + lds_a_row = dv_n_wave_id * dv_n_per_wave + ni * 16 + lane_mod_16 + a0, a1 = lds_load_packs_k64( + lds_a_row, lds_a_col0, tile_m, lds_a_buffer + ) + if const_expr(tile_m == 128): + a2, a3 = lds_load_packs_k64( + lds_a_row, lds_a_col1, tile_m, lds_a_buffer + ) + else: + a2 = a3 = fx.Int64(0) + a128 = pack_i64x4_to_i32x8(a0, a1, a2, a3) + a_scale = lds_scale_load( + lds_a_row, lds_a_scale_col, tile_m_mx, lds_a_scale_buffer + ) + + for hi in range_constexpr(dv_head_num_subtiles): + lds_b_col = dv_head_wave_id * dv_head_per_wave + hi * 16 + b0 = lds_load_packs_k32_transposed( + lds_b_row0, lds_b_col, tile_head, lds_b_buffer + ) + b1 = lds_load_packs_k32_transposed( + lds_b_row1, lds_b_col, tile_head, lds_b_buffer + ) + if const_expr(tile_m == 128): + b2 = lds_load_packs_k32_transposed( + lds_b_row2, lds_b_col, tile_head, lds_b_buffer + ) + b3 = lds_load_packs_k32_transposed( + lds_b_row3, lds_b_col, tile_head, lds_b_buffer + ) + else: + b2 = fx.Int64(0) + b3 = fx.Int64(0) + b128 = pack_i64x4_to_i32x8(b0, b1, b2, b3) + + b_scale = b_scales[hi // 2] + + acc_idx = ni * dv_head_num_subtiles + hi + current_accs_list[acc_idx] = ( + rocdl.mfma_scale_f32_16x16x128_f8f6f4( + mfma_res_ty, + [ + a128, + b128, + current_accs_list[acc_idx], + 0, + 0, + 0, + a_scale, + 0, + b_scale, + ], + ) + ) + return current_accs_list + + def compute_dp(lds_a_buffer, a_scales, lds_b_buffer, lds_b_scale_buffer): + current_accs_list = [acc_init] * ps_n_accs + mfma_res_ty = T.f32x4 + ku0 = 0 + ku1 = 1 + lds_col0 = ku0 * 64 + lane_div_16 * 16 + lds_col1 = ku1 * 64 + lane_div_16 * 16 + lds_scale_col = lane_div_16 + if const_expr(tile_head == 64): + lds_scale_col = lds_scale_col % 2 + + for mi in range_constexpr(ps_m_num_subtiles): + lds_a_row = ps_m_wave_id * ps_m_per_wave + mi * 16 + lane_mod_16 + a0, a1 = lds_load_packs_k64( + lds_a_row, lds_col0, tile_head, lds_a_buffer + ) + if const_expr(tile_head == 128): + a2, a3 = lds_load_packs_k64( + lds_a_row, lds_col1, tile_head, lds_a_buffer + ) + else: + a2 = a3 = fx.Int64(0) + a128 = pack_i64x4_to_i32x8(a0, a1, a2, a3) + + a_scale = a_scales[mi // 2] + + for ni in range_constexpr(ps_n_num_subtiles): + lds_b_row = ps_n_wave_id * ps_n_per_wave + ni * 16 + lane_mod_16 + b0, b1 = lds_load_packs_k64( + lds_b_row, lds_col0, tile_head, lds_b_buffer + ) + if const_expr(tile_head == 128): + b2, b3 = lds_load_packs_k64( + lds_b_row, lds_col1, tile_head, lds_b_buffer + ) + else: + b2 = b3 = fx.Int64(0) + b128 = pack_i64x4_to_i32x8(b0, b1, b2, b3) + b_scale = lds_scale_load( + lds_b_row, lds_scale_col, tile_head_mx, lds_b_scale_buffer + ) + + acc_idx = mi * ps_n_num_subtiles + ni + current_accs_list[acc_idx] = rocdl.mfma_scale_f32_16x16x128_f8f6f4( + mfma_res_ty, + [ + a128, + b128, + current_accs_list[acc_idx], + 0, + 0, + 0, + a_scale, + 0, + b_scale, + ], + ) + return current_accs_list + + def compute_ds(dp_accs, p_accs, offset_m): + # inputs are tile_m x tile_n shape + + accs_out = [acc_init] * ps_n_accs + + for mi in range_constexpr(ps_m_num_subtiles): + + global_D_idx = ( + offset_m + ps_m_wave_id * ps_m_per_wave + mi * 16 + lane_div_16 * 4 + ) + D_vector = buffer_ops.buffer_load(D_rsrc, global_D_idx, vec_width=4) + + for ni in range_constexpr(ps_n_num_subtiles): + + acc_idx = mi * ps_n_num_subtiles + ni + dp_f32x4 = dp_accs[acc_idx] + p_f32x4 = p_accs[acc_idx] + + vals_f32 = [] + for ii in range_constexpr(4): + dp_f32 = vector.extract( + dp_f32x4, static_position=[ii], dynamic_position=[] + ) + p_f32 = vector.extract( + p_f32x4, static_position=[ii], dynamic_position=[] + ) + D = vector.extract( + D_vector, static_position=[ii], dynamic_position=[] + ) + ds_f32 = p_f32 * (dp_f32 - D) + vals_f32.append(ds_f32) + + vals_f32_vector = vector.from_elements(T.f32x4, vals_f32) + accs_out[acc_idx] = vals_f32_vector + + return accs_out + + def wave_reduce_max_4threads(x): + width_i32 = arith.constant(64, type=T.i32) + w = x + for sh in [32, 16]: + off = arith.constant(sh, type=T.i32) + peer = w.shuffle_xor(off, width_i32) + w = w.maximumf(peer) + return w + + def mxquant_m_and_store_to_lds(accs_in, lds_buffer, lds_buffer_scale): + # inputs are tile_m x tile_n shape + + for mi in range_constexpr(ps_m_num_subtiles // 2): + for ni in range_constexpr(ps_n_num_subtiles): + + acc_idx0 = (mi * 2) * ps_n_num_subtiles + ni + acc_idx1 = (mi * 2 + 1) * ps_n_num_subtiles + ni + acc0 = accs_in[acc_idx0] + acc1 = accs_in[acc_idx1] + + vals_subtile0 = [] + vals_subtile1 = [] + vals_abs = [] + for ii in range_constexpr(4): + val0 = vector.extract( + acc0, static_position=[ii], dynamic_position=[] + ) + vals_subtile0.append(val0) + val1 = vector.extract( + acc1, static_position=[ii], dynamic_position=[] + ) + vals_subtile1.append(val1) + val0_abs = fx_math.absf(val0) + val1_abs = fx_math.absf(val1) + vals_abs.append(val0_abs) + vals_abs.append(val1_abs) + + vals_abs_vector = vector.from_elements(T.vec(8, T.f32), vals_abs) + val_max = vector.reduction(T.f32, "maxnumf", vals_abs_vector) + val_max = wave_reduce_max_4threads(val_max) + val_max = val_max * fp8_max_rcp + val_max = arith.bitcast(T.i32, val_max) + val_max = val_max + arith.constant(0x007FFFFF, type=T.i32) + val_max = val_max & arith.constant(0x7F800000, type=T.i32) + val_max_f32 = arith.bitcast(T.f32, val_max) + val_max_rcp = arith.select( + val_max_f32 == zero_f, zero_f, rocdl.rcp(T.f32, val_max_f32) + ) + scale = val_max >> 23 + scale = arith.trunci(T.i8, scale) + scale_vector = vector.from_elements(T.vec(1, T.i8), [scale]) + + for ii in range_constexpr(4): + vals_subtile0[ii] = vals_subtile0[ii] * val_max_rcp + vals_subtile1[ii] = vals_subtile1[ii] * val_max_rcp + + val_f8_packed_i32 = rocdl.cvt_pk_fp8_f32( + T.i32, vals_subtile0[0], vals_subtile0[1], fx.Int32(0), False + ) + val_f8_packed_i32 = rocdl.cvt_pk_fp8_f32( + T.i32, + vals_subtile0[2], + vals_subtile0[3], + val_f8_packed_i32, + True, + ) + val_f8_packed_i32_vector = vector.from_elements( + T.vec(1, T.i32), [val_f8_packed_i32] + ) + val_f8x4_subtile0 = vector.bitcast(T.f8x4, val_f8_packed_i32_vector) + + val_f8_packed_i32 = rocdl.cvt_pk_fp8_f32( + T.i32, vals_subtile1[0], vals_subtile1[1], fx.Int32(0), False + ) + val_f8_packed_i32 = rocdl.cvt_pk_fp8_f32( + T.i32, + vals_subtile1[2], + vals_subtile1[3], + val_f8_packed_i32, + True, + ) + val_f8_packed_i32_vector = vector.from_elements( + T.vec(1, T.i32), [val_f8_packed_i32] + ) + val_f8x4_subtile1 = vector.bitcast(T.f8x4, val_f8_packed_i32_vector) + + lds_row = ps_n_wave_id * ps_n_per_wave + ni * 16 + lane_mod_16 + lds_col_base0 = ( + ps_m_wave_id * ps_m_per_wave + (mi * 2) * 16 + ) # + lane_div_16 * 4 + lds_col_base1 = ( + ps_m_wave_id * ps_m_per_wave + (mi * 2 + 1) * 16 + ) # + lane_div_16 * 4 + lds_col0 = swizzle_xor16(lds_row, lds_col_base0, tile_m_div16) + lds_col1 = swizzle_xor16(lds_row, lds_col_base1, tile_m_div16) + lds_col0 = lds_col0 + lane_div_16 * 4 + lds_col1 = lds_col1 + lane_div_16 * 4 + lds_scale_col = ps_m_wave_id * ps_m_mx_per_wave + mi + lds_idx0 = lds_row * tile_m + lds_col0 + lds_idx1 = lds_row * tile_m + lds_col1 + lds_scale_idx = lds_row * tile_m_mx + lds_scale_col + + vector.store(val_f8x4_subtile0, lds_buffer, [lds_idx0]) + vector.store(val_f8x4_subtile1, lds_buffer, [lds_idx1]) + vector.store(scale_vector, lds_buffer_scale, [lds_scale_idx]) + + def wave_reduce_max_16threads(x): + width_i32 = arith.constant(64, type=T.i32) + w = x + for sh in [8, 4, 2, 1]: + off = arith.constant(sh, type=T.i32) + peer = w.shuffle_xor(off, width_i32) + w = w.maximumf(peer) + return w + + def mxquant_n_and_store_to_lds(accs_in, lds_buffer, lds_buffer_scale): + # inputs are tile_m x tile_n shape + + for mi in range_constexpr(ps_m_num_subtiles): + for ni in range_constexpr(ps_n_num_subtiles // 2): + + acc_idx0 = mi * ps_n_num_subtiles + ni * 2 + acc_idx1 = mi * ps_n_num_subtiles + ni * 2 + 1 + acc0 = accs_in[acc_idx0] + acc1 = accs_in[acc_idx1] + + vals_subtile0 = [] + vals_subtile1 = [] + scales = [] + for ii in range_constexpr(4): + val0 = vector.extract( + acc0, static_position=[ii], dynamic_position=[] + ) + val1 = vector.extract( + acc1, static_position=[ii], dynamic_position=[] + ) + val0_abs = fx_math.absf(val0) + val1_abs = fx_math.absf(val1) + val_max = arith.maximumf(val0_abs, val1_abs) + val_max = wave_reduce_max_16threads(val_max) + val_max = val_max * fp8_max_rcp + val_max = arith.bitcast(T.i32, val_max) + val_max = val_max + arith.constant(0x007FFFFF, type=T.i32) + val_max = val_max & arith.constant(0x7F800000, type=T.i32) + val_max_f32 = arith.bitcast(T.f32, val_max) + val_max_rcp = arith.select( + val_max_f32 == zero_f, zero_f, rocdl.rcp(T.f32, val_max_f32) + ) + val0_quant = val0 * val_max_rcp + vals_subtile0.append(val0_quant) + val1_quant = val1 * val_max_rcp + vals_subtile1.append(val1_quant) + scale = val_max >> 23 + scale = arith.trunci(T.i8, scale) + scales.append(scale) + + val_f8_packed_i32 = rocdl.cvt_pk_fp8_f32( + T.i32, vals_subtile0[0], vals_subtile0[1], fx.Int32(0), False + ) + val_f8_packed_i32 = rocdl.cvt_pk_fp8_f32( + T.i32, + vals_subtile0[2], + vals_subtile0[3], + val_f8_packed_i32, + True, + ) + val_f8_packed_i32_vector = vector.from_elements( + T.vec(1, T.i32), [val_f8_packed_i32] + ) + val_f8x4_subtile0 = vector.bitcast(T.f8x4, val_f8_packed_i32_vector) + + val_f8_packed_i32 = rocdl.cvt_pk_fp8_f32( + T.i32, vals_subtile1[0], vals_subtile1[1], fx.Int32(0), False + ) + val_f8_packed_i32 = rocdl.cvt_pk_fp8_f32( + T.i32, + vals_subtile1[2], + vals_subtile1[3], + val_f8_packed_i32, + True, + ) + val_f8_packed_i32_vector = vector.from_elements( + T.vec(1, T.i32), [val_f8_packed_i32] + ) + val_f8x4_subtile1 = vector.bitcast(T.f8x4, val_f8_packed_i32_vector) + + lds_row0 = ( + ps_n_wave_id * ps_n_per_wave + (ni * 2) * 16 + lane_mod_16 + ) + lds_row1 = ( + ps_n_wave_id * ps_n_per_wave + (ni * 2 + 1) * 16 + lane_mod_16 + ) + lds_col_base = ( + ps_m_wave_id * ps_m_per_wave + mi * 16 + ) # + lane_div_16 * 4 + lds_col0 = swizzle_xor16(lds_row0, lds_col_base, tile_m_div16) + lds_col1 = swizzle_xor16(lds_row1, lds_col_base, tile_m_div16) + lds_col0 = lds_col0 + lane_div_16 * 4 + lds_col1 = lds_col1 + lane_div_16 * 4 + lds_idx0 = lds_row0 * tile_m + lds_col0 + lds_idx1 = lds_row1 * tile_m + lds_col1 + vector.store(val_f8x4_subtile0, lds_buffer, [lds_idx0]) + vector.store(val_f8x4_subtile1, lds_buffer, [lds_idx1]) + + for ii in range_constexpr(4): + lds_scale_row = ( + ps_m_wave_id * ps_m_per_wave + + mi * 16 + + lane_div_16 * 4 + + ii + ) + lds_scale_col = ps_n_wave_id * ps_n_mx_per_wave + ni + lds_scale_idx = lds_scale_row * tile_n_mx + lds_scale_col + + scale_vector = vector.from_elements( + T.vec(1, T.i8), [scales[ii]] + ) + vector.store(scale_vector, lds_buffer_scale, [lds_scale_idx]) + + def compute_dk( + accs_in, lds_a_buffer, lds_a_scale_buffer, lds_b_buffer, b_scales + ): + current_accs_list = list(accs_in) + mfma_res_ty = T.f32x4 + num_subtiles_reduction = max(1, tile_m // 128) + for ku128 in range_constexpr(num_subtiles_reduction): + ku0 = ku128 * 2 + ku1 = ku0 + 1 + + lds_a_col0 = ku0 * 64 + lane_div_16 * 16 + lds_a_col1 = ku1 * 64 + lane_div_16 * 16 + + lds_b_row0 = ku0 * 64 + lane_div_16 * 16 + lane_div_2 + lds_b_row1 = ku0 * 64 + lane_div_16 * 16 + 8 + lane_div_2 + lds_b_row2 = ku1 * 64 + lane_div_16 * 16 + lane_div_2 + lds_b_row3 = ku1 * 64 + lane_div_16 * 16 + 8 + lane_div_2 + + lds_a_scale_col = lane_div_16 + lds_b_scale_row = lane_div_16 + if tile_m == 64: + lds_a_scale_col = lds_a_scale_col % 2 + lds_b_scale_row = lds_b_scale_row % 2 + + for ni in range_constexpr(dk_num_subtiles_n): + lds_a_row = dk_n_wave_id * dk_n_per_wave + ni * 16 + lane_mod_16 + a0, a1 = lds_load_packs_k64( + lds_a_row, lds_a_col0, tile_m, lds_a_buffer + ) + if const_expr(tile_m == 128): + a2, a3 = lds_load_packs_k64( + lds_a_row, lds_a_col1, tile_m, lds_a_buffer + ) + else: + a2 = fx.Int64(0) + a3 = fx.Int64(0) + a128 = pack_i64x4_to_i32x8(a0, a1, a2, a3) + a_scale = lds_scale_load( + lds_a_row, lds_a_scale_col, tile_m_mx, lds_a_scale_buffer + ) + + for hi in range_constexpr(dk_num_subtiles_head): + lds_b_col = dk_head_wave_id * dk_head_per_wave + hi * 16 + b0 = lds_load_packs_k32_transposed( + lds_b_row0, lds_b_col, tile_head, lds_b_buffer + ) + b1 = lds_load_packs_k32_transposed( + lds_b_row1, lds_b_col, tile_head, lds_b_buffer + ) + if const_expr(tile_m == 128): + b2 = lds_load_packs_k32_transposed( + lds_b_row2, lds_b_col, tile_head, lds_b_buffer + ) + b3 = lds_load_packs_k32_transposed( + lds_b_row3, lds_b_col, tile_head, lds_b_buffer + ) + else: + b2 = fx.Int64(0) + b3 = fx.Int64(0) + b128 = pack_i64x4_to_i32x8(b0, b1, b2, b3) + + b_scale = b_scales[hi // 2] + + acc_idx = ni * dk_num_subtiles_head + hi + current_accs_list[acc_idx] = ( + rocdl.mfma_scale_f32_16x16x128_f8f6f4( + mfma_res_ty, + [ + a128, + b128, + current_accs_list[acc_idx], + 0, + 0, + 0, + a_scale, + 0, + b_scale, + ], + ) + ) + return current_accs_list + + def compute_dq(lds_a_buffer, lds_a_scale_buffer, lds_b_buffer, b_scales): + # (m, n) @ (n, head) = (m, head) + + current_accs_list = [acc_init] * dq_n_accs + mfma_res_ty = T.f32x4 + + num_subtiles_reduction = max(1, tile_n // 128) + for ku128 in range_constexpr(num_subtiles_reduction): + ku0 = ku128 * 2 + ku1 = ku0 + 1 + + lds_a_row0 = ku0 * 64 + lane_div_16 * 16 + lane_div_2 + lds_a_row1 = ku0 * 64 + lane_div_16 * 16 + 8 + lane_div_2 + lds_a_row2 = ku1 * 64 + lane_div_16 * 16 + lane_div_2 + lds_a_row3 = ku1 * 64 + lane_div_16 * 16 + 8 + lane_div_2 + + lds_b_row0 = ku0 * 64 + lane_div_16 * 16 + lane_div_2 + lds_b_row1 = ku0 * 64 + lane_div_16 * 16 + 8 + lane_div_2 + lds_b_row2 = ku1 * 64 + lane_div_16 * 16 + lane_div_2 + lds_b_row3 = ku1 * 64 + lane_div_16 * 16 + 8 + lane_div_2 + + lds_a_scale_col = lane_div_16 + # lds_b_scale_row = lane_div_16 + if const_expr(tile_n == 64): + lds_a_scale_col = lds_a_scale_col % 2 + # lds_b_scale_row = lds_b_scale_row % 2 + + for mi in range_constexpr(dq_num_subtiles_m): + lds_a_col = ( + dq_m_wave_id * dq_m_per_wave + mi * 16 + ) # + lane_mod_2 * 8 + a0 = lds_load_packs_k32_transposed( + lds_a_row0, lds_a_col, tile_m, lds_a_buffer + ) + a1 = lds_load_packs_k32_transposed( + lds_a_row1, lds_a_col, tile_m, lds_a_buffer + ) + if const_expr(tile_n == 128): + a2 = lds_load_packs_k32_transposed( + lds_a_row2, lds_a_col, tile_m, lds_a_buffer + ) + a3 = lds_load_packs_k32_transposed( + lds_a_row3, lds_a_col, tile_m, lds_a_buffer + ) + else: + a2 = fx.Int64(0) + a3 = fx.Int64(0) + + a128 = pack_i64x4_to_i32x8(a0, a1, a2, a3) + lds_a_scale_row = ( + dq_m_wave_id * dq_m_per_wave + mi * 16 + lane_mod_16 + ) + a_scale = lds_scale_load( + lds_a_scale_row, lds_a_scale_col, tile_n_mx, lds_a_scale_buffer + ) + + for hi in range_constexpr(dq_num_subtiles_head): + lds_b_col = dq_head_wave_id * dq_head_per_wave + hi * 16 + b0 = lds_load_packs_k32_transposed( + lds_b_row0, lds_b_col, tile_head, lds_b_buffer + ) + b1 = lds_load_packs_k32_transposed( + lds_b_row1, lds_b_col, tile_head, lds_b_buffer + ) + if const_expr(tile_n == 128): + b2 = lds_load_packs_k32_transposed( + lds_b_row2, lds_b_col, tile_head, lds_b_buffer + ) + b3 = lds_load_packs_k32_transposed( + lds_b_row3, lds_b_col, tile_head, lds_b_buffer + ) + else: + b2 = fx.Int64(0) + b3 = fx.Int64(0) + b128 = pack_i64x4_to_i32x8(b0, b1, b2, b3) + + b_scale = b_scales[hi // 2] + + acc_idx = mi * dq_num_subtiles_head + hi + current_accs_list[acc_idx] = ( + rocdl.mfma_scale_f32_16x16x128_f8f6f4( + mfma_res_ty, + [ + a128, + b128, + current_accs_list[acc_idx], + 0, + 0, + 0, + a_scale, + 0, + b_scale, + ], + ) + ) + return current_accs_list + + def store_dq_atomic(final_accs, offset_m): + for mi in range_constexpr(dq_num_subtiles_m): + for hi in range_constexpr(dq_num_subtiles_head): + for ii in range_constexpr(4): + global_row = ( + offset_m + + dq_m_wave_id * dq_m_per_wave + + mi * 16 + + lane_div_16 * 4 + + ii + ) + global_col = ( + dq_head_wave_id * dq_head_per_wave + hi * 16 + lane_mod_16 + ) + global_idx = global_row * head_dim + global_col + global_idx_bytes = global_idx * 4 + + acc_idx = mi * dq_num_subtiles_head + hi + acc = final_accs[acc_idx] + val_f32 = vector.extract( + acc, static_position=[ii], dynamic_position=[] + ) + val_f32 = val_f32 * c_sm_scale + rocdl.raw_ptr_buffer_atomic_fadd( + val_f32, + dq_rsrc, + fx.Int32(global_idx_bytes), + fx.Int32(0), + fx.Int32(0), + ) + # buffer_ops.buffer_store(val_f32, dq_rsrc, global_idx) + + def store_dk_atomic(final_accs): + for ni in range_constexpr(dk_num_subtiles_n): + for hi in range_constexpr(dk_num_subtiles_head): + acc_idx = ni * dk_num_subtiles_head + hi + acc = final_accs[acc_idx] + for ii in range_constexpr(4): + + global_row = ( + global_offset_n + + dk_n_wave_id * dk_n_per_wave + + ni * 16 + + lane_div_16 * 4 + + ii + ) + global_col = ( + dk_head_wave_id * dk_head_per_wave + hi * 16 + lane_mod_16 + ) + global_idx = global_row * head_dim + global_col + + acc_idx = ni * dk_num_subtiles_head + hi + acc = final_accs[acc_idx] + val_f32 = vector.extract( + acc, static_position=[ii], dynamic_position=[] + ) + val_f32 = val_f32 * c_sm_scale + if const_expr(gqa_size == 1): + buffer_ops.buffer_store(val_f32, dk_rsrc, global_idx) + else: + global_idx_bytes = global_idx * 4 + rocdl.raw_ptr_buffer_atomic_fadd( + val_f32, + dk_rsrc, + fx.Int32(global_idx_bytes), + fx.Int32(0), + fx.Int32(0), + ) + + def store_dv_atomic(final_accs): + for ni in range_constexpr(dv_n_num_subtiles): + for hi in range_constexpr(dv_head_num_subtiles): + acc_idx = ni * dv_head_num_subtiles + hi + acc = final_accs[acc_idx] + for ii in range_constexpr(4): + + global_row = ( + global_offset_n + + dv_n_wave_id * dv_n_per_wave + + ni * 16 + + lane_div_16 * 4 + + ii + ) + global_col = ( + dv_head_wave_id * dv_head_per_wave + hi * 16 + lane_mod_16 + ) + global_idx = global_row * head_dim + global_col + + acc_idx = ni * dv_head_num_subtiles + hi + acc = final_accs[acc_idx] + val_f32 = vector.extract( + acc, static_position=[ii], dynamic_position=[] + ) + if const_expr(gqa_size == 1): + buffer_ops.buffer_store(val_f32, dv_rsrc, global_idx) + else: + global_idx_bytes = global_idx * 4 + rocdl.raw_ptr_buffer_atomic_fadd( + val_f32, + dv_rsrc, + fx.Int32(global_idx_bytes), + fx.Int32(0), + fx.Int32(0), + ) + + # ── Scheduling hints ────────────────────────────────────────────── + rocdl.sched_barrier(0) + + def hot_loop_scheduler(): + rocdl.sched_barrier(0) + return + + # ── Main pipeline ───────────────────────────────────────────────── + + def _pack_state(dk, dv, q_scales_head, q_scales_m, do_scales_head, do_scales_m): + return ( + list(dk) + + list(dv) + + list(q_scales_head) + + list(q_scales_m) + + list(do_scales_head) + + list(do_scales_m) + ) + + def _unpack_state(vals): + dk = list(vals[:dk_n_accs]) + dv = list(vals[dk_n_accs : dk_n_accs + dv_n_accs]) + q_scales_head = list( + vals[ + dk_n_accs + + dv_n_accs : dk_n_accs + + dv_n_accs + + ps_m_num_subtiles // 2 + ] + ) + q_scales_m = list( + vals[ + dk_n_accs + + dv_n_accs + + ps_m_num_subtiles // 2 : dk_n_accs + + dv_n_accs + + ps_m_num_subtiles // 2 + + dk_num_subtiles_head // 2 + ] + ) + do_scales_head = list( + vals[ + dk_n_accs + + dv_n_accs + + ps_m_num_subtiles // 2 + + dk_num_subtiles_head // 2 : dk_n_accs + + dv_n_accs + + ps_m_num_subtiles // 2 + + dk_num_subtiles_head // 2 + + ps_m_num_subtiles // 2 + ] + ) + do_scales_m = list( + vals[ + dk_n_accs + + dv_n_accs + + ps_m_num_subtiles // 2 + + dk_num_subtiles_head // 2 + + ps_m_num_subtiles // 2 : + ] + ) + return dk, dv, q_scales_head, q_scales_m, do_scales_head, do_scales_m + + def pingpong(offset_m, inner_state): + ( + dk, + dv, + q_scales_head_pong, + q_scales_m_pong, + do_scales_head_pong, + do_scales_m_pong, + ) = _unpack_state(inner_state) + + next_offset_m = offset_m + tile_m + next_offset_m_mx = next_offset_m // 32 + store_q_tile_to_lds(prefetch_q_tile(next_offset_m), lds_q_ping) + q_scales_head_ping = prefetch_q_scale_head_2d_tile(next_offset_m_mx) + q_scales_m_ping = prefetch_q_scale_m_2d_tile(next_offset_m_mx) + store_do_tile_to_lds(prefetch_do_tile(next_offset_m), lds_do_ping) + do_scales_head_ping = prefetch_do_scale_head_2d_tile(next_offset_m_mx) + do_scales_m_ping = prefetch_do_scale_m_2d_tile(next_offset_m_mx) + + qk = compute_qk( + lds_q_pong, + q_scales_head_pong, + lds_k, + k_scales_head, + ) + p = softmax(qk, offset_m) + mxquant_m_and_store_to_lds(p, lds_ppt_shuffle, lds_ppt_scale_shuffle) + gpu.barrier() + dv = compute_dv( + dv, + lds_ppt_shuffle, + lds_ppt_scale_shuffle, + lds_do_pong, + do_scales_m_pong, + ) + dp = compute_dp(lds_do_pong, do_scales_head_pong, lds_v, lds_v_scale) + ds = compute_ds(dp, p, offset_m) + mxquant_m_and_store_to_lds(ds, lds_dst_shuffle, lds_dst_scale_shuffle) + mxquant_n_and_store_to_lds(ds, lds_ds_shuffle, lds_ds_scale_shuffle) + gpu.barrier() + dk = compute_dk( + dk, + lds_dst_shuffle, + lds_dst_scale_shuffle, + lds_q_pong, + q_scales_m_pong, + ) + dq = compute_dq(lds_ds_shuffle, lds_ds_scale_shuffle, lds_k, k_scales_n) + store_dq_atomic(dq, offset_m) + hot_loop_scheduler() + gpu.barrier() + + next_offset_m = offset_m + (tile_m * 2) + next_offset_m_mx = next_offset_m // 32 + store_q_tile_to_lds(prefetch_q_tile(next_offset_m), lds_q_pong) + q_scales_head_pong = prefetch_q_scale_head_2d_tile(next_offset_m_mx) + q_scales_m_pong = prefetch_q_scale_m_2d_tile(next_offset_m_mx) + store_do_tile_to_lds(prefetch_do_tile(next_offset_m), lds_do_pong) + do_scales_head_pong = prefetch_do_scale_head_2d_tile(next_offset_m_mx) + do_scales_m_pong = prefetch_do_scale_m_2d_tile(next_offset_m_mx) + + qk = compute_qk( + lds_q_ping, + q_scales_head_ping, + lds_k, + k_scales_head, + ) + p = softmax(qk, offset_m + tile_m) + mxquant_m_and_store_to_lds(p, lds_ppt_shuffle, lds_ppt_scale_shuffle) + gpu.barrier() + dv = compute_dv( + dv, + lds_ppt_shuffle, + lds_ppt_scale_shuffle, + lds_do_ping, + do_scales_m_ping, + ) + dp = compute_dp(lds_do_ping, do_scales_head_ping, lds_v, lds_v_scale) + ds = compute_ds(dp, p, offset_m + tile_m) + mxquant_m_and_store_to_lds(ds, lds_dst_shuffle, lds_dst_scale_shuffle) + mxquant_n_and_store_to_lds(ds, lds_ds_shuffle, lds_ds_scale_shuffle) + gpu.barrier() + dk = compute_dk( + dk, + lds_dst_shuffle, + lds_dst_scale_shuffle, + lds_q_ping, + q_scales_m_ping, + ) + dq = compute_dq(lds_ds_shuffle, lds_ds_scale_shuffle, lds_k, k_scales_n) + store_dq_atomic(dq, offset_m + tile_m) + hot_loop_scheduler() + gpu.barrier() + + return _pack_state( + dk, + dv, + q_scales_head_pong, + q_scales_m_pong, + do_scales_head_pong, + do_scales_m_pong, + ) + + if const_expr(causal): + start_m = (global_offset_n // (tile_m * 2)) * (tile_m * 2) + else: + start_m = fx.Index(0) + start_m_mx = start_m // 32 + + store_q_tile_to_lds(prefetch_q_tile(start_m), lds_q_pong) + q_scales_head_pong = prefetch_q_scale_head_2d_tile(start_m_mx) + q_scales_m_pong = prefetch_q_scale_m_2d_tile(start_m_mx) + store_k_tile_to_lds(prefetch_k_tile(), lds_k) + k_scales_head = prefetch_k_scale_head_2d_tile() + k_scales_n = prefetch_k_scale_n_2d_tile() + store_v_tile_to_lds(prefetch_v_tile(), lds_v) + store_v_scale_tile_to_lds(prefetch_v_scale_tile(), lds_v_scale) + store_do_tile_to_lds(prefetch_do_tile(start_m), lds_do_pong) + do_scales_head_pong = prefetch_do_scale_head_2d_tile(start_m_mx) + do_scales_m_pong = prefetch_do_scale_m_2d_tile(start_m_mx) + gpu.barrier() + dk = [acc_init] * dk_n_accs + dv = [acc_init] * dv_n_accs + + num_tiles_loop = seqlen_rounded // tile_m + if const_expr((num_tiles_loop % 2) == 1): + upper_bound = seqlen_rounded - tile_m + init_state = _pack_state( + dk, + dv, + q_scales_head_pong, + q_scales_m_pong, + do_scales_head_pong, + do_scales_m_pong, + ) + for iv, inner in range(start_m, upper_bound, tile_m * 2, init=init_state): + results = yield pingpong(iv, inner) + ( + dk, + dv, + q_scales_head_pong, + q_scales_m_pong, + do_scales_head_pong, + do_scales_m_pong, + ) = _unpack_state(results) + + curr_m = arith.index(seqlen_rounded - tile_m) + qk = compute_qk( + lds_q_pong, + q_scales_head_pong, + lds_k, + k_scales_head, + ) + p = softmax(qk, curr_m) + mxquant_m_and_store_to_lds(p, lds_ppt_shuffle, lds_ppt_scale_shuffle) + gpu.barrier() + dv = compute_dv( + dv, + lds_ppt_shuffle, + lds_ppt_scale_shuffle, + lds_do_pong, + do_scales_m_pong, + ) + dp = compute_dp(lds_do_pong, do_scales_head_pong, lds_v, lds_v_scale) + ds = compute_ds(dp, p, curr_m) + mxquant_m_and_store_to_lds(ds, lds_dst_shuffle, lds_dst_scale_shuffle) + mxquant_n_and_store_to_lds(ds, lds_ds_shuffle, lds_ds_scale_shuffle) + gpu.barrier() + dk = compute_dk( + dk, + lds_dst_shuffle, + lds_dst_scale_shuffle, + lds_q_pong, + q_scales_m_pong, + ) + dq = compute_dq(lds_ds_shuffle, lds_ds_scale_shuffle, lds_k, k_scales_n) + store_dq_atomic(dq, curr_m) + else: + upper_bound = seqlen_rounded - (tile_m * 2) + init_state = _pack_state( + dk, + dv, + q_scales_head_pong, + q_scales_m_pong, + do_scales_head_pong, + do_scales_m_pong, + ) + for iv, inner in range(start_m, upper_bound, tile_m * 2, init=init_state): + results = yield pingpong(iv, inner) + ( + dk, + dv, + q_scales_head_pong, + q_scales_m_pong, + do_scales_head_pong, + do_scales_m_pong, + ) = _unpack_state(results) + + curr_m = arith.index(seqlen_rounded - tile_m * 2) + last_m = arith.index(seqlen_rounded - tile_m) + last_m_mx = last_m // 32 + store_q_tile_to_lds(prefetch_q_tile(last_m), lds_q_ping) + q_scales_head_ping = prefetch_q_scale_head_2d_tile(last_m_mx) + q_scales_m_ping = prefetch_q_scale_m_2d_tile(last_m_mx) + store_do_tile_to_lds(prefetch_do_tile(last_m), lds_do_ping) + do_scales_head_ping = prefetch_do_scale_head_2d_tile(last_m_mx) + do_scales_m_ping = prefetch_do_scale_m_2d_tile(last_m_mx) + + qk = compute_qk( + lds_q_pong, + q_scales_head_pong, + lds_k, + k_scales_head, + ) + p = softmax(qk, curr_m) + mxquant_m_and_store_to_lds(p, lds_ppt_shuffle, lds_ppt_scale_shuffle) + gpu.barrier() + dv = compute_dv( + dv, + lds_ppt_shuffle, + lds_ppt_scale_shuffle, + lds_do_pong, + do_scales_m_pong, + ) + dp = compute_dp(lds_do_pong, do_scales_head_pong, lds_v, lds_v_scale) + ds = compute_ds(dp, p, curr_m) + mxquant_m_and_store_to_lds(ds, lds_dst_shuffle, lds_dst_scale_shuffle) + mxquant_n_and_store_to_lds(ds, lds_ds_shuffle, lds_ds_scale_shuffle) + gpu.barrier() + dk = compute_dk( + dk, + lds_dst_shuffle, + lds_dst_scale_shuffle, + lds_q_pong, + q_scales_m_pong, + ) + dq = compute_dq(lds_ds_shuffle, lds_ds_scale_shuffle, lds_k, k_scales_n) + store_dq_atomic(dq, curr_m) + + hot_loop_scheduler() + gpu.barrier() + + curr_m = last_m + qk = compute_qk( + lds_q_ping, + q_scales_head_ping, + lds_k, + k_scales_head, + ) + p = softmax(qk, curr_m) + mxquant_m_and_store_to_lds(p, lds_ppt_shuffle, lds_ppt_scale_shuffle) + gpu.barrier() + dv = compute_dv( + dv, + lds_ppt_shuffle, + lds_ppt_scale_shuffle, + lds_do_ping, + do_scales_m_ping, + ) + dp = compute_dp(lds_do_ping, do_scales_head_ping, lds_v, lds_v_scale) + ds = compute_ds(dp, p, curr_m) + mxquant_m_and_store_to_lds(ds, lds_dst_shuffle, lds_dst_scale_shuffle) + mxquant_n_and_store_to_lds(ds, lds_ds_shuffle, lds_ds_scale_shuffle) + gpu.barrier() + dk = compute_dk( + dk, + lds_dst_shuffle, + lds_dst_scale_shuffle, + lds_q_ping, + q_scales_m_ping, + ) + dq = compute_dq(lds_ds_shuffle, lds_ds_scale_shuffle, lds_k, k_scales_n) + store_dq_atomic(dq, curr_m) + + store_dk_atomic(dk) + store_dv_atomic(dv) + + # ── Host launcher ────────────────────────────────────────────────────── + _cache_tag = (tile_m, tile_n, head_dim) + + @flyc.jit + def launch_attn_bwd( + arg_dq: fx.Tensor, + arg_dk: fx.Tensor, + arg_dv: fx.Tensor, + arg_q: fx.Tensor, + arg_q_scale: fx.Tensor, + arg_k: fx.Tensor, + arg_k_scale: fx.Tensor, + arg_v: fx.Tensor, + arg_v_scale: fx.Tensor, + arg_do_quant_head: fx.Tensor, + arg_do_scale: fx.Tensor, + arg_M: fx.Tensor, + arg_D: fx.Tensor, + batch: fx.Int32, + stride_qo_batch: fx.Int32, + stride_kv_batch: fx.Int32, + stride_MD_batch: fx.Int32, + stride_qkvo_nheads: fx.Int32, + stride_MD_nheads: fx.Int32, + stride_q_scale_batch: fx.Int32, + stride_q_scale_nheads: fx.Int32, + stride_k_scale_batch: fx.Int32, + stride_k_scale_nheads: fx.Int32, + stride_v_scale_batch: fx.Int32, + stride_v_scale_nheads: fx.Int32, + stride_do_scale_batch: fx.Int32, + stride_do_scale_nheads: fx.Int32, + stream: fx.Stream, + ): + _ = _cache_tag + allocator_pong.finalized = False + allocator_ping.finalized = False + allocator_k.finalized = False + allocator_v.finalized = False + allocator_v_scale.finalized = False + allocator_ppt_shuffle.finalized = False + allocator_ppt_scale_shuffle.finalized = False + allocator_dst_shuffle.finalized = False + allocator_dst_scale_shuffle.finalized = False + allocator_ds_shuffle.finalized = False + allocator_ds_scale_shuffle.finalized = False + ctx = CompilationContext.get_current() + with ir.InsertionPoint(ctx.gpu_module_body): + allocator_pong.finalize() + allocator_ping.finalize() + allocator_k.finalize() + allocator_v.finalize() + allocator_v_scale.finalize() + allocator_ppt_shuffle.finalize() + allocator_ppt_scale_shuffle.finalize() + allocator_dst_shuffle.finalize() + allocator_dst_scale_shuffle.finalize() + allocator_ds_shuffle.finalize() + allocator_ds_scale_shuffle.finalize() + + gx = num_heads_q + gy = (seqlen + tile_n - 1) // tile_n + gz = batch + + launcher = kernel_attn_bwd( + arg_dq, + arg_dk, + arg_dv, + arg_q, + arg_q_scale, + arg_k, + arg_k_scale, + arg_v, + arg_v_scale, + arg_do_quant_head, + arg_do_scale, + arg_M, + arg_D, + batch, + stride_qo_batch, + stride_kv_batch, + stride_MD_batch, + stride_qkvo_nheads, + stride_MD_nheads, + stride_q_scale_batch, + stride_q_scale_nheads, + stride_k_scale_batch, + stride_k_scale_nheads, + stride_v_scale_batch, + stride_v_scale_nheads, + stride_do_scale_batch, + stride_do_scale_nheads, + ) + if waves_per_eu is not None: + _wpe = int(waves_per_eu) + if _wpe >= 1: + for op in ctx.gpu_module_body.operations: + if hasattr(op, "attributes") and op.OPERATION_NAME == "gpu.func": + op.attributes["rocdl.waves_per_eu"] = ir.IntegerAttr.get( + T.i32, _wpe + ) + launcher.launch( + grid=(gx, gy, gz), + block=(256, 1, 1), + stream=stream, + ) + + return launch_attn_bwd + + +__all__ = ["compile_attn_bwd_mxfp8_gfx950"] diff --git a/aiter/ops/triton/_triton_kernels/attention/mha_fused_bwd.py b/aiter/ops/triton/_triton_kernels/attention/mha_fused_bwd.py index ff5b96b77c..13d6e55f44 100644 --- a/aiter/ops/triton/_triton_kernels/attention/mha_fused_bwd.py +++ b/aiter/ops/triton/_triton_kernels/attention/mha_fused_bwd.py @@ -29,6 +29,95 @@ ) +@triton.jit +def upcast_mxfp8(tensor, scale, BLOCK_M, BLOCK_D_POW2): + scale = (scale.to(tl.uint32) << 23).to(tl.float32, bitcast=True) + tensor = tensor.to(tl.float32) + tensor = tensor.reshape((BLOCK_M, BLOCK_D_POW2 // 32, 32)) + tensor = tensor * scale[:, :, None] + tensor = tensor.reshape((BLOCK_M, BLOCK_D_POW2)) + return tensor + + +@triton.jit +def _bwd_preprocess_mxfp8( + o_ptr, + o_scale_ptr, + do_ptr, + do_scale_ptr, + delta_ptr, + stride_o_b, + stride_o_m, + stride_o_k, + stride_o_scale_b, + stride_o_scale_m, + stride_o_scale_k, + stride_do_b, + stride_do_m, + stride_do_k, + stride_do_scale_b, + stride_do_scale_m, + stride_do_scale_k, + stride_delta_b, + stride_delta_m, + max_seqlen_q, + BLOCK_M: tl.constexpr, + BLOCK_D: tl.constexpr, + BLOCK_D_POW2: tl.constexpr, +): + pid_m = tl.program_id(0) # seqlen + bid = tl.program_id(1) # batch + + # Compute offsets + BLOCK_D_SCALE: tl.constexpr = BLOCK_D // 32 + BLOCK_D_SCALE_POW2: tl.constexpr = BLOCK_D_POW2 // 32 + offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_k = tl.arange(0, BLOCK_D_POW2) + offs_k_scale = tl.arange(0, BLOCK_D_SCALE_POW2) + + offs_o = ( + bid * stride_o_b + offs_m[:, None] * stride_o_m + offs_k[None, :] * stride_o_k + ) + offs_o_scale = ( + bid * stride_o_scale_b + + offs_m[:, None] * stride_o_scale_m + + offs_k_scale[None, :] * stride_o_scale_k + ) + offs_do = ( + bid * stride_do_b + + offs_m[:, None] * stride_do_m + + offs_k[None, :] * stride_do_k + ) + offs_do_scale = ( + bid * stride_do_scale_b + + offs_m[:, None] * stride_do_scale_m + + offs_k_scale[None, :] * stride_do_scale_k + ) + + # create masks + mask_m = offs_m < max_seqlen_q + mask = mask_m[:, None] + mask_scale = mask + PADDED_HEAD: tl.constexpr = BLOCK_D != BLOCK_D_POW2 + if PADDED_HEAD: + mask &= offs_k[None, :] < BLOCK_D + mask_scale &= offs_k_scale[None, :] < BLOCK_D_SCALE + + # load [BLOCK_M, BLOCK_D_MODEL_POW2] + o = tl.load(o_ptr + offs_o, mask=mask, other=0.0) + o_scale = tl.load(o_scale_ptr + offs_o_scale, mask=mask_scale, other=0.0) + do = tl.load(do_ptr + offs_do, mask=mask, other=0.0) + do_scale = tl.load(do_scale_ptr + offs_do_scale, mask=mask_scale, other=0.0) + + # compute and write-back to delta + o_fp32 = upcast_mxfp8(o, o_scale, BLOCK_M, BLOCK_D_POW2) + do_fp32 = upcast_mxfp8(do, do_scale, BLOCK_M, BLOCK_D_POW2) + delta = tl.sum(o_fp32 * do_fp32, axis=1) + + offs_delta = bid * stride_delta_b + offs_m * stride_delta_m + tl.store(delta_ptr + offs_delta, delta, mask=mask_m) + + @triton.jit(repr=_bwd_preprocess_repr) def _bwd_preprocess( o_ptr, diff --git a/aiter/ops/triton/_triton_kernels/quant/mxfp8_quant.py b/aiter/ops/triton/_triton_kernels/quant/mxfp8_quant.py new file mode 100644 index 0000000000..56212cd93a --- /dev/null +++ b/aiter/ops/triton/_triton_kernels/quant/mxfp8_quant.py @@ -0,0 +1,533 @@ +# SPDX-License-Identifier: MIT +# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. + +import triton +import triton.language as tl + + +@triton.jit +def _get_max_quant_val(dtype: tl.constexpr): + if dtype == tl.float8e5: + return 57344.0 + elif dtype == tl.float8e4nv: + return 448.0 + else: + tl.static_assert(False, f"Invalid {dtype=}") + + +@triton.jit +def _get_max_power_of_2_quant_val(dtype: tl.constexpr): + if dtype == tl.float8e5: + return 32768.0 + elif dtype == tl.float8e4nv: + return 256.0 + else: + tl.static_assert(False, f"Invalid {dtype=}") + + +@triton.jit +def _compute_mx_quant_and_scale( + src_tensor, + valid_src_mask, + mx_tensor_dtype: tl.constexpr, + SCALE_ROUNDING_MODE: tl.constexpr = 0, +): + BLOCK_SIZE_OUT_DIM: tl.constexpr = src_tensor.shape[0] + BLOCK_SIZE_QUANT_DIM: tl.constexpr = src_tensor.shape[1] + BLOCK_SIZE_QUANT_MX_SCALE: tl.constexpr = src_tensor.shape[1] // 32 + + # Explicit cast to fp32 since most ops are not supported on bfloat16. We avoid needless conversions to and from bf16 + f32_tensor = src_tensor.to(tl.float32) + abs_tensor = tl.abs(f32_tensor) + abs_tensor = tl.where( + valid_src_mask, abs_tensor, -1.0 + ) # Don't consider padding tensors in scale computation + abs_tensor = tl.reshape( + abs_tensor, [BLOCK_SIZE_OUT_DIM, BLOCK_SIZE_QUANT_MX_SCALE, 32] + ) + max_val = tl.max(abs_tensor, axis=2, keep_dims=True) + if SCALE_ROUNDING_MODE == 0: + # ROUND_UP + # compute 2 ** ceil(log2(dequant_scale)) + # Adding 0x007FFFFF adds exponent by 1 unless mantissa is all zeros + # A corner case: exponent is 0xFF that will overflow but that's already + # NaN so assume we don't care. + dequant_scale = max_val / _get_max_quant_val(mx_tensor_dtype) + dequant_scale_exponent = ( + dequant_scale.to(tl.uint32, bitcast=True) + 0x007FFFFF + ) & 0x7F800000 + else: + # ROUND_DOWN + # compute 2 ** floor(log2(dequant_scale)) + assert SCALE_ROUNDING_MODE == 1 + dequant_scale = max_val / _get_max_power_of_2_quant_val(mx_tensor_dtype) + dequant_scale_exponent = dequant_scale.to(tl.uint32, bitcast=True) & 0x7F800000 + dequant_scale_rounded = dequant_scale_exponent.to(tl.float32, bitcast=True) + quant_scale = tl.where(dequant_scale_rounded == 0, 0, 1.0 / dequant_scale_rounded) + + f32_tensor = tl.reshape( + f32_tensor, [BLOCK_SIZE_OUT_DIM, BLOCK_SIZE_QUANT_MX_SCALE, 32] + ) + quant_tensor = f32_tensor * quant_scale + + # Reshape the tensors after scaling + quant_tensor = quant_tensor.reshape([BLOCK_SIZE_OUT_DIM, BLOCK_SIZE_QUANT_DIM]) + # Set the invalid portions of the tensor to 0. This will ensure that any padding tensors are 0 in the mx format. + quant_tensor = tl.where(valid_src_mask, quant_tensor, 0) + dequant_scale_exponent = dequant_scale_exponent.reshape( + [BLOCK_SIZE_OUT_DIM, BLOCK_SIZE_QUANT_MX_SCALE] + ) + + # First, we simply extract the exponent part of the scales and store the result + dequant_scale_exponent = (dequant_scale_exponent >> 23).to(tl.uint8) + # Now we must convert the tensors to the mx format. + out_tensor = quant_tensor.to(mx_tensor_dtype) + + return out_tensor, dequant_scale_exponent + + +@triton.jit +def _downcast_to_mxfp8( + mx_tensor_ptr, + stride_mxt_outer, + stride_mxt_quant: tl.constexpr, + mx_scale_ptr, + stride_mx_scale_outer, + stride_mx_scale_quant, + src_ptr, + stride_src_outer, + stride_src_quant, + outer_dim, + quant_dim, + BLOCK_SIZE_OUT_DIM: tl.constexpr, + BLOCK_SIZE_QUANT_DIM: tl.constexpr, + SCALE_ROUNDING_MODE: tl.constexpr, +): + + tl.static_assert( + stride_mxt_quant == 1, f"Output stride, {stride_mxt_quant=} must be 1." + ) + tl.static_assert( + BLOCK_SIZE_QUANT_DIM % 32 == 0, + f"{BLOCK_SIZE_QUANT_DIM=} must be a multiple of 32", + ) + + mx_tensor_dtype: tl.constexpr = mx_tensor_ptr.dtype.element_ty + tl.static_assert( + mx_tensor_dtype == tl.float8e4nv or mx_tensor_dtype == tl.float8e5, + f"Invalid {mx_tensor_dtype=}. Must be float8.", + ) + + src_dtype: tl.constexpr = src_ptr.dtype.element_ty + tl.static_assert( + mx_scale_ptr.dtype.element_ty == tl.uint8, + f"{mx_scale_ptr.dtype.element_ty=} must be uint8", + ) + tl.static_assert( + (src_dtype == tl.float32) + or (src_dtype == tl.bfloat16) + or (src_dtype == tl.float16), + f"{src_dtype=} must be float32 or bfloat16 or float16", + ) + + outer_block = tl.program_id(0).to(tl.int64) + quant_block = tl.program_id(1).to(tl.int64) + + BLOCK_SIZE_QUANT_MX_SCALE: tl.constexpr = BLOCK_SIZE_QUANT_DIM // 32 + BLOCK_SIZE_QUANT_MX_TENSOR: tl.constexpr = BLOCK_SIZE_QUANT_DIM + + start_src_quant = quant_block * BLOCK_SIZE_QUANT_DIM + start_mx_scale_quant = quant_block * BLOCK_SIZE_QUANT_MX_SCALE + start_mx_quant = quant_block * BLOCK_SIZE_QUANT_MX_TENSOR + start_out = outer_block * BLOCK_SIZE_OUT_DIM + + src_ptr += start_src_quant * stride_src_quant + start_out * stride_src_outer + mx_scale_ptr += ( + start_mx_scale_quant * stride_mx_scale_quant + start_out * stride_mx_scale_outer + ) + mx_tensor_ptr += start_mx_quant * stride_mxt_quant + start_out * stride_mxt_outer + + offs_src_quant = tl.arange(0, BLOCK_SIZE_QUANT_DIM)[None, :].to(tl.int64) + offs_mxt_quant = tl.arange(0, BLOCK_SIZE_QUANT_MX_TENSOR)[None, :].to(tl.int64) + offs_scale_quant = tl.arange(0, BLOCK_SIZE_QUANT_MX_SCALE)[None, :].to(tl.int64) + offs_outer = tl.arange(0, BLOCK_SIZE_OUT_DIM)[:, None].to(tl.int64) + + mask_src_quant = start_src_quant + offs_src_quant < quant_dim + mask_n = start_out + offs_outer < outer_dim + full_mask_src = mask_src_quant & mask_n + + mask_mxt_quant = start_mx_quant + offs_mxt_quant < quant_dim + full_mask_mxt = mask_mxt_quant & mask_n + + scale_mask_k = start_mx_scale_quant + offs_scale_quant < tl.cdiv(quant_dim, 32) + full_scale_mask = scale_mask_k & mask_n + + src_tensor_offsets = ( + offs_src_quant * stride_src_quant + offs_outer * stride_src_outer + ) + mx_scale_offsets = ( + offs_scale_quant * stride_mx_scale_quant + offs_outer * stride_mx_scale_outer + ) + mx_tensor_offsets = ( + offs_mxt_quant * stride_mxt_quant + offs_outer * stride_mxt_outer + ) + src_tensor = tl.load(src_ptr + src_tensor_offsets, mask=full_mask_src) + + out_tensor, scale_tensor = _compute_mx_quant_and_scale( + src_tensor, full_mask_src, mx_tensor_dtype, SCALE_ROUNDING_MODE + ) + + tl.store(mx_scale_ptr + mx_scale_offsets, scale_tensor, mask=full_scale_mask) + tl.store(mx_tensor_ptr + mx_tensor_offsets, out_tensor, mask=full_mask_mxt) + + +@triton.jit +def _upcast_from_mxfp8( + out_ptr, + stride_o_outer, + stride_o_quant: tl.constexpr, + mx_scale_ptr, + stride_scale_outer, + stride_scale_quant, + mx_tensor_ptr, + stride_tensor_outer, + stride_tensor_quant: tl.constexpr, + outer_dim, + quant_dim, + BLOCK_SIZE_OUT_DIM: tl.constexpr, + BLOCK_SIZE_QUANT_DIM: tl.constexpr, +): + + tl.static_assert( + stride_o_quant == 1, "the weight must be contiguous in the k dimension for mx" + ) + tl.static_assert( + BLOCK_SIZE_QUANT_DIM % 32 == 0, "BLOCK_SIZE_K must be a multiple of 32" + ) + # uint8 signifies two fp4 e2m1 values packed into a single byte + mx_tensor_dtype: tl.constexpr = mx_tensor_ptr.dtype.element_ty + dst_dtype: tl.constexpr = out_ptr.dtype.element_ty + tl.static_assert( + dst_dtype == tl.float32 or dst_dtype == tl.float16 or dst_dtype == tl.bfloat16 + ) + tl.static_assert( + (mx_tensor_dtype == tl.float8e4nv or mx_tensor_dtype == tl.float8e5) + or mx_tensor_dtype == dst_dtype, + "mx_tensor_ptr must be float8 or dst_dtype", + ) + tl.static_assert( + mx_scale_ptr.dtype.element_ty == tl.uint8, "mx_scale_ptr must be uint8" + ) + + BLOCK_SIZE_QUANT_MX_SCALE: tl.constexpr = BLOCK_SIZE_QUANT_DIM // 32 + BLOCK_SIZE_QUANT_MX_TENSOR: tl.constexpr = BLOCK_SIZE_QUANT_DIM + + # Compute starting indices for the quantized (packed) dimension and the outer dimension. + outer_block = tl.program_id(0).to(tl.int64) + quant_block = tl.program_id(1).to(tl.int64) + + start_mxt_quant = quant_block * BLOCK_SIZE_QUANT_MX_TENSOR + start_out_quant = quant_block * BLOCK_SIZE_QUANT_DIM + start_mx_scale_quant = quant_block * BLOCK_SIZE_QUANT_MX_SCALE + start_out = outer_block * BLOCK_SIZE_OUT_DIM + + mx_tensor_ptr += ( + start_mxt_quant * stride_tensor_quant + start_out * stride_tensor_outer + ) + mx_scale_ptr += ( + start_mx_scale_quant * stride_scale_quant + start_out * stride_scale_outer + ) + out_ptr += start_out * stride_o_outer + start_out_quant * stride_o_quant + + # Compute offsets and masks. + offs_src_quant = tl.arange(0, BLOCK_SIZE_QUANT_MX_TENSOR)[None, :].to(tl.int64) + offs_out_quant = tl.arange(0, BLOCK_SIZE_QUANT_DIM)[None, :].to(tl.int64) + offs_outer = tl.arange(0, BLOCK_SIZE_OUT_DIM)[:, None].to(tl.int64) + offs_scale = tl.arange(0, BLOCK_SIZE_QUANT_MX_SCALE)[None, :].to(tl.int64) + + mask_outer = start_out + offs_outer < outer_dim + mask_out_quant = start_out_quant + offs_out_quant < quant_dim + full_mask_out = mask_out_quant & mask_outer + + mask_src_quant = start_mxt_quant + offs_src_quant < quant_dim + full_mask_src = mask_src_quant & mask_outer + + mask_scale = start_mx_scale_quant + offs_scale < tl.cdiv(quant_dim, 32) + full_scale_mask = mask_scale & mask_outer + + tensor_offsets = ( + offs_src_quant * stride_tensor_quant + offs_outer * stride_tensor_outer + ) + scale_offsets = offs_scale * stride_scale_quant + offs_outer * stride_scale_outer + out_offsets = offs_out_quant * stride_o_quant + offs_outer * stride_o_outer + + # Load the packed tensor and scale. + tensor = tl.load(mx_tensor_ptr + tensor_offsets, mask=full_mask_src) + scale = tl.load(mx_scale_ptr + scale_offsets, mask=full_scale_mask) + + # Upcast the scale. + dst_scale = (scale.to(tl.uint32) << 23).to(tl.float32, bitcast=True) + + # Now upcast the tensor. + dst_tensor = tensor.to(tl.float32) + + # Reshape for proper broadcasting: the scale was stored with a 32-sized "inner" grouping. + dst_tensor = dst_tensor.reshape([BLOCK_SIZE_OUT_DIM, BLOCK_SIZE_QUANT_MX_SCALE, 32]) + dst_scale = dst_scale.reshape([BLOCK_SIZE_OUT_DIM, BLOCK_SIZE_QUANT_MX_SCALE, 1]) + scale = scale.reshape(dst_scale.shape) + + out_tensor = dst_tensor * dst_scale + # Correct any NaNs encoded via the scale. + out_tensor = tl.where(scale == 0xFF, float("nan"), out_tensor) + out_tensor = out_tensor.reshape([BLOCK_SIZE_OUT_DIM, BLOCK_SIZE_QUANT_DIM]) + out_tensor = out_tensor.to(dst_dtype) + tl.store(out_ptr + out_offsets, out_tensor, mask=full_mask_out) + + +@triton.jit +def _compute_mx_quant_and_scale_2d( + src_tensor, + valid_src_mask, + mx_tensor_dtype: tl.constexpr, + SCALE_ROUNDING_MODE: tl.constexpr = 0, +): + BLOCK_SIZE_M: tl.constexpr = src_tensor.shape[0] + BLOCK_SIZE_N: tl.constexpr = src_tensor.shape[1] + BLOCK_SIZE_M_SCALE: tl.constexpr = BLOCK_SIZE_M // 32 + BLOCK_SIZE_N_SCALE: tl.constexpr = BLOCK_SIZE_N // 32 + + # Explicit cast to fp32 since most ops are not supported on bfloat16. + f32_tensor = src_tensor.to(tl.float32) + abs_tensor = tl.abs(f32_tensor) + # Don't consider padding tensors in scale computation + abs_tensor = tl.where(valid_src_mask, abs_tensor, -1.0) + + # Reshape to (M_SCALE, 32, N_SCALE, 32) so that for each (i, j) scale block, + # the elements live at abs_4d[i, :, j, :] — a 32x32 sub-block. + abs_4d = tl.reshape(abs_tensor, [BLOCK_SIZE_M_SCALE, 32, BLOCK_SIZE_N_SCALE, 32]) + # Two sequential reductions to compute the max over each 32x32 block. + max_val = tl.max(abs_4d, axis=3, keep_dims=True) + max_val = tl.max(max_val, axis=1, keep_dims=True) + + dequant_scale = max_val / _get_max_quant_val(mx_tensor_dtype) + if SCALE_ROUNDING_MODE == 0: + # ROUND_UP + dequant_scale_exponent = ( + dequant_scale.to(tl.uint32, bitcast=True) + 0x007FFFFF + ) & 0x7F800000 + else: + # ROUND_DOWN + tl.static_assert(SCALE_ROUNDING_MODE == 1) + dequant_scale_exponent = dequant_scale.to(tl.uint32, bitcast=True) & 0x7F800000 + dequant_scale_rounded = dequant_scale_exponent.to(tl.float32, bitcast=True) + quant_scale = tl.where(dequant_scale_rounded == 0, 0, 1.0 / dequant_scale_rounded) + + # Broadcast (M_SCALE, 1, N_SCALE, 1) over (M_SCALE, 32, N_SCALE, 32). + f32_tensor_4d = tl.reshape( + f32_tensor, [BLOCK_SIZE_M_SCALE, 32, BLOCK_SIZE_N_SCALE, 32] + ) + quant_tensor = f32_tensor_4d * quant_scale + + quant_tensor = quant_tensor.reshape([BLOCK_SIZE_M, BLOCK_SIZE_N]) + quant_tensor = tl.where(valid_src_mask, quant_tensor, 0) + + dequant_scale_exponent = dequant_scale_exponent.reshape( + [BLOCK_SIZE_M_SCALE, BLOCK_SIZE_N_SCALE] + ) + dequant_scale_exponent = (dequant_scale_exponent >> 23).to(tl.uint8) + out_tensor = quant_tensor.to(mx_tensor_dtype) + + return out_tensor, dequant_scale_exponent + + +@triton.jit +def _downcast_to_mxfp8_2d( + mx_tensor_ptr, + stride_mxt_b, + stride_mxt_m, + stride_mxt_n: tl.constexpr, + mx_scale_ptr, + stride_mx_scale_b, + stride_mx_scale_m, + stride_mx_scale_n, + src_ptr, + stride_src_b, + stride_src_m, + stride_src_n, + M, + N, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + SCALE_ROUNDING_MODE: tl.constexpr, +): + + tl.static_assert(stride_mxt_n == 1, f"Output stride, {stride_mxt_n=} must be 1.") + tl.static_assert( + BLOCK_SIZE_M % 32 == 0, f"{BLOCK_SIZE_M=} must be a multiple of 32" + ) + tl.static_assert( + BLOCK_SIZE_N % 32 == 0, f"{BLOCK_SIZE_N=} must be a multiple of 32" + ) + + mx_tensor_dtype: tl.constexpr = mx_tensor_ptr.dtype.element_ty + tl.static_assert( + mx_tensor_dtype == tl.float8e4nv or mx_tensor_dtype == tl.float8e5, + f"Invalid {mx_tensor_dtype=}. Must be float8.", + ) + + src_dtype: tl.constexpr = src_ptr.dtype.element_ty + tl.static_assert( + mx_scale_ptr.dtype.element_ty == tl.uint8, + f"{mx_scale_ptr.dtype.element_ty=} must be uint8", + ) + tl.static_assert( + (src_dtype == tl.float32) + or (src_dtype == tl.bfloat16) + or (src_dtype == tl.float16), + f"{src_dtype=} must be float32 or bfloat16 or float16", + ) + + batch = tl.program_id(0).to(tl.int64) + m_block = tl.program_id(1).to(tl.int64) + n_block = tl.program_id(2).to(tl.int64) + + BLOCK_SIZE_M_SCALE: tl.constexpr = BLOCK_SIZE_M // 32 + BLOCK_SIZE_N_SCALE: tl.constexpr = BLOCK_SIZE_N // 32 + + start_m = m_block * BLOCK_SIZE_M + start_n = n_block * BLOCK_SIZE_N + start_scale_m = m_block * BLOCK_SIZE_M_SCALE + start_scale_n = n_block * BLOCK_SIZE_N_SCALE + + src_ptr += batch * stride_src_b + start_m * stride_src_m + start_n * stride_src_n + mx_tensor_ptr += ( + batch * stride_mxt_b + start_m * stride_mxt_m + start_n * stride_mxt_n + ) + mx_scale_ptr += ( + batch * stride_mx_scale_b + + start_scale_m * stride_mx_scale_m + + start_scale_n * stride_mx_scale_n + ) + + offs_m = tl.arange(0, BLOCK_SIZE_M)[:, None].to(tl.int64) + offs_n = tl.arange(0, BLOCK_SIZE_N)[None, :].to(tl.int64) + offs_scale_m = tl.arange(0, BLOCK_SIZE_M_SCALE)[:, None].to(tl.int64) + offs_scale_n = tl.arange(0, BLOCK_SIZE_N_SCALE)[None, :].to(tl.int64) + + mask_m = start_m + offs_m < M + mask_n = start_n + offs_n < N + full_mask = mask_m & mask_n + + mask_scale_m = start_scale_m + offs_scale_m < tl.cdiv(M, 32) + mask_scale_n = start_scale_n + offs_scale_n < tl.cdiv(N, 32) + full_scale_mask = mask_scale_m & mask_scale_n + + src_offsets = offs_m * stride_src_m + offs_n * stride_src_n + mx_tensor_offsets = offs_m * stride_mxt_m + offs_n * stride_mxt_n + scale_offsets = offs_scale_m * stride_mx_scale_m + offs_scale_n * stride_mx_scale_n + + src_tensor = tl.load(src_ptr + src_offsets, mask=full_mask) + + out_tensor, scale_tensor = _compute_mx_quant_and_scale_2d( + src_tensor, full_mask, mx_tensor_dtype, SCALE_ROUNDING_MODE + ) + + tl.store(mx_scale_ptr + scale_offsets, scale_tensor, mask=full_scale_mask) + tl.store(mx_tensor_ptr + mx_tensor_offsets, out_tensor, mask=full_mask) + + +@triton.jit +def _upcast_from_mxfp8_2d( + out_ptr, + stride_o_b, + stride_o_m, + stride_o_n: tl.constexpr, + mx_scale_ptr, + stride_scale_b, + stride_scale_m, + stride_scale_n, + mx_tensor_ptr, + stride_tensor_b, + stride_tensor_m, + stride_tensor_n: tl.constexpr, + M, + N, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, +): + + tl.static_assert( + stride_o_n == 1, + "the weight must be contiguous in the n dimension for mx", + ) + tl.static_assert(BLOCK_SIZE_M % 32 == 0, "BLOCK_SIZE_M must be a multiple of 32") + tl.static_assert(BLOCK_SIZE_N % 32 == 0, "BLOCK_SIZE_N must be a multiple of 32") + mx_tensor_dtype: tl.constexpr = mx_tensor_ptr.dtype.element_ty + dst_dtype: tl.constexpr = out_ptr.dtype.element_ty + tl.static_assert( + dst_dtype == tl.float32 or dst_dtype == tl.float16 or dst_dtype == tl.bfloat16 + ) + tl.static_assert( + (mx_tensor_dtype == tl.float8e4nv or mx_tensor_dtype == tl.float8e5) + or mx_tensor_dtype == dst_dtype, + "mx_tensor_ptr must be float8 or dst_dtype", + ) + tl.static_assert( + mx_scale_ptr.dtype.element_ty == tl.uint8, "mx_scale_ptr must be uint8" + ) + + BLOCK_SIZE_M_SCALE: tl.constexpr = BLOCK_SIZE_M // 32 + BLOCK_SIZE_N_SCALE: tl.constexpr = BLOCK_SIZE_N // 32 + + batch = tl.program_id(0).to(tl.int64) + m_block = tl.program_id(1).to(tl.int64) + n_block = tl.program_id(2).to(tl.int64) + + start_m = m_block * BLOCK_SIZE_M + start_n = n_block * BLOCK_SIZE_N + start_scale_m = m_block * BLOCK_SIZE_M_SCALE + start_scale_n = n_block * BLOCK_SIZE_N_SCALE + + mx_tensor_ptr += ( + batch * stride_tensor_b + start_m * stride_tensor_m + start_n * stride_tensor_n + ) + mx_scale_ptr += ( + batch * stride_scale_b + + start_scale_m * stride_scale_m + + start_scale_n * stride_scale_n + ) + out_ptr += batch * stride_o_b + start_m * stride_o_m + start_n * stride_o_n + + offs_m = tl.arange(0, BLOCK_SIZE_M)[:, None].to(tl.int64) + offs_n = tl.arange(0, BLOCK_SIZE_N)[None, :].to(tl.int64) + offs_scale_m = tl.arange(0, BLOCK_SIZE_M_SCALE)[:, None].to(tl.int64) + offs_scale_n = tl.arange(0, BLOCK_SIZE_N_SCALE)[None, :].to(tl.int64) + + mask_m = start_m + offs_m < M + mask_n = start_n + offs_n < N + full_mask = mask_m & mask_n + + mask_scale_m = start_scale_m + offs_scale_m < tl.cdiv(M, 32) + mask_scale_n = start_scale_n + offs_scale_n < tl.cdiv(N, 32) + full_scale_mask = mask_scale_m & mask_scale_n + + tensor_offsets = offs_m * stride_tensor_m + offs_n * stride_tensor_n + scale_offsets = offs_scale_m * stride_scale_m + offs_scale_n * stride_scale_n + out_offsets = offs_m * stride_o_m + offs_n * stride_o_n + + tensor = tl.load(mx_tensor_ptr + tensor_offsets, mask=full_mask) + scale = tl.load(mx_scale_ptr + scale_offsets, mask=full_scale_mask) + + dst_scale = (scale.to(tl.uint32) << 23).to(tl.float32, bitcast=True) + dst_tensor = tensor.to(tl.float32) + + # Broadcast the per-32x32-block scale across the full tile. + dst_tensor = dst_tensor.reshape([BLOCK_SIZE_M_SCALE, 32, BLOCK_SIZE_N_SCALE, 32]) + dst_scale = dst_scale.reshape([BLOCK_SIZE_M_SCALE, 1, BLOCK_SIZE_N_SCALE, 1]) + scale_4d = scale.reshape([BLOCK_SIZE_M_SCALE, 1, BLOCK_SIZE_N_SCALE, 1]) + + out_tensor = dst_tensor * dst_scale + out_tensor = tl.where(scale_4d == 0xFF, float("nan"), out_tensor) + out_tensor = out_tensor.reshape([BLOCK_SIZE_M, BLOCK_SIZE_N]) + out_tensor = out_tensor.to(dst_dtype) + tl.store(out_ptr + out_offsets, out_tensor, mask=full_mask) diff --git a/aiter/ops/triton/attention/mha_fused_bwd.py b/aiter/ops/triton/attention/mha_fused_bwd.py index 634d8a0a2c..b7eb191656 100644 --- a/aiter/ops/triton/attention/mha_fused_bwd.py +++ b/aiter/ops/triton/attention/mha_fused_bwd.py @@ -8,6 +8,7 @@ from aiter.ops.triton.utils.types import _is_fp8 from aiter.ops.triton.utils.logger import AiterTritonLogger from aiter.ops.triton._triton_kernels.attention.mha_fused_bwd import ( + _bwd_preprocess_mxfp8, _bwd_preprocess, _bwd_kernel_dkdvdq_causal, _bwd_kernel_dkdvdq_noncausal, @@ -18,6 +19,73 @@ _LOGGER = AiterTritonLogger() +def bwd_preprocess_mxfp8( + o: torch.Tensor, + o_scale: torch.Tensor, + do: torch.Tensor, + do_scale: torch.Tensor, + config: Optional[Dict[str, any]] = None, +): + """ + Backward mx8 preprocess function. + + Args: + o (torch.Tensor): Output from forward pass. Shape (..., seqlen, head_dim) + o_scale (torch.Tensor): MX scales for o computed along head dimension. Shape (..., seqlen, head_dim // 32) + do (torch.Tensor): Output gradient. Shape (..., seqlen, head_dim) + do_scale (torch.Tensor): Output gradient. Shape (..., seqlen, head_dim // 32) + config (Optional[Dict[str, any]]): Kernel tuning parameters. + + Returns: + torch.Tensor: Delta tensor (element-wise product of do and o) with shape matching softmax_lse. + """ + + # get strides and shape + if o.dim() > 3: # flatten batch and number of heads dimensions + o = o.reshape(-1, o.shape[-2], o.shape[-1]) + o_scale = o_scale.reshape(-1, o.shape[-2], o.shape[-1]) + do = do.reshape(-1, do.shape[-2], do.shape[-1]) + do_scale = do_scale.reshape(-1, do.shape[-2], do.shape[-1]) + batch, seqlen, head_dim = o.shape + + # BLOCK_D, BLOCK_D_POW2 + # padding for head_dim. Power of 2 or 16 + BLOCK_D_POW2 = triton.next_power_of_2(head_dim) + BLOCK_D_POW2 = max(BLOCK_D_POW2, 16) + + # init delta + delta = torch.empty((batch, seqlen), dtype=torch.float32).cuda() + + # preprocess + # compute D(delta) = rowsum(dO*O). Note, multiplication is element-wise. + if config is None: + config = _get_config() + + pre_grid = ( + triton.cdiv(seqlen, config["preprocess_kernel"]["PRE_BLOCK"]), + batch, + ) + + _bwd_preprocess_mxfp8[pre_grid]( + o, + o_scale, + do, + do_scale, + delta, + *o.stride(), + *o_scale.stride(), + *do.stride(), + *do_scale.stride(), + *delta.stride(), + seqlen, + BLOCK_M=config["preprocess_kernel"]["PRE_BLOCK"], + BLOCK_D=head_dim, + BLOCK_D_POW2=BLOCK_D_POW2, + ) + + return delta + + def flash_attn_fused_backward( do: torch.Tensor, q: torch.Tensor, diff --git a/aiter/ops/triton/quant/mxfp8_quant.py b/aiter/ops/triton/quant/mxfp8_quant.py new file mode 100644 index 0000000000..cd9eb4f63a --- /dev/null +++ b/aiter/ops/triton/quant/mxfp8_quant.py @@ -0,0 +1,266 @@ +# SPDX-License-Identifier: MIT +# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. + +import triton +import torch +from aiter.ops.triton._triton_kernels.quant.mxfp8_quant import ( + _downcast_to_mxfp8, + _upcast_from_mxfp8, + _downcast_to_mxfp8_2d, + _upcast_from_mxfp8_2d, +) +from aiter.ops.triton.utils.logger import AiterTritonLogger + +__all__ = [ + "downcast_to_mxfp8", + "upcast_from_mxfp8", + "downcast_to_mxfp8_2d", + "upcast_from_mxfp8_2d", +] + + +_LOGGER = AiterTritonLogger() + + +def downcast_to_mxfp8( + src_tensor: torch.Tensor, + out_quant_type: torch.dtype, + axis: int, + SCALE_ROUNDING_MODE: int = 0, +): + """ + Convert the src weights to mx8 format. The src weight is quantized along the axis dimension. + + If weight_quant_type is torch.uint8, we output mxfp4 where two e2m1 values are packed into a single byte. + Note that this means the k_dim of the tensor will be half of the logical k_dim. + + If weight_quant_type is torch.float8_e4m3fn or torch.float8_e5m2, we output mxfp8 with the float8s are stored + in their respective formats. + """ + ndim = src_tensor.ndim + assert -ndim <= axis < ndim, f"Invalid axis {axis=}" + axis = axis if axis >= 0 else axis + ndim + # downcast + src_tensor = src_tensor.transpose(axis, src_tensor.ndim - 1) + L = src_tensor.shape[-1] + out_shape = src_tensor.shape[:-1] + (L,) + out_scale_shape = src_tensor.shape[:-1] + (triton.cdiv(L, 32),) + + out_quant_tensor = src_tensor.new_empty(out_shape, dtype=out_quant_type) + out_scale = src_tensor.new_empty(out_scale_shape, dtype=torch.uint8) + + kernel_src_tensor = src_tensor.reshape(-1, src_tensor.shape[-1]) + kernel_quant_tensor = out_quant_tensor.view(-1, out_quant_tensor.shape[-1]) + kernel_scale = out_scale.view(-1, out_scale.shape[-1]) + + BLOCK_OUT_DIM = 128 + BLOCK_QUANT_DIM = 32 + grid_out = triton.cdiv(kernel_src_tensor.shape[0], BLOCK_OUT_DIM) + grid_quant = triton.cdiv(kernel_src_tensor.shape[1], BLOCK_QUANT_DIM) + + _downcast_to_mxfp8[(grid_out, grid_quant)]( + kernel_quant_tensor, + *kernel_quant_tensor.stride(), + kernel_scale, + *kernel_scale.stride(), + kernel_src_tensor, + *kernel_src_tensor.stride(), + *kernel_src_tensor.shape, + BLOCK_OUT_DIM, + BLOCK_QUANT_DIM, + SCALE_ROUNDING_MODE, + num_warps=8, + ) + + out_quant_tensor = out_quant_tensor.transpose(axis, src_tensor.ndim - 1) + out_scale = out_scale.transpose(axis, src_tensor.ndim - 1) + return out_quant_tensor, out_scale + + +def upcast_from_mxfp8( + tensor: torch.Tensor, scale: torch.Tensor, dtype: torch.dtype, axis: int +): + """ + Upcasts an mxfp8 weight tensor back to float16 or bfloat16. + + The function assumes that the tensors were quantized along the given axis. + It permutes the tensor so that the quantized axis is last, reshapes to 2D, + launches the Triton upcast kernel, and then unpermutes back to the original order. + """ + ndim = tensor.ndim + assert -ndim <= axis < ndim, f"Invalid axis {axis=}" + axis = axis if axis >= 0 else axis + ndim + assert tensor.ndim == scale.ndim, ( + f"Weight and scale must have the same number of dimensions. " + f"Got {tensor.ndim=} and {scale.ndim=}" + ) + # dtype checks + assert tensor.dtype in { + torch.float8_e5m2, + torch.float8_e4m3fn, + torch.float8_e4m3fnuz, + }, f"Invalid tensor dtype {tensor.dtype=}" + assert scale.dtype == torch.uint8, f"Invalid scale dtype {scale.dtype=}" + assert dtype in ( + torch.float16, + torch.bfloat16, + torch.float32, + ), f"Invalid output dtype {dtype=}" + # upcast + logical_quant_dim = tensor.shape[axis] + tensor = tensor.transpose(axis, tensor.ndim - 1).contiguous() + scale = scale.transpose(axis, scale.ndim - 1).contiguous() + out = torch.empty( + (*tensor.shape[:-1], logical_quant_dim), dtype=dtype, device=tensor.device + ) + reshaped_out = out.view(-1, out.shape[-1]) + reshaped_tensor = tensor.view(-1, tensor.shape[-1]) + reshaped_scale = scale.view(-1, scale.shape[-1]) + BLOCK_OUT_DIM = 128 + BLOCK_QUANT_DIM = 32 + blocks_out_dim = triton.cdiv(reshaped_out.shape[0], BLOCK_OUT_DIM) + blocks_quant_dim = triton.cdiv(reshaped_out.shape[1], BLOCK_QUANT_DIM) + _upcast_from_mxfp8[(blocks_out_dim, blocks_quant_dim)]( + reshaped_out, + *reshaped_out.stride(), + reshaped_scale, + *reshaped_scale.stride(), + reshaped_tensor, + *reshaped_tensor.stride(), + *reshaped_out.shape, + BLOCK_OUT_DIM, + BLOCK_QUANT_DIM, + num_warps=8, + ) + out = out.transpose(axis, scale.ndim - 1).contiguous() + return out + + +def downcast_to_mxfp8_2d( + src_tensor: torch.Tensor, + out_quant_type: torch.dtype, + SCALE_ROUNDING_MODE: int = 0, +): + """ + Convert the last two dimensions of ``src_tensor`` to the mxfp8 format, + where each 32x32 block of those two dimensions shares a single scale. + + The quantized tensor preserves the input shape; the scale tensor has shape + ``(..., cdiv(M, 32), cdiv(N, 32))`` and dtype uint8, with ``M`` and ``N`` + being the last two dims of ``src_tensor``. + + ``out_quant_type`` must be ``torch.float8_e4m3fn`` or ``torch.float8_e5m2``. + """ + assert ( + src_tensor.ndim >= 2 + ), f"src_tensor must have at least 2 dimensions, got {src_tensor.ndim}" + assert out_quant_type in { + torch.float8_e4m3fn, + torch.float8_e5m2, + }, f"Invalid out_quant_type {out_quant_type=}" + + src_tensor = src_tensor.contiguous() + M = src_tensor.shape[-2] + N = src_tensor.shape[-1] + M_scale = triton.cdiv(M, 32) + N_scale = triton.cdiv(N, 32) + + out_shape = src_tensor.shape + out_scale_shape = src_tensor.shape[:-2] + (M_scale, N_scale) + + out_quant_tensor = src_tensor.new_empty(out_shape, dtype=out_quant_type) + out_scale = src_tensor.new_empty(out_scale_shape, dtype=torch.uint8) + + kernel_src = src_tensor.reshape(-1, M, N) + kernel_quant = out_quant_tensor.view(-1, M, N) + kernel_scale = out_scale.view(-1, M_scale, N_scale) + + BLOCK_M = 128 + BLOCK_N = 128 + B = kernel_src.shape[0] + grid = (B, triton.cdiv(M, BLOCK_M), triton.cdiv(N, BLOCK_N)) + + _downcast_to_mxfp8_2d[grid]( + kernel_quant, + *kernel_quant.stride(), + kernel_scale, + *kernel_scale.stride(), + kernel_src, + *kernel_src.stride(), + M, + N, + BLOCK_M, + BLOCK_N, + SCALE_ROUNDING_MODE, + num_warps=8, + ) + + return out_quant_tensor, out_scale + + +def upcast_from_mxfp8_2d( + tensor: torch.Tensor, + scale: torch.Tensor, + dtype: torch.dtype, +): + """ + Inverse of :func:`downcast_to_mxfp8_2d`. ``scale`` is expected to be a + ``(..., cdiv(M, 32), cdiv(N, 32))`` uint8 tensor where each entry + corresponds to a 32x32 block of ``tensor``'s last two dimensions. + """ + assert tensor.ndim >= 2 and scale.ndim >= 2 + assert tensor.dtype in { + torch.float8_e5m2, + torch.float8_e4m3fn, + torch.float8_e4m3fnuz, + }, f"Invalid tensor dtype {tensor.dtype=}" + assert scale.dtype == torch.uint8, f"Invalid scale dtype {scale.dtype=}" + assert dtype in ( + torch.float16, + torch.bfloat16, + torch.float32, + ), f"Invalid output dtype {dtype=}" + assert tensor.ndim == scale.ndim, ( + f"tensor and scale must have the same number of dimensions. " + f"Got {tensor.ndim=} and {scale.ndim=}" + ) + + tensor = tensor.contiguous() + scale = scale.contiguous() + M = tensor.shape[-2] + N = tensor.shape[-1] + M_scale = scale.shape[-2] + N_scale = scale.shape[-1] + assert M_scale == triton.cdiv( + M, 32 + ), f"scale shape mismatch: got {M_scale=} expected {triton.cdiv(M, 32)}" + assert N_scale == triton.cdiv( + N, 32 + ), f"scale shape mismatch: got {N_scale=} expected {triton.cdiv(N, 32)}" + + out = torch.empty(tensor.shape, dtype=dtype, device=tensor.device) + + kernel_out = out.view(-1, M, N) + kernel_tensor = tensor.view(-1, M, N) + kernel_scale = scale.view(-1, M_scale, N_scale) + + BLOCK_M = 128 + BLOCK_N = 128 + B = kernel_out.shape[0] + grid = (B, triton.cdiv(M, BLOCK_M), triton.cdiv(N, BLOCK_N)) + + _upcast_from_mxfp8_2d[grid]( + kernel_out, + *kernel_out.stride(), + kernel_scale, + *kernel_scale.stride(), + kernel_tensor, + *kernel_tensor.stride(), + M, + N, + BLOCK_M, + BLOCK_N, + num_warps=8, + ) + + return out diff --git a/op_tests/flydsl_tests/test_attn_bwd_mxfp8_gfx950.py b/op_tests/flydsl_tests/test_attn_bwd_mxfp8_gfx950.py new file mode 100644 index 0000000000..30b8615d62 --- /dev/null +++ b/op_tests/flydsl_tests/test_attn_bwd_mxfp8_gfx950.py @@ -0,0 +1,277 @@ +#!/usr/bin/env python3 +"""Attention backward fp8 test — @flyc.kernel API. + +Kernel implementation lives in `kernels/attn_bwd_mxfp8_gfx950.py`. +""" + +import logging +import torch +import pytest + +from aiter.ops.triton.quant.mxfp8_quant import ( + downcast_to_mxfp8, + upcast_from_mxfp8, + downcast_to_mxfp8_2d, + upcast_from_mxfp8_2d, +) +from aiter.ops.flydsl.kernels.attn_bwd_mxfp8_gfx950 import compile_attn_bwd_mxfp8_gfx950 +from flydsl.runtime.device import get_rocm_arch + +logging.basicConfig(level=logging.INFO) + +if not torch.cuda.is_available(): + pytest.skip("CUDA/ROCm not available. Skipping GPU tests.", allow_module_level=True) + +ARCH = str(get_rocm_arch()) + + +def check_result(test_out, ref_out, atol=0.01, rtol=0.01, pass_pct=95.0): + """Compare outputs and print result. Returns (passed, max_delta, pct_close).""" + close_mask = torch.isclose(test_out.float(), ref_out.float(), atol=atol, rtol=rtol) + pct_close = close_mask.float().mean().item() * 100 + passed = pct_close > pass_pct + if passed: + return True + + max_delta = (ref_out.float() - test_out.float()).abs().max().item() + print( + f" max_delta={max_delta:.4f}, {pct_close:.1f}% close (atol={atol}, rtol={rtol})" + ) + print(f" ref sample: {ref_out.reshape(-1)[:8]}") + print(f" test sample: {test_out.reshape(-1)[:8]}") + print(f" --> {'PASS' if passed else 'FAIL'}") + + +def mx_quant(x, dim=-1): + x_fp8, x_scale = downcast_to_mxfp8(x, torch.float8_e4m3fn, dim) + x_fp32 = upcast_from_mxfp8(x_fp8, x_scale, torch.float32, dim) + return x_fp32.contiguous(), x_fp8.contiguous(), x_scale.contiguous() + + +def mx_quant_2d(x): + x_fp8, x_scale = downcast_to_mxfp8_2d(x, torch.float8_e4m3fn) + x_fp32 = upcast_from_mxfp8_2d(x_fp8, x_scale, torch.float32) + return x_fp32.contiguous(), x_fp8.contiguous(), x_scale.contiguous() + + +def run_torch( + q_fp32, + k_fp32, + v, + do_fp32, + m, + D, + sm_scale, + causal, + gqa_size, +): + batch = q_fp32.shape[0] + num_heads_q = q_fp32.shape[1] + num_heads_kv = num_heads_q // gqa_size + seqlen = q_fp32.shape[2] + head_dim = q_fp32.shape[3] + device = q_fp32.device + v_f32 = v.to(torch.float32) + qk = torch.matmul(q_fp32, k_fp32.transpose(-2, -1)) * sm_scale + p = torch.exp(qk - m[:, :, :, None]) + if causal: + mask = torch.tril(torch.ones((seqlen, seqlen), device=device)) + p[:, :, mask == 0] = 0.0 + + ppT, _, _ = mx_quant(p, -2) + ppT = ppT.transpose(-2, -1) + dv = torch.matmul(ppT, do_fp32) + dp = torch.matmul(do_fp32, v_f32.transpose(-2, -1)) + ds = p * (dp - D[:, :, :, None]) + dsT, _, _ = mx_quant(ds, -1) + dsT = dsT.transpose(-2, -1) + ds, _, _ = mx_quant(ds, -2) + dk = torch.matmul(dsT, q_fp32) * sm_scale + dq = torch.matmul(ds, k_fp32) * sm_scale + + dk = dk.view(batch, num_heads_kv, gqa_size, seqlen, head_dim).sum(dim=2) + dv = dv.view(batch, num_heads_kv, gqa_size, seqlen, head_dim).sum(dim=2) + + return dq, dk, dv + + +@pytest.mark.parametrize("batch", [1, 4]) +@pytest.mark.parametrize( + "num_heads_q, num_heads_kv", + [(48, 48), (64, 8), (80, 20)], +) +@pytest.mark.parametrize("seqlen", [128, 1024, 1056, 1152, 4096]) +@pytest.mark.parametrize("head_dim", [64, 128]) +@pytest.mark.parametrize("tile_m", [64, 128]) +@pytest.mark.parametrize("tile_n", [64, 128]) +@pytest.mark.parametrize("causal", [False, True]) +def test_attn_bwd_flyc( + batch, + num_heads_q, + num_heads_kv, + seqlen, + head_dim, + tile_m, + tile_n, + causal, + waves_per_eu: int = 0, +): + tile_head = head_dim + if tile_m == 128 and tile_head == 128: + pytest.skip("Too large block size") + + torch.manual_seed(0) + + sm_scale = 0.5 + _wpe = int(waves_per_eu) + launch_fn = compile_attn_bwd_mxfp8_gfx950( + num_heads_q=num_heads_q, + num_heads_kv=num_heads_kv, + seqlen=seqlen, + head_dim=head_dim, + tile_m=tile_m, + tile_n=tile_n, + tile_head=tile_head, + sm_scale=sm_scale, + causal=causal, + waves_per_eu=_wpe, + ) + + device = torch.device("cuda") + gqa_size = num_heads_q // num_heads_kv + q_fp32 = ( + torch.randn( + batch, num_heads_q, seqlen, head_dim, device=device, dtype=torch.float32 + ) + * 0.5 + ) + k_fp32 = ( + torch.randn( + batch, num_heads_kv, seqlen, head_dim, device=device, dtype=torch.float32 + ) + * 0.5 + ) + v_fp32 = ( + torch.randn( + batch, num_heads_kv, seqlen, head_dim, device=device, dtype=torch.float32 + ) + * 0.5 + ) + do_fp32 = ( + torch.randn( + batch, num_heads_q, seqlen, head_dim, device=device, dtype=torch.float32 + ) + * 0.5 + ) + + q_fp32, q_quant, q_scale = mx_quant_2d(q_fp32) + k_fp32, k_quant, k_scale = mx_quant_2d(k_fp32) + v_fp32, v_quant, v_scale = mx_quant(v_fp32) + do_fp32, do_quant, do_scale = mx_quant_2d(do_fp32) + + k_fp32 = k_fp32.repeat_interleave(gqa_size, dim=1) + v_fp32 = v_fp32.repeat_interleave(gqa_size, dim=1) + + qk = q_fp32 @ k_fp32.transpose(-2, -1) + qk = qk * sm_scale + m = qk.max(dim=-1)[0] + p = (qk - m[:, :, :, None]).exp() + L = p.sum(dim=-1) + p = p / L[:, :, :, None] + o_fp32 = torch.matmul(p, v_fp32) + m = m + torch.log(L) + D = (o_fp32 * do_fp32).sum(dim=-1) + + dq_ref, dk_ref, dv_ref = run_torch( + q_fp32, + k_fp32, + v_fp32, + do_fp32, + m, + D, + sm_scale, + causal, + gqa_size, + ) + + dq_fly = torch.zeros( + (batch, num_heads_q, seqlen, head_dim), dtype=torch.float32, device=device + ) + dk_fly = torch.zeros( + (batch, num_heads_kv, seqlen, head_dim), dtype=torch.float32, device=device + ) + dv_fly = torch.zeros( + (batch, num_heads_kv, seqlen, head_dim), dtype=torch.float32, device=device + ) + + def launch_kernel( + dq, + dk, + dv, + q, + q_scale, + k, + k_scale, + v, + v_scale, + do, + do_scale, + m, + D, + batch, + ): + launch_fn( + dq.contiguous().view(-1), + dk.contiguous().view(-1), + dv.contiguous().view(-1), + q.contiguous().view(-1), + q_scale.contiguous().view(-1), + k.contiguous().view(-1), + k_scale.contiguous().view(-1), + v.contiguous().view(-1), + v_scale.contiguous().view(-1), + do.contiguous().view(-1), + do_scale.contiguous().view(-1), + m.contiguous().view(-1), + D.contiguous().view(-1), + batch, + q.stride(0), + k.stride(0), + m.stride(0), + q.stride(1), + m.stride(1), + q_scale.stride(0), + q_scale.stride(1), + k_scale.stride(0), + k_scale.stride(1), + v_scale.stride(0), + v_scale.stride(1), + do_scale.stride(0), + do_scale.stride(1), + torch.cuda.current_stream(), + ) + + launch_kernel( + dq_fly, + dk_fly, + dv_fly, + q_quant, + q_scale, + k_quant, + k_scale, + v_quant, + v_scale, + do_quant, + do_scale, + m, + D, + batch, + ) + + dq_fly_fp32 = dq_fly.to(torch.float32) + dk_fly_fp32 = dk_fly.to(torch.float32) + dv_fly_fp32 = dv_fly.to(torch.float32) + + assert check_result(dq_fly_fp32, dq_ref, rtol=0.01, atol=0.01, pass_pct=99.0) + assert check_result(dk_fly_fp32, dk_ref, rtol=0.01, atol=0.01, pass_pct=99.0) + assert check_result(dv_fly_fp32, dv_ref, rtol=0.01, atol=0.01, pass_pct=99.0) diff --git a/op_tests/op_benchmarks/flydsl/bench_attn_bwd_mxfp8_gfx950.py b/op_tests/op_benchmarks/flydsl/bench_attn_bwd_mxfp8_gfx950.py new file mode 100644 index 0000000000..32c25933b8 --- /dev/null +++ b/op_tests/op_benchmarks/flydsl/bench_attn_bwd_mxfp8_gfx950.py @@ -0,0 +1,286 @@ +#!/usr/bin/env python3 +"""Attention backward test — @flyc.kernel API. + +Kernel implementation lives in `flydsl/kernels/attn_bwd_mxfp8_gfx950.py`. +This file is the perf and correctness harness. +""" + +import logging +import torch +from aiter.ops.flydsl.kernels.attn_bwd_mxfp8_gfx950 import compile_attn_bwd_mxfp8_gfx950 +from utils import run_perftest +from op_tests.flydsl_tests.test_attn_bwd_mxfp8_gfx950 import ( + run_torch, + mx_quant, + mx_quant_2d, + check_result, +) +from flydsl.runtime.device import get_rocm_arch + +logging.basicConfig(level=logging.INFO) +ARCH = str(get_rocm_arch()) +DEFAULT_BENCH_ITERS = 20 +DEFAULT_BENCH_WARMUP = 3 + + +def bench_attn_bwd_flyc( + batch, + num_heads_q, + num_heads_kv, + seqlen, + head_dim, + tile_m, + tile_n, + causal, + test_graph, + bench_iters: int = DEFAULT_BENCH_ITERS, + bench_warmup: int = DEFAULT_BENCH_WARMUP, + waves_per_eu: int = 0, + check_correctness: bool = False, +): + """Attention bwd using the @flyc.kernel / @flyc.jit API.""" + tile_head = head_dim + print("=" * 80) + print(f"[flyc] Attention Backward Test (Tile: {tile_m}x{tile_n}x{tile_head})") + print("=" * 80) + + sm_scale = 0.5 + _wpe = int(waves_per_eu) if waves_per_eu else 0 + launch_fn = compile_attn_bwd_mxfp8_gfx950( + num_heads_q=num_heads_q, + num_heads_kv=num_heads_kv, + seqlen=seqlen, + head_dim=head_dim, + tile_m=tile_m, + tile_n=tile_n, + tile_head=tile_head, + sm_scale=sm_scale, + causal=causal, + waves_per_eu=_wpe, + ) + print("✓ Kernel prepared") + + device = torch.device("cuda") + gqa_size = num_heads_q // num_heads_kv + q_fp32 = ( + torch.randn( + batch, num_heads_q, seqlen, head_dim, device=device, dtype=torch.float32 + ) + * 0.5 + ) + k_fp32 = ( + torch.randn( + batch, num_heads_kv, seqlen, head_dim, device=device, dtype=torch.float32 + ) + * 0.5 + ) + v_fp32 = ( + torch.randn( + batch, num_heads_kv, seqlen, head_dim, device=device, dtype=torch.float32 + ) + * 0.5 + ) + o_fp32 = ( + torch.randn( + batch, num_heads_q, seqlen, head_dim, device=device, dtype=torch.float32 + ) + * 0.5 + ) + do_fp32 = ( + torch.randn( + batch, num_heads_q, seqlen, head_dim, device=device, dtype=torch.float32 + ) + * 0.5 + ) + + q_fp32, q_quant, q_scale = mx_quant_2d(q_fp32) + k_fp32, k_quant, k_scale = mx_quant_2d(k_fp32) + v_fp32, v_quant, v_scale = mx_quant(v_fp32) + do_fp32, do_quant, do_scale = mx_quant_2d(do_fp32) + + k_fp32 = k_fp32.repeat_interleave(gqa_size, dim=1) + v_fp32 = v_fp32.repeat_interleave(gqa_size, dim=1) + + qk = torch.matmul(q_fp32, k_fp32.transpose(-2, -1)) + qk = qk * sm_scale + m = qk.max(dim=-1)[0] + p = (qk - m[:, :, :, None]).exp() + L = p.sum(dim=-1) + m = m + torch.log(L) + D = (o_fp32 * do_fp32).sum(dim=-1) + + if check_correctness: + dq_ref, dk_ref, dv_ref = run_torch( + q_fp32, + k_fp32, + v_fp32, + do_fp32, + m, + D, + sm_scale, + causal, + gqa_size, + ) + dq_fly = torch.zeros( + (batch, num_heads_q, seqlen, head_dim), dtype=torch.float32, device=device + ) + dk_fly = torch.zeros( + (batch, num_heads_kv, seqlen, head_dim), dtype=torch.float32, device=device + ) + dv_fly = torch.zeros( + (batch, num_heads_kv, seqlen, head_dim), dtype=torch.float32, device=device + ) + + def launch_kernel( + dq, + dk, + dv, + q, + q_scale, + k, + k_scale, + v, + v_scale, + do, + do_scale, + m, + D, + batch, + ): + launch_fn( + dq.contiguous().view(-1), + dk.contiguous().view(-1), + dv.contiguous().view(-1), + q.contiguous().view(-1), + q_scale.contiguous().view(-1), + k.contiguous().view(-1), + k_scale.contiguous().view(-1), + v.contiguous().view(-1), + v_scale.contiguous().view(-1), + do.contiguous().view(-1), + do_scale.contiguous().view(-1), + m.contiguous().view(-1), + D.contiguous().view(-1), + batch, + q.stride(0), + k.stride(0), + m.stride(0), + q.stride(1), + m.stride(1), + q_scale.stride(0), + q_scale.stride(1), + k_scale.stride(0), + k_scale.stride(1), + v_scale.stride(0), + v_scale.stride(1), + do_scale.stride(0), + do_scale.stride(1), + torch.cuda.current_stream(), + ) + + bench_iters = max(2, int(bench_iters)) + bench_warmup = int(bench_warmup) + _, us = run_perftest( + launch_kernel, + dq_fly, + dk_fly, + dv_fly, + q_quant, + q_scale, + k_quant, + k_scale, + v_quant, + v_scale, + do_quant, + do_scale, + m, + D, + batch, + num_iters=bench_iters, + num_warmup=bench_warmup, + testGraph=test_graph, + ) + + if check_correctness: + torch.cuda.synchronize() + + dq_fly.zero_() + dk_fly.zero_() + dv_fly.zero_() + launch_kernel( + dq_fly, + dk_fly, + dv_fly, + q_quant, + q_scale, + k_quant, + k_scale, + v_quant, + v_scale, + do_quant, + do_scale, + m, + D, + batch, + ) + + assert check_result(dq_fly, dq_ref, rtol=0.01, atol=0.01) + assert check_result(dk_fly, dk_ref, rtol=0.01, atol=0.01) + assert check_result(dv_fly, dv_ref, rtol=0.01, atol=0.01) + + bytes_moved = ( + (2 + 4) * batch * num_heads_q * seqlen * head_dim + + (2 + 2 * 4) * batch * num_heads_kv * seqlen * head_dim + + 2 * 4 * batch * num_heads_q * seqlen + ) + flops = ( + batch + * num_heads_q + * ( + 5 * 2 * seqlen * seqlen * head_dim + + 5 * seqlen * seqlen + + 2 * 3 * seqlen * seqlen + ) + ) + if causal: + flops /= 2 + tflops = flops / (us / 1e6) / 1e12 + tbps = bytes_moved / 1e12 / (us / 1e6) + print(f"[flyc] Throughput: {us:.1f} us, {tflops:.2f} TFLOPS, BW: {tbps:.3f} TB/s") + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser(description="Preshuffle GEMM benchmark") + parser.add_argument("--batch", type=int, default=1) + parser.add_argument("--num_heads_q", type=int, default=128) + parser.add_argument("--num_heads_kv", type=int, default=128) + parser.add_argument("--seqlen", type=int, default=1024) + parser.add_argument("--head", type=int, default=128) + parser.add_argument("--tile_m", type=int, default=128) + parser.add_argument("--tile_n", type=int, default=128) + parser.add_argument("--causal", action="store_true", default=False) + parser.add_argument("--num_iters", type=int, default=DEFAULT_BENCH_ITERS) + parser.add_argument("--num_warmup", type=int, default=DEFAULT_BENCH_WARMUP) + parser.add_argument("--waves_per_eu", type=int, default=0, choices=[0, 1, 2, 3, 4]) + parser.add_argument("--test_graph", action="store_true", default=False) + parser.add_argument("--check_correctness", action="store_true", default=False) + args = parser.parse_args() + torch.set_default_device("cuda") + + bench_attn_bwd_flyc( + batch=args.batch, + num_heads_q=args.num_heads_q, + num_heads_kv=args.num_heads_kv, + seqlen=args.seqlen, + head_dim=args.head, + tile_m=args.tile_m, + tile_n=args.tile_n, + causal=args.causal, + test_graph=bool(args.test_graph), + bench_iters=args.num_iters, + bench_warmup=args.num_warmup, + waves_per_eu=int(args.waves_per_eu), + check_correctness=args.check_correctness, + ) diff --git a/op_tests/op_benchmarks/flydsl/utils.py b/op_tests/op_benchmarks/flydsl/utils.py new file mode 100644 index 0000000000..8bf3824b30 --- /dev/null +++ b/op_tests/op_benchmarks/flydsl/utils.py @@ -0,0 +1,367 @@ +# SPDX-License-Identifier: MIT +# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +# from https://github.com/ROCm/aiter/blob/main/aiter/test_common.py +import torch +import torch.profiler as tpf +import os +import copy +import numpy as np +import pandas as pd +import logging + +logger = logging.getLogger("flydsl") + +pd.set_option("display.max_rows", 200) +## debug ## +# pd.set_option("display.max_rows", None) +# pd.set_option("display.max_columns", None) +# pd.set_option("display.width", None) +# pd.set_option("display.max_colwidth", None) +# pd.set_option("display.expand_frame_repr", False) + + +def perftest( + num_iters=20, num_warmup=3, testGraph=False, num_rotate_args=0, needTrace=False +): + def decorator(func): + def wrapper(*args, **kwargs): + # ROCm torch.profiler (ROCTracer) is not always stable when invoked repeatedly + # under pytest (multiple tests, repeated init/teardown). For unit tests, the + # profiler is not required; fall back to simple timing. + # + num = num_rotate_args + if num < 1: + gpu_id = torch.cuda.current_device() + iter_used_memory, inputSize, _, _ = device_memory_profiling( + func, *args, **kwargs + ) + + properties = torch.cuda.get_device_properties(gpu_id) + free_memory = torch.cuda.mem_get_info(gpu_id)[0] + cache_size = min( + getattr(properties, "L2_cache_size", 4096 * 1024) * 64 * 128, + (free_memory - iter_used_memory + inputSize) * 0.9, + ) + cache_size = max(cache_size, 0) + num = int((cache_size + inputSize - 1) // inputSize) + num = min(num, num_iters) + + rotate_args = [ + (copy.deepcopy(args), copy.deepcopy(kwargs)) for _ in range(num - 1) + ] + [(args, kwargs)] + run_iters(num_warmup, func, *args, **kwargs) + torch.cuda.synchronize() + + if int(os.environ.get("FLYDSL_LOG_MORE", 0)): + latencies = [] + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + for _ in range(num_iters): + start_event.record() + data = func(*args, **kwargs) + end_event.record() + end_event.synchronize() + latencies.append(start_event.elapsed_time(end_event)) + avg = np.mean(latencies) * 1000 + logger.info(f"avg: {avg} us/iter from cuda.Event") + with tpf.profile( + activities=[tpf.ProfilerActivity.CPU, tpf.ProfilerActivity.CUDA], + profile_memory=False, + with_stack=False, + with_modules=True, + ) as prof: + data = run_iters_rotate(num_iters, func, rotate_args) + torch.cuda.synchronize() + torch.cuda.empty_cache() + avg = get_trace_perf(prof, num_iters) + + if testGraph: + graph = torch.cuda.CUDAGraph() + with torch.cuda.graph(graph): + data = run_iters_rotate(num_iters, func, rotate_args) + with tpf.profile( + activities=[tpf.ProfilerActivity.CPU, tpf.ProfilerActivity.CUDA], + profile_memory=True, + with_stack=True, + with_modules=True, + ) as prof: + run_iters(1, graph.replay) + avg = get_trace_perf(prof, num_iters) + logger.info(f"avg: {avg} us/iter with hipgraph") + + return data, avg + + return wrapper + + return decorator + + +def benchmark(): + def decorator(func): + def wrapper(*args, **kwargs): + callargs = log_args(func, *args, **kwargs) + ret = func(*args, **kwargs) + if ret is not None: + callargs.update(ret) + return callargs + + return wrapper + + return decorator + + +def device_memory_profiling(func, *args, **kwargs): + gpu_id = torch.cuda.current_device() + inputSize = ( + sum( + [ + el.nbytes + for el in args + if isinstance(el, torch.Tensor) and el.device.index == gpu_id + ] + ) + + 1 + ) + torch.cuda.reset_peak_memory_stats(gpu_id) + cuda_memory_before = ( + torch.cuda.mem_get_info(gpu_id)[1] - torch.cuda.mem_get_info(gpu_id)[0] + ) + torch_memory_before = torch.cuda.memory_reserved(gpu_id) + torch_peak_before = torch.cuda.memory_stats(gpu_id).get( + "allocated_bytes.all.peak", 0 + ) + non_torch_memory_before = cuda_memory_before - torch_memory_before + + data = func(*args, **kwargs) + + torch.cuda.reset_peak_memory_stats(gpu_id) + cuda_memory_after = ( + torch.cuda.mem_get_info(gpu_id)[1] - torch.cuda.mem_get_info(gpu_id)[0] + ) + torch_memory_after = torch.cuda.memory_reserved(gpu_id) + torch_peak_after = torch.cuda.memory_stats(gpu_id).get( + "allocated_bytes.all.peak", 0 + ) + non_torch_memory_after = cuda_memory_after - torch_memory_after + + torch_peak_increase = torch_peak_after - torch_peak_before + non_torch_increase = non_torch_memory_after - non_torch_memory_before + iter_used_memory = torch_peak_increase + non_torch_increase + inputSize + + return iter_used_memory, inputSize, torch_peak_increase, non_torch_increase + + +def run_iters(num_iters, func, *args, **kwargs): + data = None + for _ in range(num_iters): + data = func(*args, **kwargs) + return data + + +def run_iters_rotate(num_iters, func, rotate_args): + data = None + num_rotate_args = len(rotate_args) + for _ in range(num_iters): + args, kwargs = rotate_args[_ % num_rotate_args] + data = func(*args, **kwargs) + + return data + + +def run_perftest( + func, + *args, + num_iters=20, + num_warmup=3, + testGraph=False, + num_rotate_args=0, + needTrace=False, + **kwargs, +): + + @perftest( + num_iters=num_iters, + num_warmup=num_warmup, + testGraph=testGraph, + num_rotate_args=num_rotate_args, + needTrace=needTrace, + ) + def worker(*args, **kwargs): + return func(*args, **kwargs) + + return worker(*args, **kwargs) + + +def log_args(func, *args, **kwargs): + import inspect + + callargs = inspect.getcallargs(func, *args, **kwargs) + + prefix = f"calling {func.__name__}(" + blanks = " " * (len(prefix)) + + def getTensorInfo(el): + if isinstance(el, torch.Tensor): + return f"{el.shape} {el.dtype} {el.device} {hex(el.data_ptr())}" + elif isinstance(el, tuple): + viewNum = 5 + if len(el) > viewNum: + el = list(el[:viewNum]) + ["..."] + return f'\n{" "*(len(prefix)+31)}'.join( + ["("] + [f" {getTensorInfo(e)}" for e in el] + [")"] + ) + return el + + info = [f"{el:<28} = {getTensorInfo(callargs[el])}" for el in callargs] + info = f",\n{blanks}".join(info) + logger.info(f"\n{prefix}{info})") + return callargs + + +def post_process_data(df, num_iters, warm_iter=1): + """remove abnormal data""" + + device_df = df[df["device_type"].astype(str).str.contains("DeviceType.CUDA")] + # print("devicedf is ", device_df) + if device_df.empty: + return [], 0 + kernels_num = int(len(device_df) / num_iters) + + act_iters = num_iters + valid_n = len(device_df) + dropped_indexs = [] + if len(device_df) % num_iters == 0: + kernels_num = int(len(device_df) / num_iters) + else: + ##get correct kernel num + name_list = device_df["name"].tolist() + max_kernel_num = 20 + n = len(name_list) + for step in range(1, min(max_kernel_num, n // 2 + 1)): + sub_list = [name_list[i] for i in range(step)] + m = len(sub_list) + + valid_n = int(n / m) * m + pattern_match = all( + name_list[i] == sub_list[i % m] for i in range(int(n / m) * m) + ) + if pattern_match: + kernels_num = m + act_iters = valid_n / m + break + dropped_indexs = device_df.iloc[valid_n:].index.tolist() + if kernels_num == 0: + print("data missed, the time may be inaccurate!") + + test_df = device_df.iloc[:valid_n].reset_index() + grouped_kernel_df = test_df.groupby(test_df.index // kernels_num, sort=False).agg( + {"self_device_time_total": "sum", "index": list} + ) + + # rm warm iters + sum_df = grouped_kernel_df.iloc[warm_iter:].reset_index(drop=True) + out_range_idx = [] + if num_iters > 30: + # IQR to remove abnormal data + k = 1.5 + Q1 = sum_df["self_device_time_total"].quantile(0.25) + Q3 = sum_df["self_device_time_total"].quantile(0.75) + IQR = Q3 - Q1 + lower = Q1 - k * IQR + upper = Q3 + k * IQR + out_range_idx = sum_df.index[ + (sum_df["self_device_time_total"] < lower) + | (sum_df["self_device_time_total"] > upper) + ].tolist() + out_range_num = len(out_range_idx) + + indices = {idx for i in out_range_idx for idx in sum_df.iloc[i]["index"]} + + index_sublists = grouped_kernel_df["index"].head(warm_iter).tolist() + indices_to_add = [idx for sublist in index_sublists for idx in sublist] + indices.update(indices_to_add) + indices.update(dropped_indexs) + if int(os.environ.get("FLYDSL_LOG_MORE", 0)): + logger.info(f"abnormal data indices: {indices}") + for i in indices: + logger.info(f"abnormal data: {df.iloc[i]['self_device_time_total']}") + return list(indices), out_range_num + warm_iter + num_iters - act_iters + + +def get_trace_perf(prof, num_iters): + assert num_iters > 1 + warm_iter = 1 + num_iters -= warm_iter + df = [] + cols = [ + "name", + "self_cpu_time_total", + "self_device_time_total", + "device_type", + "device_index", + ] + for el in prof.events(): + df.append([getattr(el, x, None) for x in cols]) + df = pd.DataFrame(df, columns=cols) + ###remove abnormal data + dropped_num = warm_iter + dropped_indexs, dropped_num = post_process_data( + df, num_iters + warm_iter, warm_iter + ) + df = df.drop(dropped_indexs) + iter_init = 0 # warm_iter dropped + df["cnt"] = 1 + rets = [] + + for name, d in df.groupby("name", sort=False): + kernel_num_per_iter = iter_init + if str(d["device_type"].iat[0]).split(".")[-1] != "CUDA": + kernel_num_per_iter = 1 + r = d.iloc[kernel_num_per_iter:][ + ["cnt", "self_cpu_time_total", "self_device_time_total"] + ].sum() + if not r.empty: + device_type = str(d["device_type"].iat[0]).split(".")[-1] + r["name"] = name + r["device_type"] = device_type + r["device_index"] = str(d["device_index"].iat[0]) + if device_type == "CUDA": + r["device_time_sum"] = r["self_device_time_total"] + r["host_time_sum"] = 0 + else: + r["host_time_sum"] = r["self_device_time_total"] + r["device_time_sum"] = 0 + rets.append(r) + df = pd.DataFrame(rets) + cols = [ + "name", + "cnt", + "host_time_sum", + "device_time_sum", + "device_type", + "device_index", + ] + cols = [el for el in cols if el in df.columns] + df = df[(df.host_time_sum > 0) | (df.device_time_sum > 0)] + + timerList = [ + "host_time_sum", + "device_time_sum", + ] + df = df[cols].sort_values(timerList, ignore_index=True) + actual_iters = num_iters + warm_iter - dropped_num + if df.empty: + logger.info("no valida data after post process!") + + avg_name = "[avg us/iter]" + for el in timerList: + if el == "host_time_sum": + df.at[avg_name, el] = df[el].sum() / num_iters + else: + df.at[avg_name, el] = df[el].sum() / actual_iters + if int(os.environ.get("FLYDSL_LOG_MORE", 0)): + pd.set_option("display.expand_frame_repr", False) + pd.set_option("display.max_colwidth", 90) + pd.set_option("display.float_format", "{:,.1f}".format) + logger.info(f"{df}") + return df.at[avg_name, "device_time_sum"] diff --git a/op_tests/triton_tests/attention/test_mha.py b/op_tests/triton_tests/attention/test_mha.py index 73721efe2b..4480cae9e7 100644 --- a/op_tests/triton_tests/attention/test_mha.py +++ b/op_tests/triton_tests/attention/test_mha.py @@ -10,6 +10,7 @@ mha_set_use_fused_bwd_kernel, mha_set_use_int64_strides, ) +from aiter.ops.triton.attention.mha_fused_bwd import bwd_preprocess_mxfp8 from aiter.test_mha_common import ( attention_ref, attention_ref_with_tol, @@ -17,6 +18,7 @@ generate_qkv, ) from op_tests.triton_tests.attention.mha_test_utils import pad_rearrange_dropout_mask +from aiter.ops.triton.quant.mxfp8_quant import downcast_to_mxfp8, upcast_from_mxfp8 logging.basicConfig(level=logging.DEBUG) logger = logging.getLogger(__name__) @@ -445,6 +447,35 @@ def test_mha_varlen_with_dropout( ) +@pytest.mark.parametrize("BATCH", [1, 4, 32, 128]) +@pytest.mark.parametrize("SEQLEN", [512, 1024, 2048]) +@pytest.mark.parametrize("HEAD_SZ", [64, 128]) +def test_mha_backward_preprocess_mxfp8( + BATCH: int, + SEQLEN: int, + HEAD_SZ: int, +): + torch.cuda.empty_cache() + torch.manual_seed(20) + + o_fp32 = torch.randn(BATCH, SEQLEN, HEAD_SZ, device="cuda", dtype=torch.float32) + do_fp32 = torch.randn(BATCH, SEQLEN, HEAD_SZ, device="cuda", dtype=torch.float32) + o_fp8, o_scale = downcast_to_mxfp8(o_fp32, torch.float8_e4m3fn, -1) + do_fp8, do_scale = downcast_to_mxfp8(do_fp32, torch.float8_e4m3fn, -1) + o_fp32 = upcast_from_mxfp8(o_fp8, o_scale, torch.float32, -1) + do_fp32 = upcast_from_mxfp8(do_fp8, do_scale, torch.float32, -1) + + triton_out = bwd_preprocess_mxfp8( + o_fp8, + o_scale, + do_fp8, + do_scale, + ) + torch_out = (o_fp32 * do_fp32).sum(-1) + + torch.testing.assert_close(triton_out, torch_out, atol=0.01, rtol=0.01) + + # Production shapes based on real models: # HQ=32, HK=8: Llama 3 8B (GQA 4:1) # HQ=64, HK=8: Llama 3 70B (GQA 8:1) diff --git a/op_tests/triton_tests/quant/test_quant_mxfp8.py b/op_tests/triton_tests/quant/test_quant_mxfp8.py new file mode 100644 index 0000000000..85d72b9b89 --- /dev/null +++ b/op_tests/triton_tests/quant/test_quant_mxfp8.py @@ -0,0 +1,348 @@ +# SPDX-License-Identifier: MIT +# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + +import torch +import pytest +from aiter.ops.triton.quant.mxfp8_quant import ( + downcast_to_mxfp8, + upcast_from_mxfp8, + downcast_to_mxfp8_2d, + upcast_from_mxfp8_2d, +) + + +def get_max_quant_val(dtype): + if dtype == torch.float8_e4m3fn: + return 448.0 + else: + return 57344.0 + + +def get_max_quant_power_of_2_val(dtype): + if dtype == torch.float8_e4m3fn: + return 256.0 + else: + return 32768.0 + + +def torch_downcast_to_mxfp8( + x: torch.Tensor, dtype: torch.dtype, axis: int, SCALE_ROUNDING_MODE: int = 0 +): + # returns tensor and scale in fp32 post quantization + + ndim = x.ndim + assert -ndim <= axis < ndim, f"Invalid axis {axis=}" + axis = axis if axis >= 0 else axis + ndim + + x = x.to(torch.float32) + x = x.transpose(axis, x.ndim - 1) + orig_shape = x.shape + quant_dim = orig_shape[-1] + pad_length = 32 - quant_dim % 32 + if pad_length == 32: + pad_length = 0 + padding = torch.empty(x.shape[:-1] + (pad_length,), dtype=x.dtype, device="cuda") + padding.fill_(-1.0) + x_padded = torch.cat((x, padding), -1) + x_abs_padded = torch.cat((torch.abs(x), padding), -1) + padded_shape = x_padded.shape + + new_shape = padded_shape[:-1] + (padded_shape[-1] // 32, 32) + x_padded = x_padded.reshape(new_shape) + x_abs_padded = x_abs_padded.reshape(new_shape) + scale = torch.amax(x_abs_padded, -1) + if SCALE_ROUNDING_MODE == 0: + scale = scale / get_max_quant_val(dtype) + scale = (scale.view(torch.int32) + 0x007FFFFF) & 0x7F800000 + else: + scale = scale / get_max_quant_power_of_2_val(dtype) + scale = scale.view(torch.int32) & 0x7F800000 + scale = scale.view(torch.float32) + + scale_inv = torch.where(scale == 0.0, 0.0, 1.0 / scale).unsqueeze(-1) + x_padded = x_padded * scale_inv + x_padded = x_padded.reshape(padded_shape) + x = x_padded[..., :quant_dim].clone() + max_val = get_max_quant_val(dtype) + x = x.clamp_(-max_val, max_val) + x = x.to(dtype).to(torch.float32) + x = x.transpose(axis, x.ndim - 1) + scale = scale.transpose(axis, x.ndim - 1) + return x, scale + + +def upcast_scale(scale): + scale = scale.to(torch.int32) << 23 + scale = scale.view(torch.float32) + return scale + + +def torch_upcast_from_mxfp8( + x: torch.Tensor, + scale: torch.Tensor, + dtype: torch.dtype, + axis: int, +): + ndim = x.ndim + assert -ndim <= axis < ndim, f"Invalid axis {axis=}" + axis = axis if axis >= 0 else axis + ndim + + x = x.to(torch.float32) + x = x.transpose(axis, x.ndim - 1) + scale = scale.transpose(axis, x.ndim - 1) + orig_shape = x.shape + quant_dim = orig_shape[-1] + pad_length = 32 - quant_dim % 32 + if pad_length == 32: + pad_length = 0 + padding = torch.empty(x.shape[:-1] + (pad_length,), dtype=x.dtype, device="cuda") + padding.fill_(-1.0) + x_padded = torch.cat((x, padding), -1) + padded_shape = x_padded.shape + + new_shape = padded_shape[:-1] + (padded_shape[-1] // 32, 32) + x_padded = x_padded.reshape(new_shape) + scale = upcast_scale(scale).unsqueeze(-1) + x_padded = x_padded * scale + x_padded = x_padded.reshape(padded_shape) + x = x_padded[..., :quant_dim].clone() + x = x.transpose(axis, x.ndim - 1) + x = x.to(dtype) + return x + + +@pytest.mark.parametrize( + "shape, axis", + [ + ((1, 4), -1), + ((1, 28), -1), + ((1, 32), -1), + ((1, 64), -1), + ((1, 68), -1), + ((2, 4), -1), + ((2, 28), -1), + ((2, 32), -1), + ((2, 200, 64), 1), + ((2, 68), -1), + ((128, 4), 0), + ((128, 28), -1), + ((128, 32), -1), + ((128, 64), -1), + ((128, 68), -1), + ((256, 32), -1), + ((160, 40), -1), + ((280, 20), -1), + ], +) +@pytest.mark.parametrize("in_dtype", [torch.bfloat16, torch.float32]) +@pytest.mark.parametrize("out_dtype", [torch.float8_e4m3fn, torch.float8_e5m2]) +@pytest.mark.parametrize("SCALE_ROUNDING_MODE", [0, 1]) +def test_downcast_to_mxfp8(shape, axis, in_dtype, out_dtype, SCALE_ROUNDING_MODE): + torch.cuda.empty_cache() + torch.manual_seed(20) + x = torch.randn(*shape, dtype=in_dtype, device="cuda") + + out_torch, out_scale_torch = torch_downcast_to_mxfp8( + x, out_dtype, axis, SCALE_ROUNDING_MODE + ) + out_triton, out_scale_triton = downcast_to_mxfp8( + x, out_dtype, axis, SCALE_ROUNDING_MODE + ) + out_triton = out_triton.to(torch.float32) + out_scale_triton = upcast_scale(out_scale_triton) + + torch.testing.assert_close(out_triton, out_torch, atol=0.01, rtol=0.01) + torch.testing.assert_close(out_scale_triton, out_scale_torch, atol=0.01, rtol=0.01) + + +@pytest.mark.parametrize( + "shape, axis", + [ + ((1, 4), -1), + ((1, 28), -1), + ((1, 32), -1), + ((1, 64), -1), + ((1, 68), -1), + ((2, 4), -1), + ((2, 28), -1), + ((2, 32), -1), + ((2, 200, 64), 1), + ((2, 68), -1), + ((128, 4), 0), + ((128, 28), -1), + ((128, 32), -1), + ((128, 64), -1), + ((128, 68), -1), + ((256, 32), -1), + ((160, 40), -1), + ((280, 20), -1), + ], +) +@pytest.mark.parametrize("in_dtype", [torch.float8_e4m3fn, torch.float8_e5m2]) +@pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float32]) +def test_upcast_from_mxfp8(shape, axis, in_dtype, out_dtype): + torch.cuda.empty_cache() + torch.manual_seed(20) + x = torch.randn(*shape, dtype=out_dtype, device="cuda") + x, x_scale = downcast_to_mxfp8(x, in_dtype, axis) + out_triton = upcast_from_mxfp8(x, x_scale, out_dtype, axis) + out_torch = torch_upcast_from_mxfp8(x, x_scale, out_dtype, axis) + + torch.testing.assert_close(out_triton, out_torch, atol=0.01, rtol=0.01) + + +def torch_downcast_to_mxfp8_2d( + x: torch.Tensor, dtype: torch.dtype, SCALE_ROUNDING_MODE: int = 0 +): + """Reference implementation for the 2D 32x32-block mxfp8 downcast.""" + x = x.to(torch.float32) + orig_shape = x.shape + M = orig_shape[-2] + N = orig_shape[-1] + pad_m = (-M) % 32 + pad_n = (-N) % 32 + Mp = M + pad_m + Np = N + pad_n + + x_padded = torch.zeros(orig_shape[:-2] + (Mp, Np), dtype=x.dtype, device=x.device) + x_padded[..., :M, :N] = x + abs_padded = torch.full( + orig_shape[:-2] + (Mp, Np), -1.0, dtype=x.dtype, device=x.device + ) + abs_padded[..., :M, :N] = torch.abs(x) + + M_s = Mp // 32 + N_s = Np // 32 + leading = orig_shape[:-2] + + abs_blocked = abs_padded.reshape(*leading, M_s, 32, N_s, 32) + # max over the two 32-axes -> (..., M_s, N_s) + scale = abs_blocked.amax(dim=-1).amax(dim=-2) + scale = scale / get_max_quant_val(dtype) + if SCALE_ROUNDING_MODE == 0: + scale = (scale.view(torch.int32) + 0x007FFFFF) & 0x7F800000 + else: + scale = scale.view(torch.int32) & 0x7F800000 + scale = scale.view(torch.float32) + + scale_inv = torch.where(scale == 0.0, 0.0, 1.0 / scale) + # broadcast over the inner 32x32 block + scale_inv_b = scale_inv.unsqueeze(-1).unsqueeze(-3) + + x_blocked = x_padded.reshape(*leading, M_s, 32, N_s, 32) + x_blocked = x_blocked * scale_inv_b + x_padded = x_blocked.reshape(*leading, Mp, Np) + x = x_padded[..., :M, :N].contiguous() + # Triton's fp32 -> fp8 cast saturates on overflow; torch's .to(fp8_e4m3fn) + # turns too large values into NaN. ROUND_DOWN can push the per-block max + # over fp8's representable range, so clamp here to match Triton's behavior. + max_val = get_max_quant_val(dtype) + x = x.clamp_(-max_val, max_val) + x = x.to(dtype).to(torch.float32) + return x, scale + + +def torch_upcast_from_mxfp8_2d( + x: torch.Tensor, scale: torch.Tensor, dtype: torch.dtype +): + """Reference implementation for the 2D 32x32-block mxfp8 upcast.""" + x = x.to(torch.float32) + orig_shape = x.shape + M = orig_shape[-2] + N = orig_shape[-1] + pad_m = (-M) % 32 + pad_n = (-N) % 32 + Mp = M + pad_m + Np = N + pad_n + + x_padded = torch.zeros(orig_shape[:-2] + (Mp, Np), dtype=x.dtype, device=x.device) + x_padded[..., :M, :N] = x + + M_s = Mp // 32 + N_s = Np // 32 + leading = orig_shape[:-2] + + scale_f = upcast_scale(scale) + scale_b = scale_f.unsqueeze(-1).unsqueeze(-3) + x_blocked = x_padded.reshape(*leading, M_s, 32, N_s, 32) + x_blocked = x_blocked * scale_b + x_padded = x_blocked.reshape(*leading, Mp, Np) + x = x_padded[..., :M, :N].contiguous() + x = x.to(dtype) + return x + + +@pytest.mark.parametrize( + "shape", + [ + (32, 32), + (32, 64), + (64, 32), + (64, 64), + (96, 128), + (128, 128), + (128, 256), + (256, 256), + (40, 50), + (68, 68), + (4, 20), + (1, 32), + (32, 1), + (2, 64, 64), + (3, 96, 128), + (2, 3, 64, 64), + ], +) +@pytest.mark.parametrize("in_dtype", [torch.bfloat16, torch.float32]) +@pytest.mark.parametrize("out_dtype", [torch.float8_e4m3fn, torch.float8_e5m2]) +@pytest.mark.parametrize("SCALE_ROUNDING_MODE", [0, 1]) +def test_downcast_to_mxfp8_2d(shape, in_dtype, out_dtype, SCALE_ROUNDING_MODE): + torch.cuda.empty_cache() + torch.manual_seed(20) + x = torch.randn(*shape, dtype=in_dtype, device="cuda") + + out_torch, out_scale_torch = torch_downcast_to_mxfp8_2d( + x, out_dtype, SCALE_ROUNDING_MODE + ) + out_triton, out_scale_triton = downcast_to_mxfp8_2d( + x, out_dtype, SCALE_ROUNDING_MODE + ) + out_triton = out_triton.to(torch.float32) + out_scale_triton = upcast_scale(out_scale_triton) + + torch.testing.assert_close(out_triton, out_torch, atol=0.01, rtol=0.01) + torch.testing.assert_close(out_scale_triton, out_scale_torch, atol=0.01, rtol=0.01) + + +@pytest.mark.parametrize( + "shape", + [ + (32, 32), + (32, 64), + (64, 32), + (64, 64), + (96, 128), + (128, 128), + (128, 256), + (256, 256), + (40, 50), + (68, 68), + (4, 20), + (1, 32), + (32, 1), + (2, 64, 64), + (3, 96, 128), + (2, 3, 64, 64), + ], +) +@pytest.mark.parametrize("in_dtype", [torch.float8_e4m3fn, torch.float8_e5m2]) +@pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float32]) +def test_upcast_from_mxfp8_2d(shape, in_dtype, out_dtype): + torch.cuda.empty_cache() + torch.manual_seed(20) + x = torch.randn(*shape, dtype=out_dtype, device="cuda") + x, x_scale = downcast_to_mxfp8_2d(x, in_dtype) + out_triton = upcast_from_mxfp8_2d(x, x_scale, out_dtype) + out_torch = torch_upcast_from_mxfp8_2d(x, x_scale, out_dtype) + + torch.testing.assert_close(out_triton, out_torch, atol=0.01, rtol=0.01)