diff --git a/kernels/mfma_preshuffle_pipeline.py b/kernels/mfma_preshuffle_pipeline.py index ad14b36d..49537bb1 100644 --- a/kernels/mfma_preshuffle_pipeline.py +++ b/kernels/mfma_preshuffle_pipeline.py @@ -881,3 +881,283 @@ def load_b_raw_w4a16_groupwise( def unpack_b_w4a16_groupwise(packed32, scale_val, arith, vector, use_gfx950_cvt=False): """Phase 2 of W4A16 groupwise: unpack + scale + convert to bf16.""" return unpack_b_w4a16(packed32, arith, vector, scale_val=scale_val, use_gfx950_cvt=use_gfx950_cvt) + + +def _cvt_scalef32_pk_bf16_fp4(packed_i32, scale_f32, byte_idx, arith, vector): + """GFX950 hardware: v_cvt_scalef32_pk_bf16_fp4. + + Converts 2 FP4 E2M1 nibbles (from *byte_idx* of *packed_i32*) to + 2 bf16 values (already scaled by *scale_f32*), returned as i32 + (2 packed bf16). + + One instruction replaces ~36 VALU of the software path. + """ + from flydsl._mlir.dialects import llvm + + byte_idx_i32 = arith.constant(byte_idx, type=T.i32) + result_v2bf16 = llvm.call_intrinsic( + T.vec(2, T.bf16), + "llvm.amdgcn.cvt.scalef32.pk.bf16.fp4", + [packed_i32, scale_f32, byte_idx_i32], + [], [], + ) + vec1_i32_t = T.vec(1, T.i32) + return vector.extract( + vector.bitcast(vec1_i32_t, result_v2bf16), + static_position=[0], dynamic_position=[], + ) + + +def _fp4x4_in_i32_to_bf16x4_i64(packed4, arith, vector, scale_f32=None): + """Convert 4 FP4 E2M1 nibbles (in 4 bytes of i32) to 4 bf16 packed as i64. + + Each byte of *packed4* holds one nibble in bits [3:0]: + bit[3] = sign, bits[2:1] = exponent (bias=1), bit[0] = mantissa. + + Unsigned value table (3-bit index): + 000->0.0, 001->0.5, 010->1.0, 011->1.5, + 100->2.0, 101->3.0, 110->4.0, 111->6.0 + + *scale_f32*, when provided, is an f32 E8M0 block-scale multiplied + into every element before truncation to bf16. + """ + vec1_i32_t = T.vec(1, T.i32) + vec2_i32 = T.i32x2 + vec4_i8 = T.i8x4 + vec1_i64 = T.vec(1, T.i64) + + v1 = vector.from_elements(vec1_i32_t, [packed4]) + i8x4 = vector.bitcast(vec4_i8, v1) + + c1 = arith.constant(1, type=T.i32) + c3_shift = arith.constant(3, type=T.i32) + c7 = arith.constant(7, type=T.i32) + c22 = arith.constant(22, type=T.i32) + c23 = arith.constant(23, type=T.i32) + c31 = arith.constant(31, type=T.i32) + c126 = arith.constant(126, type=T.i32) + c_zero = arith.constant(0, type=T.i32) + c_half_bits = arith.constant(0x3F000000, type=T.i32) # 0.5f + + f32_vals = [] + for i in range(4): + nibble_i8 = vector.extract(i8x4, static_position=[i], dynamic_position=[]) + n = arith.extui(T.i32, nibble_i8) + + sign_bit = arith.andi(arith.shrui(n, c3_shift), c1) + unsigned_val = arith.andi(n, c7) + exp_field = arith.shrui(unsigned_val, c1) + mant_field = arith.andi(unsigned_val, c1) + + f32_norm = arith.ori( + arith.shli(arith.addi(exp_field, c126), c23), + arith.shli(mant_field, c22), + ) + + is_zero = arith.cmpi(arith.CmpIPredicate.eq, unsigned_val, c_zero) + is_subnorm = arith.cmpi(arith.CmpIPredicate.eq, unsigned_val, c1) + + f32_bits = arith.select( + is_zero, c_zero, + arith.select(is_subnorm, c_half_bits, f32_norm), + ) + f32_bits = arith.ori(f32_bits, arith.shli(sign_bit, c31)) + + v = arith.bitcast(T.f32, f32_bits) + if scale_f32 is not None: + v = v * scale_f32 + f32_vals.append(v) + + c16 = arith.constant(16, type=T.i32) + c_ffff0000 = arith.constant(0xFFFF0000, type=T.i32) + bits0 = arith.bitcast(T.i32, f32_vals[0]) + bits1 = arith.bitcast(T.i32, f32_vals[1]) + bits2 = arith.bitcast(T.i32, f32_vals[2]) + bits3 = arith.bitcast(T.i32, f32_vals[3]) + i32_lo = arith.shrui(bits0, c16) | (bits1 & c_ffff0000) + i32_hi = arith.shrui(bits2, c16) | (bits3 & c_ffff0000) + + v2 = vector.from_elements(vec2_i32, [i32_lo, i32_hi]) + v64 = vector.bitcast(vec1_i64, v2) + return vector.extract(v64, static_position=[0], dynamic_position=[]) + + +def load_b_raw_mxfp4( + buffer_ops, + arith, + vector, + *, + arg_b, + b_rsrc, + layout_b, + base_k: ir.Value, + ku: int, + n_blk: ir.Value, + n_intra: ir.Value, + lane_div_16: ir.Value, + elem_type: ir.Type, + kpack_bytes: int = 16, +): + """Load 4 bytes of packed FP4 from a kpack=16 preshuffle layout. + + Addressing for kpack=16 (``shuffle_weight_a16w4`` format): + - Layout shape: ``(n0, k0, klane=4, nlane=16, kpack=16)`` + - The A-side LDS has klane stride = 8 bf16 elements, advancing + by 32 bf16 per ku step. B must match: each klane loads 4 bytes + (8 FP4 = 8 K elements) at K_start = base_k + ku*32 + lane*8. + - In the preshuffle layout this maps to: + k0 = base_k//128 + ku//4 + klane_hw = ku % 4 (compile-time) + kpack_byte = lane_div_16*4 (runtime) + + Returns a single i32 containing 4 packed bytes (8 FP4 nibbles). + """ + if kpack_bytes != 16: + raise ValueError(f"MXFP4 requires kpack_bytes=16, got {kpack_bytes!r}") + + c128 = arith.constant(128, index=True) + c4 = arith.constant(4, index=True) + + k0_base = base_k // c128 + k0 = k0_base + arith.constant(ku // 4, index=True) + klane_hw = arith.constant(ku % 4, index=True) + byte_offset = lane_div_16 * c4 + + coord_pack = (n_blk, k0, klane_hw, n_intra, arith.constant(0, index=True)) + idx_pack = crd2idx(coord_pack, layout_b) + idx_bytes = idx_pack + byte_offset + + b4 = _buffer_load_vec( + buffer_ops, + vector, + b_rsrc, + idx_bytes, + elem_type=elem_type, + vec_elems=4, + elem_bytes=1, + offset_in_bytes=True, + ) + packed32 = vector.extract( + vector.bitcast(T.vec(1, T.i32), b4), + static_position=[0], + dynamic_position=[], + ) + return packed32 + + +def load_b_raw_mxfp4_dwordx4( + buffer_ops, + arith, + vector, + *, + arg_b, + b_rsrc, + layout_b, + base_k: "ir.Value", + n_blk: "ir.Value", + n_intra: "ir.Value", + lane_div_16: "ir.Value", + elem_type: "ir.Type", + kpack_bytes: int = 16, + cache_modifier: int = 0, +): + """Load 16 bytes (vec4_i32) of packed FP4 via buffer_load_dwordx4. + + CK-style addressing: klane = lane_div_16, loading the full kpack + for the thread's sub-lane. Returns vec4_i32 where i32[j] contains + 8 FP4 elements for kIter j. + + Layout: ``(n0, k0, klane=4, nlane=16, kpack=16)`` + """ + if kpack_bytes != 16: + raise ValueError(f"MXFP4 requires kpack_bytes=16, got {kpack_bytes!r}") + + c128 = arith.constant(128, index=True) + k0 = base_k // c128 + + coord_pack = (n_blk, k0, lane_div_16, n_intra, arith.constant(0, index=True)) + idx_pack = crd2idx(coord_pack, layout_b) + + b16 = _buffer_load_vec( + buffer_ops, + vector, + b_rsrc, + idx_pack, + elem_type=elem_type, + vec_elems=16, + elem_bytes=1, + offset_in_bytes=True, + ) + return vector.bitcast(T.vec(4, T.i32), b16) + + +def _unpack_b_mxfp4_bf16_hw(packed32, arith, vector, scale_f32): + """Hardware fast-path: 4 x v_cvt_scalef32_pk_bf16_fp4.""" + vec2_i32 = T.i32x2 + vec1_i64 = T.vec(1, T.i64) + + lo0 = _cvt_scalef32_pk_bf16_fp4(packed32, scale_f32, 0, arith, vector) + lo1 = _cvt_scalef32_pk_bf16_fp4(packed32, scale_f32, 1, arith, vector) + v2_lo = vector.from_elements(vec2_i32, [lo0, lo1]) + v64_lo = vector.bitcast(vec1_i64, v2_lo) + b0 = vector.extract(v64_lo, static_position=[0], dynamic_position=[]) + + hi0 = _cvt_scalef32_pk_bf16_fp4(packed32, scale_f32, 2, arith, vector) + hi1 = _cvt_scalef32_pk_bf16_fp4(packed32, scale_f32, 3, arith, vector) + v2_hi = vector.from_elements(vec2_i32, [hi0, hi1]) + v64_hi = vector.bitcast(vec1_i64, v2_hi) + b1 = vector.extract(v64_hi, static_position=[0], dynamic_position=[]) + + return (b0, b1) + + +def _unpack_b_mxfp4_bf16_sw(packed32, arith, vector, scale_f32): + """Software fallback for non-GFX950 targets.""" + c_0f = arith.constant(0x0F, type=T.i32) + c4 = arith.constant(4, type=T.i32) + c8 = arith.constant(8, type=T.i32) + c12 = arith.constant(12, type=T.i32) + c16 = arith.constant(16, type=T.i32) + c20 = arith.constant(20, type=T.i32) + c24 = arith.constant(24, type=T.i32) + c28 = arith.constant(28, type=T.i32) + + n0 = packed32 & c_0f + n1 = arith.shrui(packed32, c4) & c_0f + n2 = arith.shrui(packed32, c8) & c_0f + n3 = arith.shrui(packed32, c12) & c_0f + first = n0 | arith.shli(n1, c8) | arith.shli(n2, c16) | arith.shli(n3, c24) + + n4 = arith.shrui(packed32, c16) & c_0f + n5 = arith.shrui(packed32, c20) & c_0f + n6 = arith.shrui(packed32, c24) & c_0f + n7 = arith.shrui(packed32, c28) & c_0f + second = n4 | arith.shli(n5, c8) | arith.shli(n6, c16) | arith.shli(n7, c24) + + b0 = _fp4x4_in_i32_to_bf16x4_i64(first, arith, vector, scale_f32=scale_f32) + b1 = _fp4x4_in_i32_to_bf16x4_i64(second, arith, vector, scale_f32=scale_f32) + return (b0, b1) + + +def unpack_b_mxfp4_bf16(packed32, arith, vector, scale_f32=None, use_hw_cvt=True): + """Unpack 8 FP4 E2M1 nibbles (packed in i32) to 2 x i64 (8 bf16). + + Each byte of *packed32* holds two FP4 nibbles: low nibble = K_even, + high nibble = K_even+1. For ``mfma_f32_16x16x16bf16_1k`` the B + operand needs 4 consecutive K values per i64. So we unpack the + lower 2 bytes (4 consecutive nibbles) into b0 and the upper 2 bytes + into b1. + + *scale_f32* is the decoded E8M0 block-scale (as f32). + + When *use_hw_cvt* is True (default), uses the GFX950 hardware + instruction ``v_cvt_scalef32_pk_bf16_fp4`` which converts 2 FP4 + nibbles to 2 bf16 (with scale) in a single VALU cycle. This + replaces ~144 VALU of the software fallback with 4 instructions. + + Returns ``(b0, b1)`` -- two i64 values, each containing 4 bf16 for + one ``mfma_f32_16x16x16bf16_1k`` call. + """ + if use_hw_cvt and scale_f32 is not None: + return _unpack_b_mxfp4_bf16_hw(packed32, arith, vector, scale_f32) + return _unpack_b_mxfp4_bf16_sw(packed32, arith, vector, scale_f32) diff --git a/kernels/mixed_moe_gemm_2stage.py b/kernels/mixed_moe_gemm_2stage.py index 3b442d7b..dab64d51 100644 --- a/kernels/mixed_moe_gemm_2stage.py +++ b/kernels/mixed_moe_gemm_2stage.py @@ -29,6 +29,11 @@ from flydsl.expr import range_constexpr from flydsl.runtime.device import get_rocm_arch as get_hip_arch +try: + from flydsl.runtime.device import supports_bf16_global_atomics +except ImportError: + def supports_bf16_global_atomics(arch: str) -> bool: + return str(arch).startswith(("gfx94", "gfx95", "gfx12")) from flydsl.utils.smem_allocator import SmemAllocator, SmemPtr @@ -37,6 +42,7 @@ from flydsl.expr import arith, gpu, buffer_ops, vector, rocdl, const_expr from flydsl._mlir.dialects import llvm, scf, memref +from flydsl._mlir.dialects import fly as _fly from flydsl._mlir.dialects.arith import CmpIPredicate from kernels.mfma_preshuffle_pipeline import ( @@ -47,10 +53,13 @@ lds_store_4b_xor16, make_preshuffle_b_layout, make_preshuffle_scale_layout, + load_b_raw_mxfp4, + load_b_raw_mxfp4_dwordx4, + unpack_b_mxfp4_bf16, tile_chunk_coord_i32, swizzle_xor16, ) -from kernels.mfma_epilogues import c_shuffle_epilog +from kernels.mfma_epilogues import c_shuffle_epilog, default_epilog, mfma_epilog from kernels.layout_utils import crd2idx, idx2crd, get as layout_get from kernels.kernels_common import _if_then, validate_moe_dtypes @@ -126,6 +135,7 @@ def compile_mixed_moe_gemm1( gate_mode: GateMode = GateMode.SEPARATED, a_scale_one: bool = False, xcd_swizzle: int = 0, + split_k_intra: int = 1, ): """Compile stage1 kernel (gate+up with silu/swiglu). @@ -138,45 +148,119 @@ def compile_mixed_moe_gemm1( gate_mode controls the gate/up computation strategy — see GateMode enum. """ + is_a16w4_stage1 = a_dtype == "bf16" and b_dtype in ("fp4", "mxfp4") + + # ---- Shared prelude (used by both A16W4 and generic stage1 paths) ---- + # Padding semantics: model_dim and inter_dim INCLUDE padding. + # model_dim = model_dim_true + model_dim_pad (K direction) + # inter_dim = inter_dim_true + inter_dim_pad (N direction) + # The grid simply does not launch tiles for padding columns. gpu_arch = get_hip_arch() allocator_pong = SmemAllocator(None, arch=gpu_arch, global_sym_name="smem0") allocator_ping = SmemAllocator(None, arch=gpu_arch, global_sym_name="smem1") + mock_gate_only = gate_mode is GateMode.MOCK_GATE_ONLY + gate_up_interleave = gate_mode is GateMode.INTERLEAVE + _inter_dim_valid = inter_dim - inter_dim_pad + _is_splitk = k_batch > 1 + _use_lds128 = os.environ.get("FLIR_CK_LDS128", "1") in ( + "1", "true", "True", "YES", "yes", + ) + pad_k = 0 if _use_lds128 else 8 + lds_stride = tile_k + pad_k + + # ==== Unified setup (merged A16W4 + Generic stage1) ==== + # All paths share the same setup variables; A16W4-specific blocks are + # gated by `if is_a16w4_stage1:` and generic-only blocks by + # `if not is_a16w4_stage1:`. This mirrors the stage2 design. + + # A16W4 alias used throughout the body: gate_only ≡ mock_gate_only + gate_only = mock_gate_only + if gate_only and gate_up_interleave: + raise ValueError( + "gate_only / mock_gate_only and gate_up_interleave are mutually exclusive" + ) + if gate_only and not _is_splitk: + raise ValueError( + "gate_only / mock_gate_only requires k_batch > 1 (split-K)" + ) + _state = {} - validate_moe_dtypes(a_dtype, b_dtype) + # ---- dtype probing (A16W4 hardcoded; generic dynamic) ---- + if not is_a16w4_stage1: + validate_moe_dtypes(a_dtype, b_dtype) - is_f16_a = a_dtype == "fp16" - is_f16_b = b_dtype == "fp16" - is_f8_a = a_dtype == "fp8" - is_f4_a = a_dtype == "fp4" - is_f4_b = b_dtype == "fp4" + is_f16_a = is_a16w4_stage1 or (a_dtype == "fp16") + is_f16_b = (not is_a16w4_stage1) and (b_dtype == "fp16") + is_f8_a = (not is_a16w4_stage1) and (a_dtype == "fp8") + is_f4_a = (not is_a16w4_stage1) and (a_dtype == "fp4") + is_f4_b = (b_dtype in ("fp4", "mxfp4")) if is_a16w4_stage1 else (b_dtype == "fp4") + is_int4 = (not is_a16w4_stage1) and (b_dtype == "int4") + is_int8 = False + # ---- wave / pack config ---- sort_block_m = max(32, tile_m) - num_waves = min(4, tile_n // 32) - total_threads = num_waves * 64 - pack_M = 1 if tile_m < 32 else 2 + if is_a16w4_stage1: + num_waves = 4 + total_threads = 256 + pack_M = 1 + pack_N = 1 + pack_K = 2 + scale_mn_pack = 2 + a_elem_vec_pack = 1 + cbsz = 0 + blgp = 4 + else: + num_waves = min(4, tile_n // 32) + total_threads = num_waves * 64 + pack_M = 1 if tile_m < 32 else 2 + pack_N = min(2, (tile_n // num_waves) // 16) + pack_K = 2 + scale_mn_pack = 2 + a_elem_vec_pack = 2 if is_f4_a else 1 + cbsz = 0 if is_f8_a else 4 + blgp = 4 n_per_wave = tile_n // num_waves - pack_N = min(2, n_per_wave // 16) - pack_K = 2 - scale_mn_pack = 2 - elem_bytes = 1 + + # ---- byte / element parameters ---- + elem_bytes = 2 if is_a16w4_stage1 else 1 a_elem_bytes = 2 if is_f16_a else 1 b_elem_bytes = 1 tile_k_bytes = int(tile_k) * int(a_elem_bytes) - a_elem_vec_pack = 2 if is_f4_a else 1 - cbsz = 0 if is_f8_a else 4 - blgp = 4 - if (tile_k_bytes % 64) != 0: raise ValueError(f"tile_k_bytes must be divisible by 64, got {tile_k_bytes}") + # ---- output dtype parsing ---- out_s = str(out_dtype).strip().lower() out_is_f32 = out_s in ("f32", "fp32", "float") out_is_bf16 = out_s in ("bf16", "bfloat16") - is_int4 = b_dtype == "int4" - is_int8 = False + + # ---- A16W4-specific BF16 K32 MFMA helper ---- + mfma_f32_bf16_k32 = None + if is_a16w4_stage1: + if out_dtype not in ("f16", "bf16"): + raise ValueError(f"out_dtype must be 'f16' or 'bf16', got {out_dtype!r}") + _mfma_k32_raw = getattr(rocdl, "mfma_f32_16x16x32_bf16_", None) + if _mfma_k32_raw is None: + raise AttributeError( + "BF16 K32 MFMA op not found: expected `rocdl.mfma_f32_16x16x32_bf16_`" + ) + _split_mfma = rocdl._split_mfma_operands + + def mfma_f32_bf16_k32(result_type, operands, *, loc=None, ip=None): + a, b, c, cbsz, abid, blgp = _split_mfma(operands, loc=loc) + return _mfma_k32_raw(result_type, a, b, c, cbsz, abid, blgp, loc=loc, ip=ip) + + # ---- type helpers (A16W4 forces bf16 X / i8 W; generic dispatches on dtype flags) ---- + out_mlir = lambda: ( + (lambda ty: ty() if callable(ty) else ty)( + T.f16 if out_dtype == "f16" else T.bf16 + ) + ) def _x_elem_type(): + if is_a16w4_stage1: + return T.bf16 if is_f4_b: return T.f8 if is_f8_a else T.i8 return T.f16 if is_f16_a else (T.i8 if is_int8 else T.f8) @@ -189,540 +273,550 @@ def _w_elem_type(): def out_elem(): return T.f32 if out_is_f32 else (T.bf16 if out_is_bf16 else T.f16) - mock_gate_only = gate_mode is GateMode.MOCK_GATE_ONLY - gate_up_interleave = gate_mode is GateMode.INTERLEAVE - - # Padding semantics: model_dim and inter_dim INCLUDE padding. - # model_dim = model_dim_true + model_dim_pad (K direction) - # inter_dim = inter_dim_true + inter_dim_pad (N direction) - # Tensor sizes use the padded dimensions (inter_dim, model_dim). - # Padding only affects kernel internal logic and grid computation. - _inter_dim_valid = inter_dim - inter_dim_pad + def x_lds_elem(): + if is_a16w4_stage1: + return T.bf16 + return T.f16 if is_f16_a else (T.i8 if is_int8 else T.f8) - # Split-K validation - _is_splitk = k_batch > 1 - if mock_gate_only and not _is_splitk: - raise ValueError("mock_gate_only requires k_batch > 1 (split-K)") - if _is_splitk: - _k_per_batch = model_dim // k_batch - assert ( - model_dim % k_batch == 0 - ), f"model_dim={model_dim} not divisible by k_batch={k_batch}" - assert ( - _k_per_batch % tile_k == 0 - ), f"K_per_batch={_k_per_batch} not divisible by tile_k={tile_k}" - - out_dtype = "bf16" + # ---- K dimension & split-K validation ---- + if is_a16w4_stage1: + accumulate = _is_splitk + if const_expr(_is_splitk): + _k_per_batch = model_dim // k_batch + else: + _k_per_batch = model_dim + _k_dim = _k_per_batch else: - _k_per_batch = model_dim - _k_dim = _k_per_batch + accumulate = False + if _is_splitk: + _k_per_batch = model_dim // k_batch + assert ( + model_dim % k_batch == 0 + ), f"model_dim={model_dim} not divisible by k_batch={k_batch}" + assert ( + _k_per_batch % tile_k == 0 + ), f"K_per_batch={_k_per_batch} not divisible by tile_k={tile_k}" + out_dtype = "bf16" + else: + _k_per_batch = model_dim + _k_dim = _k_per_batch - bytes_x_per_tile = int(tile_m) * int(tile_k) * int(a_elem_bytes) + # A16W4 body uses _single_b to skip the up-projection in single-output paths + _single_b = gate_only or gate_up_interleave + + # ---- Tile bytes / thread distribution ---- + bytes_x_per_tile = int(tile_m) * int(tile_k) * int(a_elem_bytes if not is_a16w4_stage1 else elem_bytes) if bytes_x_per_tile % total_threads != 0: raise ValueError( - f"tile_m*tile_k*elem_bytes must be divisible by {total_threads}" + "tile_m*tile_k*elem_bytes must be divisible by " + f"{total_threads}: tile_m={tile_m}, tile_k={tile_k}" ) bytes_per_thread_x = bytes_x_per_tile // total_threads - _use_lds128 = os.environ.get("FLIR_CK_LDS128", "1") in ( - "1", - "true", - "True", - "YES", - "yes", - ) - pad_k = 0 if _use_lds128 else 8 - lds_stride = tile_k + pad_k - + # ---- CShuffle epilog handling ---- if use_cshuffle_epilog is None: _use_cshuffle_epilog = os.environ.get("FLIR_MOE_STAGE1_CSHUFFLE", "1") in ( - "1", - "true", - "True", - "YES", - "yes", + "1", "true", "True", "YES", "yes", ) else: _use_cshuffle_epilog = bool(use_cshuffle_epilog) - _need_fp4 = out_dtype == "fp4" - _need_fp8 = out_dtype == "fp8" - _need_quant = _need_fp4 or _need_fp8 - _need_sort = _need_quant - - if _need_quant: - _use_cshuffle_epilog = True - - _fp4q_tag = "_fp4q" if _need_fp4 else "" - _fp8q_tag = "_fp8q" if _need_fp8 else "" - _sort_tag = "_sort" if _need_sort else "" - _async_tag = "_async" if use_async_copy else "" - _sk_tag = f"_sk{k_batch}" if _is_splitk else "" - _go_tag = "_go" if mock_gate_only else "" - _gui_tag = "_gui" if gate_up_interleave else "" - _as1_tag = "_as1" if a_scale_one else "" - _xcd_tag = f"_xcd{xcd_swizzle}" if xcd_swizzle > 0 else "" - module_name = ( - f"mfma_moe1_silu_mul_a{a_dtype}_w{b_dtype}_{out_s}" - f"_t{tile_m}x{tile_n}x{tile_k}_pm{persist_m}{_fp4q_tag}{_fp8q_tag}{_sort_tag}{_async_tag}{_sk_tag}{_go_tag}{_gui_tag}{_as1_tag}{_xcd_tag}_v32" - ).replace("-", "_") - - # -- LDS sizing -- - _cshuffle_elem_bytes = 4 if _need_quant else (4 if out_is_f32 else 2) - _single_x_bytes = int(tile_m) * int(lds_stride) * int(a_elem_bytes) - lds_out_bytes = ( - _cshuffle_elem_bytes * int(tile_m) * int(tile_n) if _use_cshuffle_epilog else 0 - ) - lds_tid_bytes = int(tile_m) * 4 - _input_elems = _single_x_bytes if a_elem_bytes == 1 else (_single_x_bytes // 2) - - # Determine whether we need wave-group split for lds_out. - # Standard layout: pong = max(input, lds_out) + tid, ping = input. - # When this overflows, split lds_out into two halves across pong & ping. - _GLOBAL_ALIGN = 1024 - _std_pong = max(_single_x_bytes, lds_out_bytes) + lds_tid_bytes - _std_ping = _single_x_bytes - _std_pong_aligned = allocator_pong._align(_std_pong, 128) - _std_total = allocator_pong._align( - _std_pong_aligned, _GLOBAL_ALIGN - ) + allocator_pong._align(_std_ping, 128) - _lds_limit = {"gfx950": 163840, "gfx942": 65536}.get(gpu_arch, 0) - - _split_lds_out = ( - _lds_limit > 0 - and lds_out_bytes > 0 - and _std_total > _lds_limit - and num_waves >= 2 - ) - - if _split_lds_out: - _half_out_bytes = _cshuffle_elem_bytes * int(tile_m) * (int(tile_n) // 2) - _pong_buffer_bytes = max(_single_x_bytes, _half_out_bytes) - _ping_buffer_bytes = max(_single_x_bytes, _half_out_bytes) + if is_a16w4_stage1: + if out_dtype not in ("f16", "bf16") and _use_cshuffle_epilog: + raise ValueError("stage1 cshuffle epilog supports only f16/bf16 output") + _split_k_intra = split_k_intra + if const_expr(_split_k_intra > 1): + _use_cshuffle_epilog = False + _waves_per_group = 4 // _split_k_intra + _n_per_wave_check = tile_n // _waves_per_group + if _n_per_wave_check < 16: + raise ValueError( + f"split_k_intra={_split_k_intra} with tile_n={tile_n}: " + f"n_per_wave={_n_per_wave_check} < 16 (MFMA minimum)" + ) + # GUI cross-wave fusion: needed when num_acc_n < 2 per wave + # (standard pair fusion requires gate+up in same wave, i.e. num_acc_n >= 2) + _n_per_wave_eff = tile_n // ((4 // _split_k_intra) if _split_k_intra > 1 else 4) + _gui_xwave_fuse = ( + gate_up_interleave and not _is_splitk + and (_n_per_wave_eff // 16) < 2 + ) + if const_expr(_gui_xwave_fuse): + _use_cshuffle_epilog = False + # Auto-disable cshuffle when tile_m doesn't meet CShuffleMLane constraint + if const_expr(_use_cshuffle_epilog): + _eff_out_n = (tile_n // 2) if (gate_up_interleave and not _is_splitk) else tile_n + _cs_nlane_chk = min(32, _eff_out_n // 4) + if _cs_nlane_chk > 0: + _cs_mlane_chk = 256 // _cs_nlane_chk + if tile_m % _cs_mlane_chk != 0: + _use_cshuffle_epilog = False + else: + _split_k_intra = 1 # not used by generic + _need_fp4 = out_dtype == "fp4" + _need_fp8 = out_dtype == "fp8" + _need_quant = _need_fp4 or _need_fp8 + _need_sort = _need_quant + if _need_quant: + _use_cshuffle_epilog = True + + # ---- module_name (different format per path; cache key) ---- + if is_a16w4_stage1: + _mode_tag = "gui" if gate_up_interleave else ("go" if gate_only else "sep") + epilog_tag = "cshuffle" if _use_cshuffle_epilog else "direct" + _wpe_tag = f"_wpe{waves_per_eu}" if waves_per_eu >= 1 else "" + _ski_tag = f"_ski{_split_k_intra}" if _split_k_intra > 1 else "" + _bias_tag = "_bias" if enable_bias else "" + _act_tag = f"_{act}" if act != "silu" else "" + _pad_tag = f"_mp{model_dim_pad}_ip{inter_dim_pad}" if (model_dim_pad or inter_dim_pad) else "" + module_name = ( + f"mfma_a16w4_moe1_mxfp4_{out_dtype}_{_mode_tag}_{epilog_tag}" + f"_t{tile_m}x{tile_n}x{tile_k}_kb{k_batch}" + f"{_wpe_tag}{_ski_tag}{_bias_tag}{_act_tag}{_pad_tag}_abi1" + ).replace("-", "_") else: + _fp4q_tag = "_fp4q" if _need_fp4 else "" + _fp8q_tag = "_fp8q" if _need_fp8 else "" + _sort_tag = "_sort" if _need_sort else "" + _async_tag = "_async" if use_async_copy else "" + _sk_tag = f"_sk{k_batch}" if _is_splitk else "" + _go_tag = "_go" if mock_gate_only else "" + _gui_tag = "_gui" if gate_up_interleave else "" + _as1_tag = "_as1" if a_scale_one else "" + _xcd_tag = f"_xcd{xcd_swizzle}" if xcd_swizzle > 0 else "" + module_name = ( + f"mfma_moe1_silu_mul_a{a_dtype}_w{b_dtype}_{out_s}" + f"_t{tile_m}x{tile_n}x{tile_k}_pm{persist_m}{_fp4q_tag}{_fp8q_tag}{_sort_tag}{_async_tag}{_sk_tag}{_go_tag}{_gui_tag}{_as1_tag}{_xcd_tag}_v32" + ).replace("-", "_") + + # ---- LDS sizing ---- + if is_a16w4_stage1: + kpack_bytes = 16 # MXFP4 preshuffle + # For interleave+non-splitk cshuffle, the epilogue output tile_n is halved + _gui_out_tile_n = tile_n // 2 if (gate_up_interleave and not _is_splitk) else tile_n + _single_x_bytes = int(tile_m) * int(lds_stride) * int(elem_bytes) + _single_x_elems = _single_x_bytes // int(elem_bytes) + lds_out_bytes = 2 * int(tile_m) * int(_gui_out_tile_n) if _use_cshuffle_epilog else 0 + # Ping-pong: pong holds max(input, output), ping holds input only _pong_buffer_bytes = max(_single_x_bytes, lds_out_bytes) _ping_buffer_bytes = _single_x_bytes + lds_tid_bytes = 0 + _split_lds_out = False + else: + kpack_bytes = 8 if is_int4 else 16 + _cshuffle_elem_bytes = 4 if _need_quant else (4 if out_is_f32 else 2) + _single_x_bytes = int(tile_m) * int(lds_stride) * int(a_elem_bytes) + lds_out_bytes = ( + _cshuffle_elem_bytes * int(tile_m) * int(tile_n) if _use_cshuffle_epilog else 0 + ) + lds_tid_bytes = int(tile_m) * 4 + _input_elems = _single_x_bytes if a_elem_bytes == 1 else (_single_x_bytes // 2) + # Determine whether we need wave-group split for lds_out. + # Standard layout: pong = max(input, lds_out) + tid, ping = input. + # When this overflows, split lds_out into two halves across pong & ping. + _GLOBAL_ALIGN = 1024 + _std_pong = max(_single_x_bytes, lds_out_bytes) + lds_tid_bytes + _std_ping = _single_x_bytes + _std_pong_aligned = allocator_pong._align(_std_pong, 128) + _std_total = allocator_pong._align( + _std_pong_aligned, _GLOBAL_ALIGN + ) + allocator_pong._align(_std_ping, 128) + _lds_limit = {"gfx950": 163840, "gfx942": 65536}.get(gpu_arch, 0) + _split_lds_out = ( + _lds_limit > 0 + and lds_out_bytes > 0 + and _std_total > _lds_limit + and num_waves >= 2 + ) + if _split_lds_out: + _half_out_bytes = _cshuffle_elem_bytes * int(tile_m) * (int(tile_n) // 2) + _pong_buffer_bytes = max(_single_x_bytes, _half_out_bytes) + _ping_buffer_bytes = max(_single_x_bytes, _half_out_bytes) + else: + _pong_buffer_bytes = max(_single_x_bytes, lds_out_bytes) + _ping_buffer_bytes = _single_x_bytes - def x_lds_elem(): - return T.f16 if is_f16_a else (T.i8 if is_int8 else T.f8) - + # ---- LDS allocator pointers ---- lds_pong_offset = allocator_pong._align(allocator_pong.ptr, 16) allocator_pong.ptr = lds_pong_offset + _pong_buffer_bytes - _lds_tid_offset_pong = allocator_pong._align(allocator_pong.ptr, 4) - allocator_pong.ptr = _lds_tid_offset_pong + lds_tid_bytes - + if not is_a16w4_stage1: + _lds_tid_offset_pong = allocator_pong._align(allocator_pong.ptr, 4) + allocator_pong.ptr = _lds_tid_offset_pong + lds_tid_bytes lds_ping_offset = allocator_ping._align(allocator_ping.ptr, 16) allocator_ping.ptr = lds_ping_offset + _ping_buffer_bytes - if waves_per_eu is not None and waves_per_eu >= 1: - _total_cu_lds = 160 * 1024 - _min_lds = _total_cu_lds // (waves_per_eu + 1) + 1 - _pong_sz = allocator_pong._align(allocator_pong.ptr, 128) - _ping_sz = allocator_ping._align(allocator_ping.ptr, 128) - _cur_lds = _pong_sz + _ping_sz - if _cur_lds < _min_lds: - allocator_ping.ptr += _min_lds - _cur_lds - - kpack_bytes = 8 if is_int4 else 16 - out_elem_bytes = 4 if out_is_f32 else 2 - - _e_vec_s1 = min(tile_n // 32, 8) - if _need_quant: - _e_vec_s1 = max(2, _e_vec_s1) - _num_threads_per_quant_blk_s1 = 32 // _e_vec_s1 - _shuffle_dists_s1 = [] - _sh_val = 1 - while _sh_val < _num_threads_per_quant_blk_s1: - _shuffle_dists_s1.append(_sh_val) - _sh_val *= 2 - _num_shuffle_steps_s1 = len(_shuffle_dists_s1) - - # ---- Unified pipeline schedule (outside @flyc.kernel) ---- - # Each scheduling phase is a dict: - # mfma: [(k_idx, mi_idx, ikxdl, imxdl, asv_idx), ...] - # a_reads: [(k, mi), ...] # A ds_read subtiles - # b_loads: [('gate'/'up', ku, ni), ...] # B VMEM loads - # has_scale: bool # A/B scale VMEM loads - _pipe_m_repeat = tile_m // 16 - _pipe_k_unroll = tile_k_bytes // 128 - _pipe_k_unroll_packed = _pipe_k_unroll // pack_K - _pipe_m_repeat_packed = _pipe_m_repeat // pack_M - _pipe_num_acc_n = n_per_wave // 16 - - # A ds_read groups: group by mi (same mi, all k values together) - _pipe_a_groups = [] - for _mi in range(_pipe_m_repeat): - _grp = [] - for _k in range(_pipe_k_unroll): - _grp.append((_k, _mi)) - if len(_grp) == 2: + # ---- Generic-only postlude (pipeline schedule + quant block layout) ---- + if not is_a16w4_stage1: + out_elem_bytes = 4 if out_is_f32 else 2 + + _e_vec_s1 = min(tile_n // 32, 8) + if _need_quant: + _e_vec_s1 = max(2, _e_vec_s1) + _num_threads_per_quant_blk_s1 = 32 // _e_vec_s1 + _shuffle_dists_s1 = [] + _sh_val = 1 + while _sh_val < _num_threads_per_quant_blk_s1: + _shuffle_dists_s1.append(_sh_val) + _sh_val *= 2 + _num_shuffle_steps_s1 = len(_shuffle_dists_s1) + + # Unified pipeline schedule (outside @flyc.kernel). Each phase descriptor: + # mfma: [(k_idx, mi_idx, ikxdl, imxdl, asv_idx), ...] + # a_reads: [(k, mi), ...] # A ds_read subtiles + # b_loads: [('gate'/'up', ku, ni)] # B VMEM loads + # has_scale: bool # A/B scale VMEM loads + _pipe_m_repeat = tile_m // 16 + _pipe_k_unroll = tile_k_bytes // 128 + _pipe_k_unroll_packed = _pipe_k_unroll // pack_K + _pipe_m_repeat_packed = _pipe_m_repeat // pack_M + _pipe_num_acc_n = n_per_wave // 16 + + # A ds_read groups: same mi, all k-values together + _pipe_a_groups = [] + for _mi in range(_pipe_m_repeat): + _grp = [] + for _k in range(_pipe_k_unroll): + _grp.append((_k, _mi)) + if len(_grp) == 2: + _pipe_a_groups.append(_grp) + _grp = [] + if _grp: _pipe_a_groups.append(_grp) - _grp = [] - if _grp: - _pipe_a_groups.append(_grp) - - # B VMEM loads: individual gate/up loads - _pipe_b_loads = [] - for ku in range(_pipe_k_unroll): - for ni in range(_pipe_num_acc_n): - _pipe_b_loads.append(("gate", ku, ni)) - if not mock_gate_only and not gate_up_interleave: - _pipe_b_loads.append(("up", ku, ni)) - - # MFMA order: B-major (fix B, cycle all A tiles before next B) - # Each entry: one (k, ni) pair; the compute function loops over all mi. - # This keeps B operands (from VMEM) fixed while cycling A (from LDS, no wait). - _pipe_num_acc_n_packed = _pipe_num_acc_n // pack_N - _pipe_all_mfma = [] - for _ku128 in range(_pipe_k_unroll_packed): - for _ni_packed in range(_pipe_num_acc_n_packed): - for _ikxdl in range(pack_K): - for _inxdl in range(pack_N): - _k_idx = _ku128 * pack_K + _ikxdl - _ni_idx = _ni_packed * pack_N + _inxdl - _pipe_all_mfma.append((_k_idx, _ni_idx, _ikxdl, _inxdl, _ku128)) - - # Group MFMAs per scheduling phase (wider M -> more MFMAs per phase) - _pipe_mfma_per_phase = max(1, len(_pipe_all_mfma) // 4) - _pipe_n_phases = len(_pipe_all_mfma) // _pipe_mfma_per_phase - - # Build unified phase descriptors - _a_groups_per_phase = (len(_pipe_a_groups) + _pipe_n_phases - 1) // _pipe_n_phases - _pipe_phases = [] - _mfma_i = 0 - _a_i = 0 - for _p in range(_pipe_n_phases): - _a_reads = [] - for _ in range(_a_groups_per_phase): - if _a_i < len(_pipe_a_groups): - _a_reads.extend(_pipe_a_groups[_a_i]) - _a_i += 1 - _phase = { - "mfma": _pipe_all_mfma[_mfma_i : _mfma_i + _pipe_mfma_per_phase], - "a_reads": _a_reads, - "b_loads": [], - "has_scale": (_p == 0), - } - _mfma_i += _pipe_mfma_per_phase - _pipe_phases.append(_phase) - - # Distribute B loads evenly across phases 1..n-1 (phase 0 has scales) - _bi = 0 - for _p in range(1, _pipe_n_phases): - _rem_b = len(_pipe_b_loads) - _bi - _rem_p = _pipe_n_phases - _p - _n_b = (_rem_b + _rem_p - 1) // _rem_p if _rem_p > 0 else 0 - for _ in range(_n_b): - if _bi < len(_pipe_b_loads): - _pipe_phases[_p]["b_loads"].append(_pipe_b_loads[_bi]) - _bi += 1 - - # Extract flat lists for kernel access (avoids dict access in AST rewriter) - _pp_mfma = [p["mfma"] for p in _pipe_phases] - _pp_a_reads = [p["a_reads"] for p in _pipe_phases] - _pp_b_loads = [p["b_loads"] for p in _pipe_phases] - _pp_has_scale = [p["has_scale"] for p in _pipe_phases] - - fp4_ratio = 2 if a_dtype == "fp4" else 1 - gui_ratio = 1 if gate_up_interleave else 2 - _vmcnt_before_barrier = tile_m // 32 // fp4_ratio + tile_n // 32 * gui_ratio - - if True: - - @flyc.kernel - def moe_gemm1( - arg_out: fx.Tensor, - arg_x: fx.Tensor, - arg_w: fx.Tensor, - arg_scale_x: fx.Tensor, - arg_scale_w: fx.Tensor, - arg_sorted_token_ids: fx.Tensor, - arg_expert_ids: fx.Tensor, - arg_sorted_weights: fx.Tensor, - arg_num_valid_ids: fx.Tensor, - arg_bias: fx.Tensor, - arg_out_scale_sorted: fx.Tensor, - i32_tokens_in: fx.Int32, - i32_n_in: fx.Int32, - i32_k_in: fx.Int32, - i32_size_expert_ids_in: fx.Int32, - ): - tokens_in = arith.index_cast(ir.IndexType.get(), i32_tokens_in.ir_value()) - n_in = arith.index_cast(ir.IndexType.get(), i32_n_in.ir_value()) - k_in = arith.index_cast(ir.IndexType.get(), i32_k_in.ir_value()) + # B VMEM loads: gate/up + _pipe_b_loads = [] + for ku in range(_pipe_k_unroll): + for ni in range(_pipe_num_acc_n): + _pipe_b_loads.append(("gate", ku, ni)) + if not mock_gate_only and not gate_up_interleave: + _pipe_b_loads.append(("up", ku, ni)) + + # MFMA order: B-major (fix B, cycle all A tiles before next B) + _pipe_num_acc_n_packed = _pipe_num_acc_n // pack_N + _pipe_all_mfma = [] + for _ku128 in range(_pipe_k_unroll_packed): + for _ni_packed in range(_pipe_num_acc_n_packed): + for _ikxdl in range(pack_K): + for _inxdl in range(pack_N): + _k_idx = _ku128 * pack_K + _ikxdl + _ni_idx = _ni_packed * pack_N + _inxdl + _pipe_all_mfma.append((_k_idx, _ni_idx, _ikxdl, _inxdl, _ku128)) + + _pipe_mfma_per_phase = max(1, len(_pipe_all_mfma) // 4) + _pipe_n_phases = len(_pipe_all_mfma) // _pipe_mfma_per_phase + + _a_groups_per_phase = (len(_pipe_a_groups) + _pipe_n_phases - 1) // _pipe_n_phases + _pipe_phases = [] + _mfma_i = 0 + _a_i = 0 + for _p in range(_pipe_n_phases): + _a_reads = [] + for _ in range(_a_groups_per_phase): + if _a_i < len(_pipe_a_groups): + _a_reads.extend(_pipe_a_groups[_a_i]) + _a_i += 1 + _phase = { + "mfma": _pipe_all_mfma[_mfma_i : _mfma_i + _pipe_mfma_per_phase], + "a_reads": _a_reads, + "b_loads": [], + "has_scale": (_p == 0), + } + _mfma_i += _pipe_mfma_per_phase + _pipe_phases.append(_phase) + + # Distribute B loads evenly across phases 1..n-1 (phase 0 has scales) + _bi = 0 + for _p in range(1, _pipe_n_phases): + _rem_b = len(_pipe_b_loads) - _bi + _rem_p = _pipe_n_phases - _p + _n_b = (_rem_b + _rem_p - 1) // _rem_p if _rem_p > 0 else 0 + for _ in range(_n_b): + if _bi < len(_pipe_b_loads): + _pipe_phases[_p]["b_loads"].append(_pipe_b_loads[_bi]) + _bi += 1 + + _pp_mfma = [p["mfma"] for p in _pipe_phases] + _pp_a_reads = [p["a_reads"] for p in _pipe_phases] + _pp_b_loads = [p["b_loads"] for p in _pipe_phases] + _pp_has_scale = [p["has_scale"] for p in _pipe_phases] + + fp4_ratio = 2 if a_dtype == "fp4" else 1 + gui_ratio = 1 if gate_up_interleave else 2 + _vmcnt_before_barrier = tile_m // 32 // fp4_ratio + tile_n // 32 * gui_ratio + + # ==== End unified setup ==== + @flyc.kernel + def moe_gemm1( + arg_out: fx.Tensor, + arg_x: fx.Tensor, + arg_w: fx.Tensor, + arg_scale_x: fx.Tensor, + arg_scale_w: fx.Tensor, + arg_sorted_token_ids: fx.Tensor, + arg_expert_ids: fx.Tensor, + arg_sorted_weights: fx.Tensor, + arg_num_valid_ids: fx.Tensor, + arg_bias: fx.Tensor, + arg_out_scale_sorted: fx.Tensor, + i32_tokens_in: fx.Int32, + i32_n_in: fx.Int32, + i32_k_in: fx.Int32, + i32_size_expert_ids_in: fx.Int32, + ): + if const_expr(is_a16w4_stage1): + _ = arg_scale_x + _ = arg_out_scale_sorted + tokens_in = arith.index_cast(T.index, i32_tokens_in.ir_value()) + inter_in = arith.ArithValue( + arith.index_cast(T.index, i32_n_in.ir_value()) + ) + k_in = arith.index_cast(T.index, i32_k_in.ir_value()) size_expert_ids_in = arith.index_cast( - ir.IndexType.get(), i32_size_expert_ids_in.ir_value() + T.index, i32_size_expert_ids_in.ir_value() ) + k_i32_v = i32_k_in.ir_value() - x_elem = T.f16 if is_f16_a else (T.i8 if is_int8 else T.f8) + x_elem = T.bf16 + w_elem = T.i8 + f16 = T.f16 f32 = T.f32 i32 = T.i32 i64 = T.i64 vec4_f32 = T.vec(4, f32) - vec16_elems = 16 if a_elem_bytes == 1 else 8 - vec16_x = T.vec(vec16_elems, x_elem) + vec1_f16 = T.vec(1, f16) + vec8_bf16 = T.vec(8, x_elem) + vec16_x = T.vec(8, x_elem) vec2_i64 = T.vec(2, i64) + def _silu_elem(g): + neg_log2e = arith.constant(-1.4426950408889634, type=f32) + t = g * neg_log2e + emu = llvm.call_intrinsic(f32, "llvm.amdgcn.exp2.f32", [t], [], []) + one = arith.constant(1.0, type=f32) + den = one + emu + sig = llvm.call_intrinsic(f32, "llvm.amdgcn.rcp.f32", [den], [], []) + return g * sig + + silu = _silu_elem + + def _silu_mul_vec4(gate_v4, up_v4): + result_elems = [] + for ei in range_constexpr(4): + g = vector.extract( + gate_v4, static_position=[ei], dynamic_position=[] + ) + u = vector.extract( + up_v4, static_position=[ei], dynamic_position=[] + ) + result_elems.append(_silu_elem(g) * u) + return vector.from_elements(vec4_f32, result_elems) + + def _swiglu_mul_vec4(gate_v4, up_v4): + result_elems = [] + _alpha = arith.constant(1.702, type=f32) + _limit = arith.constant(7.0, type=f32) + _neg_limit = arith.constant(-7.0, type=f32) + _one = arith.constant(1.0, type=f32) + _neg_log2e = arith.constant(-1.4426950408889634, type=f32) + for ei in range_constexpr(4): + g = vector.extract( + gate_v4, static_position=[ei], dynamic_position=[] + ) + u = vector.extract( + up_v4, static_position=[ei], dynamic_position=[] + ) + g = arith.minimumf(g, _limit) + u = arith.minimumf(u, _limit) + u = arith.maximumf(u, _neg_limit) + t = g * _alpha * _neg_log2e + emu = llvm.call_intrinsic( + f32, "llvm.amdgcn.exp2.f32", [t], [], [] + ) + den = _one + emu + sig = llvm.call_intrinsic( + f32, "llvm.amdgcn.rcp.f32", [den], [], [] + ) + result_elems.append(g * sig * (u + _one)) + return vector.from_elements(vec4_f32, result_elems) + + def _act_vec4(gate_v4, up_v4): + if const_expr(act == "swiglu"): + return _swiglu_mul_vec4(gate_v4, up_v4) + else: + return _silu_mul_vec4(gate_v4, up_v4) + + def _act_elem(g_e, u_e): + if const_expr(act == "swiglu"): + _alpha = arith.constant(1.702, type=f32) + _limit = arith.constant(7.0, type=f32) + _neg_limit = arith.constant(-7.0, type=f32) + _one = arith.constant(1.0, type=f32) + _neg_log2e = arith.constant(-1.4426950408889634, type=f32) + g_e = arith.minimumf(g_e, _limit) + u_e = arith.minimumf(u_e, _limit) + u_e = arith.maximumf(u_e, _neg_limit) + t = g_e * _alpha * _neg_log2e + emu = llvm.call_intrinsic( + f32, "llvm.amdgcn.exp2.f32", [t], [], [] + ) + den = _one + emu + sig = llvm.call_intrinsic( + f32, "llvm.amdgcn.rcp.f32", [den], [], [] + ) + return g_e * sig * (u_e + _one) + else: + return _silu_elem(g_e) * u_e + acc_init = arith.constant_vector(0.0, vec4_f32) - # --- Stage1 dimension mapping --- - # X: [tokens, model_dim] -- M = sorted tokens, K = model_dim - # W: [E*2*inter_dim, model_dim] gate portion -- N = inter_dim - # Out: [tokens*topk, inter_dim] + layout_x = fx.make_layout( + (arith.index_cast(i32, tokens_in), k_i32_v), stride=(k_i32_v, 1) + ) - # B preshuffle layout: [E*2*inter_dim, model_dim] - # Gate rows for expert e: [e*2*inter_dim, e*2*inter_dim + inter_dim) - c_n_total = arith.constant(experts * (2 * inter_dim), index=True) + # Gate+up interleaved: N_total = experts * 2 * inter_dim + c_n_total = arith.index(experts * (2 * inter_dim)) + c2 = arith.index(2) + c_k_packed = k_in // c2 b_layout = make_preshuffle_b_layout( arith, c_n=c_n_total, - c_k=k_in // pack_K, + c_k=c_k_packed, kpack_bytes=kpack_bytes, - elem_bytes=b_elem_bytes, - # k_major=True, + elem_bytes=1, ) layout_b = b_layout.layout_b - # A-scale: [sorted_size, K/32] -- pre-scattered by caller into sorted layout - # Same as stage2: indexed by sorted_row position, not by token_id. - sorted_m = size_expert_ids_in * arith.constant(sort_block_m, index=True) - layout_a_scale = make_preshuffle_scale_layout( - arith, c_mn=sorted_m, c_k=arith.constant(model_dim, index=True) - ) - # B-scale: [E*2*inter_dim, K/32] layout_b_scale = make_preshuffle_scale_layout( - arith, c_mn=c_n_total, c_k=arith.constant(model_dim, index=True) + arith, + c_mn=c_n_total, + c_k=k_in, + mn_pack=2, + k_pack=2, + elem_bytes=4, + scale_block_size=32, ) - _eff_lds_stride = lds_stride - _eff_tile_k_bytes = tile_k_bytes - if const_expr(use_async_copy and a_elem_vec_pack > 1): - _eff_lds_stride = lds_stride // a_elem_vec_pack - _eff_tile_k_bytes = tile_k_bytes // a_elem_vec_pack - - shape_lds = fx.make_shape(tile_m, _eff_lds_stride) - stride_lds = fx.make_stride(_eff_lds_stride, 1) + shape_lds = fx.make_shape(tile_m, tile_k) + stride_lds = fx.make_stride(lds_stride, 1) layout_lds = fx.make_layout(shape_lds, stride_lds) tx = gpu.thread_id("x") - by = gpu.block_id("x") # tile along inter_dim (N) - bx_persist = gpu.block_id("y") # persistent WG index - - if const_expr(xcd_swizzle > 0): - _NUM_XCDS_S1 = 8 - _c1_sw = arith.constant(1, index=True) - _c_tn_sw = arith.constant(tile_n, index=True) - _c_idp_sw = arith.constant(2 * inter_dim_pad, index=True) - if const_expr(mock_gate_only or gate_up_interleave): - _gx = (n_in - _c_idp_sw + _c_tn_sw - _c1_sw) / _c_tn_sw - else: - _c2_sw = arith.constant(2, index=True) - _gx = ( - (n_in - _c_idp_sw + _c2_sw * _c_tn_sw - _c1_sw) - / _c_tn_sw - / _c2_sw - ) - _c_pm_sw = arith.constant(persist_m, index=True) - _gy = (size_expert_ids_in + _c_pm_sw - _c1_sw) / _c_pm_sw - - _linear_id = bx_persist * _gx + by - _num_wgs = _gx * _gy - - _c_xcds = arith.constant(_NUM_XCDS_S1, index=True) - _wgs_per_xcd = _num_wgs / _c_xcds - _wgid = (_linear_id % _c_xcds) * _wgs_per_xcd + (_linear_id / _c_xcds) - - _WGM_S1 = xcd_swizzle - _c_wgm = arith.constant(_WGM_S1, index=True) - _num_wgid_in_group = _c_wgm * _gx - _group_id = _wgid / _num_wgid_in_group - _first_pid_m = _group_id * _c_wgm - _remaining_m = _gy - _first_pid_m - _cmp_m = arith.cmpi(CmpIPredicate.ult, _remaining_m, _c_wgm) - _group_size_m = arith.select(_cmp_m, _remaining_m, _c_wgm) - - _wgid_in_group = _wgid % _num_wgid_in_group - bx_persist = _first_pid_m + (_wgid_in_group % _group_size_m) - by = _wgid_in_group / _group_size_m - by_n = by * arith.constant(tile_n, index=True) + by = gpu.block_id("x") + bx = gpu.block_id("y") - k_base_idx = arith.index(0) - if const_expr(_is_splitk): - bz = gpu.block_id("z") # K-batch id - k_base_idx = bz * arith.constant(_k_dim, index=True) + bx_m = bx * arith.index(tile_m) + numids_rsrc = buffer_ops.create_buffer_resource( + arg_num_valid_ids, max_size=False, + num_records_bytes=arith.constant(4, type=i32), + ) + num_valid_i32 = buffer_ops.buffer_load( + numids_rsrc, arith.constant(0, index=True), vec_width=1, dtype=i32 + ) + num_valid_idx = arith.index_cast(T.index, num_valid_i32) + bx_m_i32 = arith.index_cast(i32, bx_m) + blk_valid = arith.cmpi(arith.CmpIPredicate.ult, bx_m_i32, num_valid_i32) - k_blocks16 = arith.constant(_eff_tile_k_bytes // 16, index=True) - layout_tx_wave_lane = fx.make_layout((num_waves, 64), stride=(64, 1)) + k_blocks16 = arith.index(tile_k_bytes // 16) + layout_tx_wave_lane = fx.make_layout((4, 64), stride=(64, 1)) layout_lane16 = fx.make_layout((4, 16), stride=(16, 1)) - base_ptr_pong = allocator_pong.get_base() - base_ptr_ping = allocator_ping.get_base() - lds_x_pong = SmemPtr( - base_ptr_pong, lds_pong_offset, x_lds_elem(), shape=(_input_elems,) - ).get() - lds_x_ping = SmemPtr( - base_ptr_ping, lds_ping_offset, x_lds_elem(), shape=(_input_elems,) - ).get() - _lds_out_elem_type = ( - T.f32 if _need_quant else (T.bf16 if out_is_bf16 else T.f16) - ) - if const_expr(_split_lds_out and _use_cshuffle_epilog): - _half_out_elems = int(tile_m) * (int(tile_n) // 2) - lds_out = SmemPtr( - base_ptr_pong, - lds_pong_offset, - _lds_out_elem_type, - shape=(_half_out_elems,), + _if_blk = scf.IfOp(blk_valid) + with _if_then(_if_blk): + base_ptr_pong = allocator_pong.get_base() + base_ptr_ping = allocator_ping.get_base() + lds_x_pong = SmemPtr( + base_ptr_pong, lds_pong_offset, T.bf16, + shape=(_single_x_elems,), ).get() - lds_out_B = SmemPtr( - base_ptr_ping, - lds_ping_offset, - _lds_out_elem_type, - shape=(_half_out_elems,), + lds_x_ping = SmemPtr( + base_ptr_ping, lds_ping_offset, T.bf16, + shape=(_single_x_elems,), ).get() - else: lds_out = ( SmemPtr( - base_ptr_pong, - lds_pong_offset, - _lds_out_elem_type, - shape=(tile_m * tile_n,), + base_ptr_pong, lds_pong_offset, out_mlir(), + shape=(tile_m * _gui_out_tile_n,), ).get() if _use_cshuffle_epilog else None ) - lds_out_B = None - lds_tid = SmemPtr( - base_ptr_pong, _lds_tid_offset_pong, T.i32, shape=(tile_m,) - ).get() - - # Buffer resources - c_a_pack = arith.constant(int(a_elem_vec_pack), index=True) - c_elem_bytes = arith.constant(int(a_elem_bytes), index=True) - - # X: [tokens, model_dim] - x_nbytes_idx = (tokens_in * k_in * c_elem_bytes) / c_a_pack - x_nbytes_i32 = arith.index_cast(T.i32, x_nbytes_idx) - x_rsrc = buffer_ops.create_buffer_resource( - arg_x, max_size=False, num_records_bytes=x_nbytes_i32 - ) - - w_rsrc = buffer_ops.create_buffer_resource(arg_w, max_size=False) - - # Out: [tokens*topk, inter_dim] - numids_rsrc = buffer_ops.create_buffer_resource( - arg_num_valid_ids, - max_size=False, - num_records_bytes=arith.constant(4, type=T.i32), - ) - num_valid_i32 = buffer_ops.buffer_load( - numids_rsrc, arith.constant(0, index=True), vec_width=1, dtype=T.i32 - ) - sx_rsrc = 1 - sw_rsrc = 1 - if const_expr(not (is_f16_a or a_scale_one)): - # A scale: [sorted_size, model_dim/32] pre-scattered by caller - c32 = arith.constant(32, index=True) - kblk = k_in / c32 - sx_nbytes_idx = sorted_m * kblk - sx_nbytes_i32 = arith.index_cast(T.i32, sx_nbytes_idx) - sx_rsrc = buffer_ops.create_buffer_resource( - arg_scale_x, max_size=False, num_records_bytes=sx_nbytes_i32 + c_topk = arith.index(topk) + x_nbytes_idx = tokens_in * k_in * arith.index(int(elem_bytes)) + x_rsrc = buffer_ops.create_buffer_resource( + arg_x, max_size=False, num_records_bytes=x_nbytes_idx ) + w_rsrc = buffer_ops.create_buffer_resource(arg_w, max_size=False) + sw_rsrc = buffer_ops.create_buffer_resource(arg_scale_w, max_size=False) - if const_expr(not is_f16_b): - c32 = arith.constant(32, index=True) - kblk_w = k_in / c32 - mn_w = arith.constant(experts * (2 * inter_dim), index=True) - sw_nbytes_idx = mn_w * kblk_w - sw_nbytes_i32 = arith.index_cast(T.i32, sw_nbytes_idx) - sw_rsrc = buffer_ops.create_buffer_resource( - arg_scale_w, max_size=False, num_records_bytes=sw_nbytes_i32 + # Split-K uses f32 atomics → out_elem_bytes = 4 + # Rename to avoid shadowing the enclosing-scope setup var. + _a16_out_elem_bytes = 4 if _is_splitk else 2 + out_nbytes_idx = ( + tokens_in * c_topk * inter_in * arith.index(_a16_out_elem_bytes) + ) + out_rsrc = buffer_ops.create_buffer_resource( + arg_out, max_size=False, num_records_bytes=out_nbytes_idx ) - sorted_nbytes_idx = size_expert_ids_in * arith.constant( - sort_block_m * 4, index=True - ) - sorted_nbytes_i32 = arith.index_cast(T.i32, sorted_nbytes_idx) - sorted_rsrc = buffer_ops.create_buffer_resource( - arg_sorted_token_ids, - max_size=False, - num_records_bytes=sorted_nbytes_i32, - ) - sorted_w_rsrc = buffer_ops.create_buffer_resource( - arg_sorted_weights, max_size=False, num_records_bytes=sorted_nbytes_i32 - ) + sorted_rsrc = buffer_ops.create_buffer_resource( + arg_sorted_token_ids, max_size=False + ) + sorted_w_rsrc = buffer_ops.create_buffer_resource( + arg_sorted_weights, max_size=False + ) + expert_rsrc = buffer_ops.create_buffer_resource( + arg_expert_ids, max_size=False, + num_records_bytes=(size_expert_ids_in * arith.index(4)), + ) - eid_nbytes_idx = size_expert_ids_in * arith.constant(4, index=True) - eid_nbytes_i32 = arith.index_cast(T.i32, eid_nbytes_idx) - expert_rsrc = buffer_ops.create_buffer_resource( - arg_expert_ids, max_size=False, num_records_bytes=eid_nbytes_i32 - ) - bias_rsrc = ( - buffer_ops.create_buffer_resource(arg_bias, max_size=False) - if enable_bias - else None - ) + expert_i32 = buffer_ops.buffer_load( + expert_rsrc, bx, vec_width=1, dtype=i32 + ) + exp_valid = arith.cmpi( + arith.CmpIPredicate.ult, + expert_i32, + arith.constant(experts, type=i32), + ) + expert_idx = arith.index_cast(T.index, expert_i32) + inter2_idx = arith.index(2 * inter_dim) + expert_off_idx = expert_idx * inter2_idx - # Sorted-scale buffer resource for fused mxfp4 quantization - _sorted_scale_cols = inter_dim // 32 - _sorted_scale_cols_i32 = arith.constant(_sorted_scale_cols, type=T.i32) - sorted_scale_rsrc = None - if const_expr(_need_sort): - sorted_scale_rsrc = buffer_ops.create_buffer_resource( - arg_out_scale_sorted, max_size=False + bias_rsrc = ( + buffer_ops.create_buffer_resource(arg_bias, max_size=False) + if enable_bias + else None ) - # ---- persist_m loop (same pattern as stage2) ---- - _PERSIST_M = persist_m - _c0_p = arith.constant(0, index=True) - _c1_p = arith.constant(1, index=True) - _c_pm = arith.constant(_PERSIST_M, index=True) - _for_persist = scf.ForOp(_c0_p, _c_pm, _c1_p) - _for_ip = ir.InsertionPoint(_for_persist.body) - _for_ip.__enter__() - _mi_p = _for_persist.induction_variable - bx = bx_persist * _c_pm + _mi_p - bx_m = bx * arith.constant(sort_block_m, index=True) - - # Block validity - bx_m_i32 = arith.index_cast(T.i32, bx_m) - blk_valid = arith.cmpi(CmpIPredicate.ult, bx_m_i32, num_valid_i32) - expert_i32 = buffer_ops.buffer_load( - expert_rsrc, bx, vec_width=1, dtype=T.i32 - ) - expert_idx = arith.index_cast(ir.IndexType.get(), expert_i32) - exp_valid = arith.cmpi( - CmpIPredicate.ult, expert_i32, arith.constant(experts, type=T.i32) - ) - - def _moe_gemm1_body(): - # Gate expert offset: first inter_dim rows of each expert's 2*inter_dim block - expert_off_idx = expert_idx * arith.constant(2 * inter_dim, index=True) - - # X loading -- KEY DIFFERENCE from stage2: X row = token_id only - x_load_bytes = 16 + if const_expr(bytes_per_thread_x >= 16 and bytes_per_thread_x % 16 == 0): + x_load_bytes = 16 + elif const_expr(bytes_per_thread_x >= 8 and bytes_per_thread_x % 8 == 0): + x_load_bytes = 8 + elif const_expr(bytes_per_thread_x >= 4 and bytes_per_thread_x % 4 == 0): + x_load_bytes = 4 + else: + raise ValueError( + f"bytes_per_thread_x ({bytes_per_thread_x}) must be " + f"divisible by 4" + ) num_x_loads = bytes_per_thread_x // x_load_bytes chunk_i32 = x_load_bytes // 4 + x_vec_elems = x_load_bytes // elem_bytes + x_vec_i32_ty = T.vec(chunk_i32, i32) if chunk_i32 > 1 else T.vec(1, i32) + x_vec_x_ty = T.vec(x_vec_elems, x_elem) - c_k_div4 = ( - (k_in / c_a_pack) * arith.constant(int(a_elem_bytes), index=True) - ) / arith.index(4) - tile_k_dwords = (int(tile_k) * int(a_elem_bytes)) // ( - 4 * int(a_elem_vec_pack) - ) + c_k_div4 = (k_in * arith.index(int(elem_bytes))) // arith.index(4) + c_k_div4_i32 = arith.index_cast(i32, c_k_div4) + tile_k_dwords = (int(tile_k) * int(elem_bytes)) // 4 layout_x_tile_div4 = fx.make_layout( (tile_m, tile_k_dwords), stride=(tile_k_dwords, 1) ) - c_chunk_i32 = arith.constant(chunk_i32, index=True) + c_chunk_i32 = arith.index(chunk_i32) tx_i32_base = tx * c_chunk_i32 - - topk_i32 = arith.constant(topk) - mask24 = arith.constant(0xFFFFFF) - tokens_i32 = arith.index_cast(T.i32, tokens_in) + mask24 = arith.constant(0xFFFFFF, type=T.i32) + tokens_i32 = arith.index_cast(i32, tokens_in) + topk_i32 = arith.constant(topk, type=T.i32) def x_tile_chunk_coord_i32(i: int): return tile_chunk_coord_i32( @@ -734,25 +828,9 @@ def x_tile_chunk_coord_i32(i: int): chunk_i32=chunk_i32, ) - def load_x(idx_i32): - idx_elem = ( - idx_i32 if a_elem_bytes == 1 else (idx_i32 * arith.index(2)) - ) - return buffer_copy_gmem16_dwordx4( - buffer_ops, - vector, - elem_type=x_elem, - idx_i32=idx_elem, - rsrc=x_rsrc, - vec_elems=vec16_elems, - ) - - # Decode sorted token ids -- stage1: X row = token_id (not t*topk+s) x_row_base_div4 = [] x_col_local_i32 = [] x_row_local = [] - # Also store token_id and slot_id for output indexing - for i in range_constexpr(num_x_loads): row_local, col_local_i32 = x_tile_chunk_coord_i32(i) x_row_local.append(row_local) @@ -760,320 +838,340 @@ def load_x(idx_i32): sorted_row_i = bx_m + row_local fused_i = buffer_ops.buffer_load( - sorted_rsrc, sorted_row_i, vec_width=1, dtype=T.i32 + sorted_rsrc, sorted_row_i, vec_width=1, dtype=i32 ) - t_i32 = arith.andi(fused_i, mask24) - s_i32 = arith.shrui(fused_i, arith.constant(24)) - t_valid = arith.cmpi(CmpIPredicate.ult, t_i32, tokens_i32) - s_valid = arith.cmpi(CmpIPredicate.ult, s_i32, topk_i32) + t_i32 = fused_i & mask24 + s_i32 = arith.shrui(fused_i, arith.constant(24, type=i32)) + t_valid = arith.cmpi(arith.CmpIPredicate.ult, t_i32, tokens_i32) + s_valid = arith.cmpi(arith.CmpIPredicate.ult, s_i32, topk_i32) ts_valid = arith.andi(t_valid, s_valid) - t_safe = arith.select(ts_valid, t_i32, arith.constant(0)) - - # KEY: X row base uses token_id only (not t*topk+s) - t_idx = arith.index_cast(ir.IndexType.get(), t_safe) - x_row_base_div4.append(t_idx * c_k_div4) + t_safe = arith.select( + ts_valid, + arith.index_cast(T.index, t_i32), + arith.index(0), + ) + x_row_base_div4.append(t_safe * c_k_div4) - def load_x_tile(base_k): - base_k_div4 = ( - (base_k / c_a_pack) - * arith.constant(int(a_elem_bytes), index=True) - ) / arith.index(4) - parts = [] - for i in range_constexpr(num_x_loads): - idx_i32 = x_row_base_div4[i] + base_k_div4 + x_col_local_i32[i] - x_vec = load_x(idx_i32) - parts.append(vector.bitcast(T.vec(4, i32), x_vec)) - return parts + # NOTE: load_x_tile is unused on the stage1 A16W4 path -- B-cycle + # consumes X via dma_x_tile_to_lds + LDS reads. Removed. - # Wave/lane decomposition (identical to stage2) coord_wl = idx2crd(tx, layout_tx_wave_lane) wave_id = layout_get(coord_wl, 0) lane_id = layout_get(coord_wl, 1) + + _dma_bytes = 16 + _wave_size = 64 + + def dma_x_tile_to_lds(base_k, lds_buffer): + """Async DMA: global -> LDS via buffer_load_lds, no VGPR.""" + c4_idx = arith.index(4) + base_k_div4 = ( + base_k * arith.index(int(elem_bytes)) + ) // arith.index(4) + + lds_ptr_i64 = None + for i in range_constexpr(num_x_loads): + row_local_i = x_row_local[i] + col_local_i32_i = x_col_local_i32[i] + col_local_sw = swizzle_xor16( + row_local_i, col_local_i32_i * c4_idx, k_blocks16 + ) + row_k_dw = x_row_base_div4[i] + base_k_div4 + global_byte_idx = row_k_dw * c4_idx + col_local_sw + global_offset = arith.index_cast(i32, global_byte_idx) + + if const_expr(i == 0): + lds_addr = memref.extract_aligned_pointer_as_index( + lds_buffer + ) + wave_id * arith.constant( + _wave_size * _dma_bytes, index=True + ) + lds_ptr_i64 = rocdl.readfirstlane( + i64, arith.index_cast(i64, lds_addr) + ) + else: + lds_ptr_i64 = lds_ptr_i64 + arith.constant( + total_threads * _dma_bytes, type=i64 + ) + + lds_ptr_type = ir.Type.parse("!llvm.ptr<3>") + lds_ptr = llvm.inttoptr(lds_ptr_type, lds_ptr_i64) + + rocdl.raw_ptr_buffer_load_lds( + x_rsrc, + lds_ptr, + arith.constant(_dma_bytes, type=i32), + global_offset, + arith.constant(0, type=i32), + arith.constant(0, type=i32), + arith.constant(0, type=i32), + ) + + def prefetch_x_to_lds(base_k, lds_buffer): + dma_x_tile_to_lds(base_k, lds_buffer) + coord_l16 = idx2crd(lane_id, layout_lane16) lane_div_16 = layout_get(coord_l16, 0) lane_mod_16 = layout_get(coord_l16, 1) + row_a_lds = lane_mod_16 - col_offset_base = lane_div_16 * arith.constant(16, index=True) + # CK-style A addressing: sub-lane L covers K[L*32..L*32+31] + # within each k0 block (128 K elements). + # col_offset_base = lane_div_16 * 32 bf16 = lane_div_16 * 64 bytes + # stride per ku step = 8 bf16 = 16 bytes + # For k0 boundaries: ku=4 wraps to next k0 block (+256 bytes) + _a_sublane_stride = 64 # 32 bf16 * 2 bytes + _a_ku_stride_bytes = 16 # 8 bf16 * 2 bytes + col_offset_base_bytes = lane_div_16 * arith.index(_a_sublane_stride) + + by_n = by * arith.index(tile_n) + _waves_per_group = 4 // _split_k_intra + # NOTE: shadows enclosing-scope `n_per_wave`; rename to avoid + # making it look like a local of moe_gemm1 (which would break + # the generic body's enclosing-scope reference under FlyDSL). + _a16_n_per_wave = tile_n // _waves_per_group + num_acc_n = _a16_n_per_wave // 16 + c_n_per_wave = arith.index(_a16_n_per_wave) + if const_expr(_split_k_intra > 1): + _wave_group = wave_id // arith.index(_waves_per_group) + _wave_in_group = wave_id % arith.index(_waves_per_group) + n_tile_base = _wave_in_group * c_n_per_wave + # A LDS: each wave group reads different K half + _k_half_bytes = (tile_k // _split_k_intra) * elem_bytes + col_offset_base_bytes = ( + col_offset_base_bytes + + _wave_group * arith.index(_k_half_bytes) + ) + else: + _wave_group = None + wave_mod_4 = wave_id % arith.index(4) + n_tile_base = wave_mod_4 * c_n_per_wave - num_acc_n = n_per_wave // 16 - c_n_per_wave = arith.constant(n_per_wave, index=True) - wave_n_id = wave_id % arith.constant(num_waves, index=True) - n_tile_base = wave_n_id * c_n_per_wave - - # N-tile precompute for gate AND up weights - gate_n_intra_list = [] - gate_n_blk_list = [] - up_n_intra_list = [] - up_n_blk_list = [] + n_intra_gate = [] + n_blk_gate = [] col_g_list = [] + inter_idx = arith.index(inter_dim) c_n0_static = experts * (2 * inter_dim) // 16 layout_n_blk_intra = fx.make_layout((c_n0_static, 16), stride=(16, 1)) - inter_idx = arith.constant(inter_dim, index=True) - for i in range_constexpr(num_acc_n): - offset = i * 16 - c_offset = arith.constant(offset, index=True) - if const_expr(not gate_up_interleave): - col_g = by_n + n_tile_base + c_offset + lane_mod_16 + if const_expr(not _single_b): + # Separated mode: gate and up are distinct N regions. + n_intra_up = [] + n_blk_up = [] + for ni in range_constexpr(num_acc_n): + offset = arith.index(ni * 16) + col_g = by_n + n_tile_base + offset + lane_mod_16 col_g_list.append(col_g) - global_n = by_n + n_tile_base + c_offset + lane_mod_16 - # Gate/interleave: rows [expert_off, expert_off + 2*inter_dim) - gate_row_w = expert_off_idx + global_n - gate_coord = idx2crd(gate_row_w, layout_n_blk_intra) - gate_n_blk_list.append(layout_get(gate_coord, 0)) - gate_n_intra_list.append(layout_get(gate_coord, 1)) - if const_expr(not mock_gate_only and not gate_up_interleave): - up_row_w = gate_row_w + inter_idx - up_coord = idx2crd(up_row_w, layout_n_blk_intra) - up_n_blk_list.append(layout_get(up_coord, 0)) - up_n_intra_list.append(layout_get(up_coord, 1)) - - if const_expr(gate_up_interleave): - _gui_num_acc_n_out = num_acc_n // pack_N - for _gui_i in range_constexpr(_gui_num_acc_n_out): - _gui_offset = _gui_i * 16 - _gui_c_offset = arith.constant(_gui_offset, index=True) - _gui_col_g = ( - (by_n + n_tile_base) // arith.constant(2, index=True) - + _gui_c_offset - + lane_mod_16 - ) - col_g_list.append(_gui_col_g) + row_gate = expert_off_idx + col_g + row_up = row_gate + inter_idx - m_repeat = tile_m // 16 - k_unroll = tile_k_bytes // 128 - k_unroll_packed = k_unroll // pack_K - m_repeat_packed = m_repeat // pack_M - num_acc_n_packed = num_acc_n // pack_N + coord_gate = idx2crd(row_gate, layout_n_blk_intra) + n_blk_gate.append(layout_get(coord_gate, 0)) + n_intra_gate.append(layout_get(coord_gate, 1)) - _K_per_ku = tile_k // k_unroll - _pad_k_elems = ( - (model_dim_pad % tile_k) - if (not _is_splitk and model_dim_pad > 0) - else 0 - ) - _pad_ku_skip = _pad_k_elems // _K_per_ku - _tail_ku = k_unroll - _pad_ku_skip - _tail_ku_packed = ( - (_tail_ku + pack_K - 1) // pack_K if _pad_ku_skip > 0 else None - ) + coord_up = idx2crd(row_up, layout_n_blk_intra) + n_blk_up.append(layout_get(coord_up, 0)) + n_intra_up.append(layout_get(coord_up, 1)) + else: + # gate_only / gate_up_interleave: single B stream + n_intra_up = None + n_blk_up = None + for ni in range_constexpr(num_acc_n): + offset = arith.index(ni * 16) + global_n = by_n + n_tile_base + offset + lane_mod_16 + gate_row_w = expert_off_idx + global_n + + coord_gate = idx2crd(gate_row_w, layout_n_blk_intra) + n_blk_gate.append(layout_get(coord_gate, 0)) + n_intra_gate.append(layout_get(coord_gate, 1)) + + if const_expr(gate_up_interleave and not _is_splitk): + if const_expr(_gui_xwave_fuse): + if const_expr(_split_k_intra > 1): + # Split_k xwave: all waves share same output cols + _gui_col_g = ( + by_n // arith.index(2) + lane_mod_16 + ) + else: + # 4-wave xwave: per-pair output cols + # pair (0,1)→cols[0:15], pair (2,3)→cols[16:31] + _xw_pair_off = ( + (wave_id // arith.index(2)) + * c_n_per_wave + ) + _gui_col_g = ( + by_n // arith.index(2) + + _xw_pair_off + + lane_mod_16 + ) + col_g_list.append(_gui_col_g) + else: + # Standard pair fusion: pairs of N subtiles → one output col + # Renamed from `pack_N` to avoid shadowing the enclosing scope. + _a16_pack_N = 2 + _gui_num_acc_n_out = num_acc_n // _a16_pack_N + for _gui_i in range_constexpr(_gui_num_acc_n_out): + _gui_offset = arith.index(_gui_i * 16) + _gui_col_g = ( + (by_n + n_tile_base) + // arith.index(2) + + _gui_offset + + lane_mod_16 + ) + col_g_list.append(_gui_col_g) + else: + # gate_only or interleave+splitk: output covers full N + for ni in range_constexpr(num_acc_n): + offset = arith.index(ni * 16) + col_g = by_n + n_tile_base + offset + lane_mod_16 + col_g_list.append(col_g) - # B load for gate and up separately - def load_b_packs_k64(base_k, ku: int, n_blk, n_intra): - c64 = arith.constant(64, index=True) - base_k_bytes = base_k * arith.constant( - int(b_elem_bytes), index=True - ) - k0 = base_k_bytes // c64 + arith.constant(ku, index=True) - k1 = lane_div_16 - coord_pack = (n_blk, k0, k1, n_intra, arith.constant(0, index=True)) - idx_pack = crd2idx(coord_pack, layout_b) - vec_elems = kpack_bytes // int(b_elem_bytes) - b16 = _buffer_load_vec( - buffer_ops, - vector, - w_rsrc, - idx_pack, - elem_type=_w_elem_type(), - vec_elems=vec_elems, - elem_bytes=b_elem_bytes, - offset_in_bytes=(b_elem_bytes == 1), - cache_modifier=b_nt, - ) - b_i64x2 = vector.bitcast(vec2_i64, b16) - b0 = vector.extract( - b_i64x2, static_position=[0], dynamic_position=[] - ) - b1 = vector.extract( - b_i64x2, static_position=[1], dynamic_position=[] - ) - return b0, b1 + m_repeat = tile_m // 16 + k_unroll = (tile_k_bytes // 64) // _split_k_intra - def load_b_tile(base_k, ku_limit=k_unroll): - """Load B tiles. Returns (gate_b_tile, up_b_tile). - When mock_gate_only or gate_up_interleave, up_b_tile is None.""" - gate_b_tile = [] - up_b_tile = ( - [] if (not mock_gate_only and not gate_up_interleave) else None + _pad_k_elems = (model_dim_pad % tile_k) if (not _is_splitk and _split_k_intra == 1 and model_dim_pad > 0) else 0 + _pad_ku_skip = _pad_k_elems // 32 + _tail_ku = k_unroll - _pad_ku_skip + _tail_k0_count = (_tail_ku + 3) // 4 if _pad_ku_skip > 0 else None + + # Each dwordx4 covers 128 K elements; per-group count + _k_per_dwordx4 = 128 + _k0_count = (tile_k // _k_per_dwordx4) // _split_k_intra + + # Scale mni for gate (and up if separated) + scale_mni_gate = [] + scale_n_pack_gate = [] + for ni in range_constexpr(num_acc_n): + n_gate = expert_off_idx + by_n + n_tile_base + arith.index(ni * 16) + if const_expr(not _single_b): + n_gate_phys = n_gate + else: + n_gate_phys = n_gate + scale_mni_gate.append(n_gate_phys // arith.index(32)) + scale_n_pack_gate.append( + (n_gate_phys // arith.index(16)) % arith.index(2) ) - for ku in range_constexpr(ku_limit): - g_packs0, g_packs1 = [], [] - u_packs0, u_packs1 = [], [] - for ni in range_constexpr(num_acc_n): - gb0, gb1 = load_b_packs_k64( - base_k, ku, gate_n_blk_list[ni], gate_n_intra_list[ni] - ) - g_packs0.append(gb0) - g_packs1.append(gb1) - if const_expr( - not mock_gate_only and not gate_up_interleave - ): - ub0, ub1 = load_b_packs_k64( - base_k, ku, up_n_blk_list[ni], up_n_intra_list[ni] - ) - u_packs0.append(ub0) - u_packs1.append(ub1) - gate_b_tile.append((g_packs0, g_packs1)) - if const_expr(not mock_gate_only and not gate_up_interleave): - up_b_tile.append((u_packs0, u_packs1)) - return gate_b_tile, up_b_tile - - # Pre-compute scale base element indices (K-loop invariant). - # idx = mni * stride_n0 + ku * stride_k0 + k_lane * stride_klane + n_lane - # Split into: base_elem = mni * stride_n0 + lane_elem (invariant) - # k_elem = ku * stride_k0 (per-iteration) - _scale_lane_elem = ( - lane_div_16 * layout_b_scale.stride_klane + lane_mod_16 - ) - _gate_scale_bases = [] - _up_scale_bases = [] - for _ni in range_constexpr(num_acc_n_packed): - _col_base = ( - by_n - + n_tile_base - + arith.constant(_ni * 16 * pack_N, index=True) - ) - _gate_mni = (expert_off_idx + _col_base) // arith.constant( - 32, index=True - ) - _gate_scale_bases.append( - _gate_mni * layout_b_scale.stride_n0 + _scale_lane_elem - ) - if const_expr(not mock_gate_only and not gate_up_interleave): - _up_mni = ( - expert_off_idx + inter_idx + _col_base - ) // arith.constant(32, index=True) - _up_scale_bases.append( - _up_mni * layout_b_scale.stride_n0 + _scale_lane_elem + if const_expr(not _single_b): + scale_mni_up = [] + scale_n_pack_up = [] + for ni in range_constexpr(num_acc_n): + n_up = ( + expert_off_idx + by_n + n_tile_base + + arith.index(ni * 16) + inter_idx ) - - if const_expr(not a_scale_one): - _a_scale_bases = [] - for _mi in range_constexpr(m_repeat_packed): - _a_mni = _mi + bx_m // scale_mn_pack // 16 - _a_scale_bases.append( - _a_mni * layout_a_scale.stride_n0 + _scale_lane_elem + scale_mni_up.append(n_up // arith.index(32)) + scale_n_pack_up.append( + (n_up // arith.index(16)) % arith.index(2) ) - - _c16_idx = arith.constant(16, index=True) - _c2_idx = arith.constant(2, index=True) - _scale_mask_lo = arith.constant(0xFF, type=T.i32) - - _m_half_idx = arith.constant(0, type=T.i32) - _m_half_i32 = arith.constant(0, type=T.i32) - _scale_shift = arith.constant(0, type=T.i32) - _scale_shift_hi = arith.constant(0, type=T.i32) - _n_half_idx = arith.constant(0, type=T.i32) - _n_half_i32 = arith.constant(0, type=T.i32) - _bscale_shift = arith.constant(0, type=T.i32) - _bscale_shift_hi = arith.constant(0, type=T.i32) - if const_expr(pack_M < scale_mn_pack): - _m_half_idx = (bx_m // _c16_idx) % _c2_idx - _m_half_i32 = arith.index_cast(T.i32, _m_half_idx) - _scale_shift = _m_half_i32 * arith.constant(8, type=T.i32) - _scale_shift_hi = _scale_shift + arith.constant(16, type=T.i32) - - if const_expr(pack_N < scale_mn_pack): - _n_half_idx = (n_tile_base // _c16_idx) % _c2_idx - _n_half_i32 = arith.index_cast(T.i32, _n_half_idx) - _bscale_shift = _n_half_i32 * arith.constant(8, type=T.i32) - _bscale_shift_hi = _bscale_shift + arith.constant(16, type=T.i32) - - def _rearrange_a_scale(raw_i32): - """Rearrange scale bytes for pack_M=1: extract m_half's k0,k1 bytes.""" - if const_expr(pack_M >= scale_mn_pack): - return raw_i32 - b_k0 = arith.andi( - arith.shrui(raw_i32, _scale_shift), _scale_mask_lo - ) - b_k1 = arith.andi( - arith.shrui(raw_i32, _scale_shift_hi), _scale_mask_lo - ) - return arith.ori( - b_k0, arith.shli(b_k1, arith.constant(8, type=T.i32)) - ) - - def _rearrange_b_scale(raw_i32): - """Rearrange scale bytes for pack_N=1: extract n_half's k0,k1 bytes.""" - if const_expr(pack_N >= scale_mn_pack): - return raw_i32 - b_k0 = arith.andi( - arith.shrui(raw_i32, _bscale_shift), _scale_mask_lo - ) - b_k1 = arith.andi( - arith.shrui(raw_i32, _bscale_shift_hi), _scale_mask_lo - ) - return arith.ori( - b_k0, arith.shli(b_k1, arith.constant(8, type=T.i32)) - ) - - if const_expr(a_scale_one): - _as1_const = arith.constant(0x7F7F7F7F, type=T.i32) - _as1_vec = vector.from_elements(T.vec(1, T.i32), [_as1_const]) - - def prefetch_ab_scale_tile(base_k, ku_packed_limit=k_unroll_packed): - a_scale_tile = [] - gate_b_scale = [] - up_b_scale = ( - [] if (not mock_gate_only and not gate_up_interleave) else None + else: + scale_mni_up = None + scale_n_pack_up = None + + def _load_scale_i32(scale_ku_idx, mni_val, scale_klane=None): + _klane = scale_klane if scale_klane is not None else lane_div_16 + idx = (mni_val * layout_b_scale.stride_n0 + + scale_ku_idx * layout_b_scale.stride_k0 + + _klane * layout_b_scale.stride_klane + + lane_mod_16) + return buffer_ops.buffer_load(sw_rsrc, idx, vec_width=1, dtype=i32) + + def _extract_e8m0_f32_dynamic(packed_i32, byte_pos_idx): + """Extract E8M0 byte at runtime byte_pos and decode to f32.""" + shift = arith.index_cast(i32, byte_pos_idx) * arith.constant(8, type=i32) + byte_i32 = arith.shrui(packed_i32, shift) & arith.constant(0xFF, type=i32) + scale_bits = arith.shli(byte_i32, arith.constant(23, type=i32)) + return arith.bitcast(f32, scale_bits) + + def _get_scale_f32(base_k, ku, ni, mni_list, n_pack_list, scale_cache): + # CK addressing: adj_ku = base_k//32 + (ku//4)*4 + lane_div_16 + # scale_klane = lane_div_16, k_pack_sub = (ku//4) % 2 + _k0_blk = ku // 4 + adj_ku = (base_k // arith.index(32) + + arith.index(_k0_blk * 4) + + lane_div_16) + scale_klane_rt = lane_div_16 + k_pack_sub_rt = (adj_ku // arith.index(4)) % arith.index(2) + s_ku = adj_ku // arith.index(8) + + cache_key = (_k0_blk, ni, id(mni_list)) + if cache_key not in scale_cache: + scale_cache[cache_key] = _load_scale_i32( + s_ku, mni_list[ni], scale_klane=scale_klane_rt + ) + packed = scale_cache[cache_key] + n_pack_sub_val = n_pack_list[ni] + byte_pos_even = k_pack_sub_rt * arith.index(2) + byte_pos_odd = byte_pos_even + arith.index(1) + scale_even = _extract_e8m0_f32_dynamic(packed, byte_pos_even) + scale_odd = _extract_e8m0_f32_dynamic(packed, byte_pos_odd) + n_pack_is_zero = arith.cmpi( + arith.CmpIPredicate.eq, + arith.index_cast(i32, n_pack_sub_val), + arith.constant(0, type=i32), ) - for ku in range_constexpr(ku_packed_limit): - k_off = (ku + base_k) * layout_b_scale.stride_k0 - for mi in range_constexpr(m_repeat_packed): - if const_expr(a_scale_one): - a_scale_tile.append(_as1_vec) - else: - s = buffer_ops.buffer_load( - sx_rsrc, - _a_scale_bases[mi] + k_off, - vec_width=1, - dtype=T.i32, - cache_modifier=0, - ) - s = _rearrange_a_scale(s) - a_scale_tile.append( - vector.from_elements(T.vec(1, T.i32), [s]) - ) - for ni in range_constexpr(num_acc_n_packed): - gs = buffer_ops.buffer_load( - sw_rsrc, - _gate_scale_bases[ni] + k_off, - vec_width=1, - dtype=T.i32, - cache_modifier=0, + return arith.select(n_pack_is_zero, scale_even, scale_odd) + + def load_b_raw(base_k, blk_list, intra_list, k0_limit=_k0_count): + """Load raw FP4 data via dwordx4. Returns raw_v4[k0_idx][ni] = vec4_i32.""" + raw_all = [] + for k0_idx in range_constexpr(k0_limit): + raw_k0 = [] + k_off = base_k + arith.index(k0_idx * _k_per_dwordx4) + for ni in range_constexpr(num_acc_n): + v4 = load_b_raw_mxfp4_dwordx4( + buffer_ops, arith, vector, + arg_b=arg_w, + b_rsrc=w_rsrc, + layout_b=layout_b, + base_k=k_off, + n_blk=blk_list[ni], + n_intra=intra_list[ni], + lane_div_16=lane_div_16, + elem_type=w_elem, + kpack_bytes=kpack_bytes, + cache_modifier=2, ) - gs = _rearrange_b_scale(gs) - gate_b_scale.append( - vector.from_elements(T.vec(1, T.i32), [gs]) + raw_k0.append(v4) + raw_all.append(raw_k0) + return raw_all + + def load_b_scale(base_k, mni_list, n_pack_list, ku_limit=k_unroll): + """Load scales for all ku × ni. Returns scales[ku][ni] = f32.""" + scale_cache = {} + scales = [] + for ku in range_constexpr(ku_limit): + scales_ku = [] + for ni in range_constexpr(num_acc_n): + sf = _get_scale_f32( + base_k, ku, ni, mni_list, n_pack_list, scale_cache ) - if const_expr( - not mock_gate_only and not gate_up_interleave - ): - us = buffer_ops.buffer_load( - sw_rsrc, - _up_scale_bases[ni] + k_off, - vec_width=1, - dtype=T.i32, - cache_modifier=0, - ) - us = _rearrange_b_scale(us) - up_b_scale.append( - vector.from_elements(T.vec(1, T.i32), [us]) - ) - return [a_scale_tile, gate_b_scale, up_b_scale] - - _lds_base_zero = arith.index(0) + scales_ku.append(sf) + scales.append(scales_ku) + return scales + + def load_all_b_raw(base_k, k0_limit=_k0_count, ku_limit=k_unroll): + """Load raw B (dwordx4) + scales for gate (and up if separated).""" + g_scales = load_b_scale(base_k, scale_mni_gate, scale_n_pack_gate, ku_limit=ku_limit) + u_scales = None + if const_expr(not _single_b): + u_scales = load_b_scale(base_k, scale_mni_up, scale_n_pack_up, ku_limit=ku_limit) + + g_v4 = load_b_raw(base_k, n_blk_gate, n_intra_gate, k0_limit=k0_limit) + u_v4 = None + if const_expr(not _single_b): + u_v4 = load_b_raw(base_k, n_blk_up, n_intra_up, k0_limit=k0_limit) + return g_v4, g_scales, u_v4, u_scales def store_x_tile_to_lds(vec_x_in_parts, lds_buffer): + _lds_base_zero = arith.index(0) for i in range_constexpr(num_x_loads): row_local = x_row_local[i] col_local_i32 = x_col_local_i32[i] - if const_expr(x_load_bytes == 16): + if const_expr(x_load_bytes >= 16): lds_store_16b_xor16( - arith, - vector, + arith, vector, lds_memref=lds_buffer, vec16_ty=vec16_x, layout_lds=layout_lds, @@ -1085,1548 +1183,3090 @@ def store_x_tile_to_lds(vec_x_in_parts, lds_buffer): vec_part_i32x4=vec_x_in_parts[i], elem_bytes=elem_bytes, ) + else: + lds_store_8b_xor16( + arith, vector, + lds_memref=lds_buffer, + vec8_ty=x_vec_x_ty, + layout_lds=layout_lds, + row_local=row_local, + col_local_i32=col_local_i32, + tx_c4=arith.index(4), + k_blocks16=k_blocks16, + lds_base=_lds_base_zero, + vec_part_i32x2=vec_x_in_parts[i], + elem_bytes=elem_bytes, + ) - if const_expr(use_async_copy): - _dma_bytes = 16 - _wave_size = 64 - _eff_bytes_per_buffer = ( - int(tile_m) * int(_eff_lds_stride) * int(a_elem_bytes) + def lds_load_packs_k64(curr_row_a_lds, col_base_bytes, lds_buffer): + col_base_swz_bytes = swizzle_xor16( + curr_row_a_lds, col_base_bytes, k_blocks16 ) - _num_dma_loads = max( - 1, _eff_bytes_per_buffer // (total_threads * _dma_bytes) + col_base_swz = col_base_swz_bytes // arith.index(int(elem_bytes)) + idx_a16 = crd2idx((curr_row_a_lds, col_base_swz), layout_lds) + loaded_a16 = vector.load_op(vec16_x, lds_buffer, [idx_a16]) + a_i64x2 = vector.bitcast(vec2_i64, loaded_a16) + a0 = vector.extract(a_i64x2, static_position=[0], dynamic_position=[]) + a1 = vector.extract(a_i64x2, static_position=[1], dynamic_position=[]) + return a0, a1 + + def _a_col_bytes_for_ku(ku_val): + """CK-style A col address: L*64 + (ku%4)*16 + (ku//4)*256.""" + _k0_blk = ku_val // 4 + _ku_in = ku_val % 4 + return col_offset_base_bytes + arith.index( + _ku_in * _a_ku_stride_bytes + _k0_blk * 256 ) - def dma_x_tile_to_lds(base_k, lds_buffer): - c4_idx = arith.index(4) - base_k_div4 = ( - (base_k / c_a_pack) - * arith.constant(int(elem_bytes), index=True) - ) / arith.index(4) + _can_full_preload = (m_repeat == 1) + _m_preload = (k_unroll * m_repeat) if _can_full_preload else min(2, k_unroll * m_repeat) + _total_a_slots = k_unroll * m_repeat + + def preload_a_from_lds(lds_buffer): + """Load first _m_preload A tiles from LDS into VGPRs.""" + a_tiles = [None] * _m_preload + for pl in range_constexpr(_m_preload): + _pl_mi = pl % m_repeat + _pl_ku = pl // m_repeat + col = _a_col_bytes_for_ku(_pl_ku) + row = row_a_lds + arith.index(_pl_mi * 16) + a_tiles[pl] = lds_load_packs_k64(row, col, lds_buffer) + return a_tiles + + def _mfma_k32(acc_in, a0, a1, b0, b1): + a_v2 = vector.from_elements(vec2_i64, [a0, a1]) + a_v8 = vector.bitcast(vec8_bf16, a_v2) + b_v2 = vector.from_elements(vec2_i64, [b0, b1]) + b_v8 = vector.bitcast(vec8_bf16, b_v2) + return mfma_f32_bf16_k32(vec4_f32, [a_v8, b_v8, acc_in, 0, 0, 0]) - lds_ptr_i64 = None - for i in range_constexpr(_num_dma_loads): - row_local_i = x_row_local[i] - col_local_i32_i = x_col_local_i32[i] - col_local_sw = swizzle_xor16( - row_local_i, col_local_i32_i * c4_idx, k_blocks16 - ) - row_k_dw = x_row_base_div4[i] + base_k_div4 - global_byte_idx = row_k_dw * c4_idx + col_local_sw - global_offset = arith.index_cast(T.i32, global_byte_idx) + def compute_tile( + acc_gate_in, acc_up_in, + g_v4, g_scales, u_v4, u_scales, + a_preloaded, cur_lds_buffer, next_lds_buffer, + ku_count=k_unroll, + ): + """Compute GEMM tile with preloaded A. - if const_expr(i == 0): - lds_addr = memref.extract_aligned_pointer_as_index( - lds_buffer - ) + wave_id * arith.constant( - _wave_size * _dma_bytes, index=True - ) - lds_ptr_i64 = rocdl.readfirstlane( - T.i64, arith.index_cast(T.i64, lds_addr) - ) - else: - lds_ptr_i64 = lds_ptr_i64 + arith.constant( - total_threads * _dma_bytes, type=T.i64 - ) + Full preload (m_repeat=1): ni→ku, all A in VGPRs. + Partial preload (m_repeat>1): ku→mi→ni, m_preload=2 + pipeline, reload remaining A from cur_lds_buffer. - lds_ptr_type = ir.Type.parse("!llvm.ptr<3>") - lds_ptr = llvm.inttoptr(lds_ptr_type, lds_ptr_i64) + ku_count: number of k_unroll iterations to execute + (< k_unroll for the last tile when K-padding is active). - rocdl.raw_ptr_buffer_load_lds( - x_rsrc, - lds_ptr, - arith.constant(_dma_bytes, type=T.i32), - global_offset, - arith.constant(0, type=T.i32), - arith.constant(0, type=T.i32), - arith.constant(0, type=T.i32), - ) + Returns: (gate_list, up_list, a_tiles_next). + a_tiles_next has _m_preload tiles from next_lds_buffer. + """ + gate_list = list(acc_gate_in) + up_list = list(acc_up_in) if acc_up_in is not None else None + a_tiles_next = [None] * _m_preload - def prefetch_x_to_lds(base_k, lds_buffer): - dma_x_tile_to_lds(base_k, lds_buffer) + if _can_full_preload: + # --- Full preload: ni → ku → mi --- + _is_last_ni = num_acc_n - 1 - def lds_load_packs_k64(curr_row_a_lds, col_base, lds_buffer): - col_base_swz_bytes = swizzle_xor16( - curr_row_a_lds, col_base, k_blocks16 - ) - col_base_swz = ( - col_base_swz_bytes - if elem_bytes == 1 - else (col_base_swz_bytes / arith.index(2)) - ) - idx_a16 = crd2idx([curr_row_a_lds, col_base_swz], layout_lds) - loaded_a16 = vector.load_op(vec16_x, lds_buffer, [idx_a16]) - a_i64x2 = vector.bitcast(vec2_i64, loaded_a16) - a0 = vector.extract( - a_i64x2, static_position=[0], dynamic_position=[] - ) - a1 = vector.extract( - a_i64x2, static_position=[1], dynamic_position=[] - ) - return a0, a1 + for ni in range_constexpr(num_acc_n): + for ku in range_constexpr(ku_count): + _k0_idx = ku // 4 + _ku_in_k0 = ku % 4 - def prefetch_full_a_from_lds(lds_buffer, ku_limit=k_unroll): - """Load entire A tile from LDS into registers before compute.""" - a_regs = [] - for k_idx in range_constexpr(ku_limit): - col_base = col_offset_base + (k_idx * 128) // a_elem_vec_pack - for mi_idx in range_constexpr(m_repeat): - mi_val = arith.constant(mi_idx * 16, index=True) - curr_row = row_a_lds + mi_val - a0, a1 = lds_load_packs_k64(curr_row, col_base, lds_buffer) - if const_expr(is_f8_a): - a2, a3 = lds_load_packs_k64( - curr_row, col_base + 64, lds_buffer + g_raw_ku = vector.extract( + g_v4[_k0_idx][ni], + static_position=[_ku_in_k0], + dynamic_position=[], ) - a_regs.append((a0, a1, a2, a3)) - else: - a_regs.append((a0, a1)) - return a_regs - - # Compute tile: gate + up MFMA interleaved, same A data, different B data. - # Two accumulator sets; after all K tiles, acc = acc_gate + acc_up (f32 add). - def compute_tile( - acc_gate_in, - acc_up_in, - gate_b_tile_in, - up_b_tile_in, - a_tile_regs, - a_scale=None, - gate_b_scale=None, - up_b_scale=None, - *, - prefetch_epilogue=False, - ku_count=k_unroll, - ): - gate_list = list(acc_gate_in) - _single_b = mock_gate_only or gate_up_interleave - up_list = None if _single_b else list(acc_up_in) - mfma_res_ty = vec4_f32 - epilogue_pf = None - bias_pf = None - if const_expr(prefetch_epilogue): - if const_expr(enable_bias): - bias_pf = [] - for ni in range_constexpr(num_acc_n): - if const_expr(gate_up_interleave): - _logical_col = ( - (by_n + n_tile_base) - // arith.constant(2, index=True) - + arith.constant((ni // 2) * 16, index=True) - + lane_mod_16 - ) - _up_off = ( - inter_idx - if (ni % 2 == 1) - else arith.constant(0, index=True) - ) - bias_offset = ( - expert_off_idx + _up_off + _logical_col - ) - else: - global_n = ( - by_n - + n_tile_base - + arith.constant(ni * 16, index=True) - + lane_mod_16 - ) - bias_offset = expert_off_idx + global_n - bias_pf.append( - buffer_ops.buffer_load( - bias_rsrc, bias_offset, vec_width=1, dtype=f32 - ) + gb0, gb1 = unpack_b_mxfp4_bf16( + g_raw_ku, arith, vector, + scale_f32=g_scales[ku][ni], ) - tw_pf = None - if const_expr(doweight_stage1): - tw_pf = [] - lane_div_16_mul4_pf = lane_div_16 * arith.index(4) - ii_idx_list_pf = [ - arith.constant(ii, index=True) for ii in range(4) - ] - for mi in range_constexpr(m_repeat): - mi_base_pf = arith.constant(mi * 16, index=True) - for ii in range_constexpr(4): - row_off_pf = ( - lane_div_16_mul4_pf + ii_idx_list_pf[ii] + if const_expr(up_list is not None): + u_raw_ku = vector.extract( + u_v4[_k0_idx][ni], + static_position=[_ku_in_k0], + dynamic_position=[], ) - sorted_row_pf = bx_m + mi_base_pf + row_off_pf - tw_pf.append( - buffer_ops.buffer_load( - sorted_w_rsrc, - sorted_row_pf, - vec_width=1, - dtype=f32, - ) + ub0, ub1 = unpack_b_mxfp4_bf16( + u_raw_ku, arith, vector, + scale_f32=u_scales[ku][ni], ) - epilogue_pf = (None, tw_pf, bias_pf) - - c0_i64 = arith.constant(0, type=T.i64) - vec4_i64 = T.vec(4, T.i64) - vec8_i32 = T.vec(8, T.i32) - - def pack_i64x4_to_i32x8(x0, x1, x2, x3): - v4 = vector.from_elements(vec4_i64, [x0, x1, x2, x3]) - return vector.bitcast(vec8_i32, v4) - _eff_packed = (ku_count + pack_K - 1) // pack_K - # B-major: fix B (ni), cycle A (mi) -- B from VMEM stays - # in registers while A from LDS is repacked per mi. - for ku128 in range_constexpr(_eff_packed): - for ni in range_constexpr(num_acc_n_packed): - gate_bs_i32 = gate_b_scale[ku128 * num_acc_n_packed + ni] - gate_bs_val = vector.extract( - gate_bs_i32, - static_position=[0], - dynamic_position=[], - ) - if const_expr(not _single_b): - up_bs_i32 = up_b_scale[ku128 * num_acc_n_packed + ni] - up_bs_val = vector.extract( - up_bs_i32, static_position=[0], dynamic_position=[] - ) - for ikxdl in range_constexpr(pack_K): - k_idx = ku128 * pack_K + ikxdl - if const_expr(k_idx < ku_count): - gate_bp0, gate_bp1 = gate_b_tile_in[k_idx] - if const_expr(not _single_b): - up_bp0, up_bp1 = up_b_tile_in[k_idx] - for inxdl in range_constexpr(pack_N): - ni_idx = ni * pack_N + inxdl - gb0 = gate_bp0[ni_idx] - gb1 = gate_bp1[ni_idx] - gb128 = pack_i64x4_to_i32x8( - gb0, gb1, c0_i64, c0_i64 + for mi in range_constexpr(m_repeat): + _flat = ku * m_repeat + mi + a0, a1 = a_preloaded[_flat] + acc_idx = mi * num_acc_n + ni + gate_list[acc_idx] = _mfma_k32( + gate_list[acc_idx], a0, a1, gb0, gb1, + ) + if const_expr(up_list is not None): + up_list[acc_idx] = _mfma_k32( + up_list[acc_idx], a0, a1, ub0, ub1, ) - if const_expr(not _single_b): - ub0 = up_bp0[ni_idx] - ub1 = up_bp1[ni_idx] - ub128 = pack_i64x4_to_i32x8( - ub0, ub1, c0_i64, c0_i64 - ) - for mi in range_constexpr(m_repeat_packed): - a_scale_i32 = a_scale[ - ku128 * m_repeat_packed + mi - ] - a_scale_val = vector.extract( - a_scale_i32, - static_position=[0], - dynamic_position=[], - ) - for imxdl in range_constexpr(pack_M): - mi_idx = mi * pack_M + imxdl - _a_reg_idx = k_idx * m_repeat + mi_idx - if const_expr(is_f8_a): - a0, a1, a2, a3 = a_tile_regs[ - _a_reg_idx - ] - a128 = pack_i64x4_to_i32x8( - a0, a1, a2, a3 - ) - else: - a0, a1 = a_tile_regs[_a_reg_idx] - a128 = pack_i64x4_to_i32x8( - a0, a1, c0_i64, c0_i64 - ) - acc_idx = mi_idx * num_acc_n + ni_idx - gate_list[acc_idx] = ( - rocdl.mfma_scale_f32_16x16x128_f8f6f4( - mfma_res_ty, - [ - a128, - gb128, - gate_list[acc_idx], - cbsz, - blgp, - ikxdl * pack_M + imxdl, - a_scale_val, - ikxdl * pack_N + inxdl, - gate_bs_val, - ], - ) - ) - if const_expr(not _single_b): - up_list[acc_idx] = ( - rocdl.mfma_scale_f32_16x16x128_f8f6f4( - mfma_res_ty, - [ - a128, - ub128, - up_list[acc_idx], - cbsz, - blgp, - ikxdl * pack_M + imxdl, - a_scale_val, - ikxdl * pack_N + inxdl, - up_bs_val, - ], - ) - ) - return gate_list, up_list, epilogue_pf - def load_a_subtile(k_idx, mi_idx, lds_buffer): - """Load a single A sub-tile from LDS (one ds_read).""" - col_base = col_offset_base + (k_idx * 128) // a_elem_vec_pack - mi_val = arith.constant(mi_idx * 16, index=True) - curr_row = row_a_lds + mi_val - a0, a1 = lds_load_packs_k64(curr_row, col_base, lds_buffer) - if const_expr(is_f8_a): - a2, a3 = lds_load_packs_k64(curr_row, col_base + 64, lds_buffer) - return (a0, a1, a2, a3) + if next_lds_buffer is not None and ni == _is_last_ni: + for mi in range_constexpr(m_repeat): + _flat = ku * m_repeat + mi + _nxt_col = _a_col_bytes_for_ku(ku) + _nxt_row = row_a_lds + arith.index(mi * 16) + a_tiles_next[_flat] = lds_load_packs_k64( + _nxt_row, _nxt_col, next_lds_buffer + ) else: - return (a0, a1) - - _single_b_pipe = mock_gate_only or gate_up_interleave - - def compute_bmajor_mfma_phase( - all_a_tiles, - gate_b_single, - up_b_single, - a_scale_vals, - gate_bs_val, - up_bs_val, - gate_list, - up_list, - k_idx, - ni_idx, - ikxdl, - inxdl, - ): - """B-major MFMA: fix one B (ni), cycle all A tiles (mi). - - Packs B once and reuses across all mi iterations. - A tiles come from LDS (already available, no VMEM wait). - - all_a_tiles: flat list indexed by [k*m_repeat + mi]. - gate_b_single/up_b_single: (b0, b1) for one specific ni. - When _single_b_pipe (mock_gate_only or interleave), up_b_single is None. - a_scale_vals: list of A scale scalars indexed by mi_packed. - """ - c0_i64 = arith.constant(0, type=T.i64) - vec4_i64 = T.vec(4, T.i64) - vec8_i32 = T.vec(8, T.i32) - - def _pack(x0, x1, x2, x3): - v4 = vector.from_elements(vec4_i64, [x0, x1, x2, x3]) - return vector.bitcast(vec8_i32, v4) + # --- bm=32: full load from cur_lds inside compute --- + _local_a_slots = ku_count * m_repeat + all_a = [None] * _local_a_slots + for ku in range_constexpr(ku_count): + for mi in range_constexpr(m_repeat): + _col = _a_col_bytes_for_ku(ku) + _row = row_a_lds + arith.index(mi * 16) + all_a[ku * m_repeat + mi] = lds_load_packs_k64( + _row, _col, cur_lds_buffer + ) - mfma_res_ty = vec4_f32 - gb128 = _pack(gate_b_single[0], gate_b_single[1], c0_i64, c0_i64) - if const_expr(not _single_b_pipe): - ub128 = _pack(up_b_single[0], up_b_single[1], c0_i64, c0_i64) + for ni in range_constexpr(num_acc_n): + for ku in range_constexpr(ku_count): + _k0_idx = ku // 4 + _ku_in_k0 = ku % 4 - for mi_p in range_constexpr(m_repeat_packed): - a_scale_val = a_scale_vals[mi_p] - for imxdl in range_constexpr(pack_M): - mi_idx = mi_p * pack_M + imxdl - a_reg = all_a_tiles[k_idx * m_repeat + mi_idx] + g_raw_ku = vector.extract( + g_v4[_k0_idx][ni], + static_position=[_ku_in_k0], + dynamic_position=[], + ) + gb0, gb1 = unpack_b_mxfp4_bf16( + g_raw_ku, arith, vector, + scale_f32=g_scales[ku][ni], + ) + if const_expr(up_list is not None): + u_raw_ku = vector.extract( + u_v4[_k0_idx][ni], + static_position=[_ku_in_k0], + dynamic_position=[], + ) + ub0, ub1 = unpack_b_mxfp4_bf16( + u_raw_ku, arith, vector, + scale_f32=u_scales[ku][ni], + ) - if const_expr(is_f8_a): - a128 = _pack(a_reg[0], a_reg[1], a_reg[2], a_reg[3]) - else: - a128 = _pack(a_reg[0], a_reg[1], c0_i64, c0_i64) - - acc_idx = mi_idx * num_acc_n + ni_idx - gate_list[acc_idx] = rocdl.mfma_scale_f32_16x16x128_f8f6f4( - mfma_res_ty, - [ - a128, - gb128, - gate_list[acc_idx], - cbsz, - blgp, - ikxdl * pack_M + imxdl, - a_scale_val, - ikxdl * pack_N + inxdl, - gate_bs_val, - ], - ) - if const_expr(not _single_b_pipe): - up_list[acc_idx] = ( - rocdl.mfma_scale_f32_16x16x128_f8f6f4( - mfma_res_ty, - [ - a128, - ub128, - up_list[acc_idx], - cbsz, - blgp, - ikxdl * pack_M + imxdl, - a_scale_val, - ikxdl * pack_N + inxdl, - up_bs_val, - ], + for mi in range_constexpr(m_repeat): + _flat = ku * m_repeat + mi + a0, a1 = all_a[_flat] + acc_idx = mi * num_acc_n + ni + gate_list[acc_idx] = _mfma_k32( + gate_list[acc_idx], a0, a1, gb0, gb1, ) - ) + if const_expr(up_list is not None): + up_list[acc_idx] = _mfma_k32( + up_list[acc_idx], a0, a1, ub0, ub1, + ) - def _interleaved_half( - lds_read, - lds_write, - next_k_dma_py, - next_k_load, - prev_a_tile, - prev_gate_w, - prev_up_w, - prev_a_scale, - prev_gate_bs, - prev_up_bs, - acc_gate, - acc_up, - ): - """One flatmm-style interleaved half-iteration (deep pipeline). + # Load next round's preload from next_lds + if next_lds_buffer is not None: + for pl in range_constexpr(_m_preload): + _pl_mi = pl % m_repeat + _pl_ku = pl // m_repeat + _nxt_col = _a_col_bytes_for_ku(_pl_ku) + _nxt_row = row_a_lds + arith.index(_pl_mi * 16) + a_tiles_next[pl] = lds_load_packs_k64( + _nxt_row, _nxt_col, next_lds_buffer + ) - Generalized for arbitrary m_repeat (block_m=32, 64, ...). - DMA targets lds_write (OTHER buffer) while ds_read uses - lds_read (already DMA'd in previous half). + return gate_list, up_list, a_tiles_next - Interleaving schedule (per half): - Phase 0: scale VMEM + 2 ds_read(A) -> 4 MFMA(prev) - Phase 1..N: B VMEM(distributed) + 2 ds_read(A, if avail) -> 4 MFMA(prev) - Phase N+1..: remaining B VMEM -> 4 MFMA(prev) - """ - _abs_k = k_base_idx + arith.constant(next_k_load, index=True) - _bk = _abs_k // arith.constant(2, index=True) - _sk = _abs_k // arith.constant(pack_K * 128, index=True) - _k_off = _sk * layout_b_scale.stride_k0 + rocdl.sched_barrier(0) - rocdl.sched_barrier(0) - rocdl.s_waitcnt(_vmcnt_before_barrier) - _barrier() - rocdl.sched_barrier(0) + _b_stream_mult = 1 if _single_b else 2 - # DMA A to OTHER buffer (for next half), non-blocking - _abs_k_dma = k_base_idx + arith.constant(next_k_dma_py, index=True) - if const_expr(use_async_copy and next_k_dma_py < int(_k_dim)): - prefetch_x_to_lds(_abs_k_dma, lds_write) - if const_expr(not use_async_copy): - _x_regs = load_x_tile(_abs_k_dma) - - # ---- Extract previous scale values ---- - _prev_asvs = [] - for _mi_p in range_constexpr(m_repeat_packed): - _prev_asvs.append( - vector.extract( - prev_a_scale[_mi_p], - static_position=[0], - dynamic_position=[], - ) - ) - _prev_gsv_list = [] - for _gs_ni in range_constexpr(num_acc_n_packed): - _prev_gsv_list.append( - vector.extract( - prev_gate_bs[_gs_ni], - static_position=[0], - dynamic_position=[], - ) - ) - if const_expr(not _single_b_pipe): - _prev_usv_list = [] - for _us_ni in range_constexpr(num_acc_n_packed): - _prev_usv_list.append( - vector.extract( - prev_up_bs[_us_ni], - static_position=[0], - dynamic_position=[], - ) - ) + def hot_loop_scheduler(): + """CK-style scheduler: interleave MFMA, DS_READ, VMEM_READ.""" + # Constants matching CK: + # dsread_per_wg = WG_M * WG_K * sizeof(bf16) / 64 / VectorLoadSize(16) = 16*32*2/64/16 = 1 + _dsread_per_wg = 1 + _mfma_per_wg = 1 + _NIterPerWarp = num_acc_n * _b_stream_mult # 2 or 4 + _mfma_perM_perK = _NIterPerWarp * _mfma_per_wg + + _HalfMIter = (m_repeat + 1) // 2 + + _Aload_num_perK = _dsread_per_wg * m_repeat # num buffer_load_lds per K iter + _Aload_rep = max((_Aload_num_perK + m_repeat - 1) // m_repeat, 1) + _Bload_num_perK = num_acc_n * _b_stream_mult # B loads per K iter + _Bload_rep = max((_Bload_num_perK + _HalfMIter - 1) // _HalfMIter, 1) + + for _ku in range_constexpr(k_unroll): + for _mi in range_constexpr(m_repeat): + _dsread_perM = _dsread_per_wg + _load_perM = 0 - # ---- Execute phases from unified schedule ---- - _a_all = {} - _b_gate_all = {} - _b_up_all = {} - - for _p in range_constexpr(_pipe_n_phases): - # Scale VMEM loads (phase 0 only) - if const_expr(_pp_has_scale[_p]): - _new_as_list = [] - for _mi_p in range_constexpr(m_repeat_packed): - if const_expr(a_scale_one): - _new_as_list.append(_as1_const) - else: - _raw_as = buffer_ops.buffer_load( - sx_rsrc, - _a_scale_bases[_mi_p] + _k_off, - vec_width=1, - dtype=T.i32, - cache_modifier=0, - ) - _new_as_list.append(_rearrange_a_scale(_raw_as)) - _new_gs_list = [] - for _gs_ni in range_constexpr(num_acc_n_packed): - _gs_raw = buffer_ops.buffer_load( - sw_rsrc, - _gate_scale_bases[_gs_ni] + _k_off, - vec_width=1, - dtype=T.i32, - cache_modifier=0, - ) - _new_gs_list.append(_rearrange_b_scale(_gs_raw)) - if const_expr(not _single_b_pipe): - _new_us_list = [] - for _us_ni in range_constexpr(num_acc_n_packed): - _us_raw = buffer_ops.buffer_load( - sw_rsrc, - _up_scale_bases[_us_ni] + _k_off, - vec_width=1, - dtype=T.i32, - cache_modifier=0, - ) - _new_us_list.append(_rearrange_b_scale(_us_raw)) - - # B VMEM loads - for _b_j in range_constexpr(len(_pp_b_loads[_p])): - _b_type, _b_ku, _b_ni = _pp_b_loads[_p][_b_j] - if const_expr(_b_type == "gate"): - _b_gate_all[(_b_ku, _b_ni)] = load_b_packs_k64( - _bk, - _b_ku, - gate_n_blk_list[_b_ni], - gate_n_intra_list[_b_ni], + if const_expr(_mi < _HalfMIter): + _load_perM = ( + (_Aload_rep if (_Aload_num_perK - (m_repeat - 1 - _mi) * _Aload_rep) > 0 else 0) + + (_Bload_rep if (_Bload_num_perK - (_HalfMIter - 1 - _mi) * _Bload_rep) > 0 else 0) ) else: - _b_up_all[(_b_ku, _b_ni)] = load_b_packs_k64( - _bk, - _b_ku, - up_n_blk_list[_b_ni], - up_n_intra_list[_b_ni], + _load_perM = ( + _Aload_rep if (_Aload_num_perK - (m_repeat - 1 - _mi) * _Aload_rep) > 0 else 0 ) - # A ds_reads - rocdl.sched_barrier(0) - for _a_j in range_constexpr(len(_pp_a_reads[_p])): - _ak, _ami = _pp_a_reads[_p][_a_j] - _a_all[(_ak, _ami)] = load_a_subtile( - _ak, - _ami, - lds_read, - ) - rocdl.sched_barrier(0) - - # MFMAs on prev data - rocdl.s_setprio(1) - for _m_j in range_constexpr(len(_pp_mfma[_p])): - _k_idx, _ni_idx, _ikxdl, _inxdl, _ku128 = _pp_mfma[_p][_m_j] - _ni_packed_idx = _ni_idx // pack_N - _up_b_single = ( - ( - prev_up_w[_k_idx][0][_ni_idx], - prev_up_w[_k_idx][1][_ni_idx], - ) - if not _single_b_pipe - else None - ) - compute_bmajor_mfma_phase( - prev_a_tile, - ( - prev_gate_w[_k_idx][0][_ni_idx], - prev_gate_w[_k_idx][1][_ni_idx], - ), - _up_b_single, - _prev_asvs, - _prev_gsv_list[_ni_packed_idx], - ( - _prev_usv_list[_ni_packed_idx] - if not _single_b_pipe - else None - ), - acc_gate, - acc_up, - _k_idx, - _ni_idx, - _ikxdl, - _inxdl, - ) - rocdl.s_setprio(0) - rocdl.sched_barrier(0) + _sum_data = _dsread_perM + _load_perM + _round_data = max((_sum_data + _mfma_perM_perK - 1) // _mfma_perM_perK, 1) + + # Build instruction order: 2=VMEM, 3=DS_READ + _inst_order = [] + _max_data = max(_load_perM, _dsread_perM) + for _j in range_constexpr(_max_data): + if const_expr(_load_perM > _j): + _inst_order.append(2) + if const_expr(_dsread_perM > _j): + _inst_order.append(3) + # Pad to mfma_perM_perK * round_data + _pad_len = _mfma_perM_perK * _round_data - len(_inst_order) + _inst_order.extend([0] * _pad_len) + + for _nj in range_constexpr(_mfma_perM_perK): + if const_expr(_nj == 0): + _inst_idx = 0 + elif const_expr(_nj == 1): + _inst_idx = _mfma_perM_perK - 2 if _mfma_perM_perK > 2 else 1 + elif const_expr(_nj == 2): + _inst_idx = _mfma_perM_perK - 1 + else: + _inst_idx = _mfma_perM_perK - _nj - # ---- Assemble loaded data for next half-iteration ---- - cur_a_tile = [] - for _k in range_constexpr(k_unroll): - for _mi in range_constexpr(m_repeat): - cur_a_tile.append(_a_all[(_k, _mi)]) + rocdl.sched_mfma(1) - cur_gate_w = [] - cur_up_w = None if _single_b_pipe else [] - for ku in range_constexpr(k_unroll): - g_packs0, g_packs1 = [], [] - u_packs0, u_packs1 = [], [] - for ni in range_constexpr(num_acc_n): - g = _b_gate_all[(ku, ni)] - g_packs0.append(g[0]) - g_packs1.append(g[1]) - if const_expr(not _single_b_pipe): - u = _b_up_all[(ku, ni)] - u_packs0.append(u[0]) - u_packs1.append(u[1]) - cur_gate_w.append((g_packs0, g_packs1)) - if const_expr(not _single_b_pipe): - cur_up_w.append((u_packs0, u_packs1)) + for _r in range_constexpr(_round_data): + if const_expr(_r % 2 == 0): + _oi = _inst_idx + _r * _mfma_perM_perK + else: + _oi = (_r + 1) * _mfma_perM_perK - 1 - _inst_idx + if const_expr(_oi < len(_inst_order)): + if const_expr(_inst_order[_oi] == 2): + rocdl.sched_vmem(1) + elif _inst_order[_oi] == 3: + rocdl.sched_dsrd(1) + + if const_expr(_Aload_num_perK == 0): + rocdl.sched_vmem(1) + rocdl.sched_barrier(0) - cur_a_scale = [] - for _mi_p in range_constexpr(m_repeat_packed): - cur_a_scale.append( - vector.from_elements( - T.vec(1, T.i32), - [_new_as_list[_mi_p]], - ) - ) - cur_gate_bs = [] - for _gs_ni in range_constexpr(num_acc_n_packed): - cur_gate_bs.append( - vector.from_elements( - T.vec(1, T.i32), [_new_gs_list[_gs_ni]] - ) - ) - if const_expr(not _single_b_pipe): - cur_up_bs = [] - for _us_ni in range_constexpr(num_acc_n_packed): - cur_up_bs.append( - vector.from_elements( - T.vec(1, T.i32), [_new_us_list[_us_ni]] - ) - ) - else: - cur_up_bs = None - - if const_expr(not use_async_copy): - store_x_tile_to_lds(_x_regs, lds_write) - - return ( - cur_a_tile, - cur_gate_w, - cur_up_w, - cur_a_scale, - cur_gate_bs, - cur_up_bs, - acc_gate, - acc_up, + def _s1_barrier(vmcnt=63, lgkmcnt=63): + """s_waitcnt + s_barrier via inline asm (bypasses LLVM).""" + parts = [] + needs_waitcnt = vmcnt < 63 or lgkmcnt < 63 + if needs_waitcnt: + wc = [] + if vmcnt < 63: + wc.append(f"vmcnt({vmcnt})") + if lgkmcnt < 63: + wc.append(f"lgkmcnt({lgkmcnt})") + parts.append("s_waitcnt " + " ".join(wc)) + parts.append("s_barrier") + llvm.InlineAsmOp( + res=None, + operands_=[], + asm_string="\n".join(parts), + constraints="", + has_side_effects=True, + is_align_stack=False, ) - # Pipeline (split ping/pong allocators) - rocdl.sched_barrier(0) + # ---- CK-style constants ---- + # NOTE: original A16W4 path computed `_vmcnt_before_barrier = + # num_x_loads` here but never actually consumed it; the value + # is dead code. It was removed because (a) it shadows the + # generic body's enclosing-scope binding, and (b) it would + # otherwise leak as a closure local of `moe_gemm1`. + + # Split-K: base K offset + if const_expr(_is_splitk): + bz = gpu.block_id("z") + k_base = bz * arith.index(_k_dim) + else: + k_base = arith.index(0) - k0 = k_base_idx - if const_expr(use_async_copy): - prefetch_x_to_lds(k0, lds_x_pong) + # Intra-WG split-K: B offset per wave group + if const_expr(_split_k_intra > 1): + _k_half = tile_k // _split_k_intra + _wg_k_off = _wave_group * arith.index(_k_half) else: - x_regs0 = load_x_tile(k0) - store_x_tile_to_lds(x_regs0, lds_x_pong) + _wg_k_off = arith.index(0) + + # ---- CK-style pipeline: HEAD ---- + # DMA A[0] → pong (full tile_k, shared by all wave groups) + k0 = k_base + prefetch_x_to_lds(k0, lds_x_pong) rocdl.sched_barrier(0) - _k0_scale = k_base_idx // arith.constant(pack_K * 128, index=True) - a_scale_pong, gate_bs_pong, up_bs_pong = prefetch_ab_scale_tile( - _k0_scale - ) - _c_tile_m_idx = arith.constant(tile_m, index=True) - _tid_in_range = arith.cmpi(CmpIPredicate.ult, tx, _c_tile_m_idx) - _if_tid = scf.IfOp(_tid_in_range) - with ir.InsertionPoint(_if_tid.then_block): - _tid_row = bx_m + tx - _tid_val = buffer_ops.buffer_load( - sorted_rsrc, _tid_row, vec_width=1, dtype=T.i32 - ) - _tid_vec1 = vector.from_elements(T.vec(1, T.i32), [_tid_val]) - vector.store(_tid_vec1, lds_tid, [tx]) - scf.YieldOp([]) - acc_gate = [acc_init] * num_acc_n * m_repeat - acc_up = ( - [acc_init] * num_acc_n * m_repeat if not _single_b_pipe else None + # Load B[0] (raw + scale, per-group K range) + g_raw_ping, g_sc_ping, u_raw_ping, u_sc_ping = load_all_b_raw( + k0 + _wg_k_off ) - _k1 = k_base_idx + arith.constant(tile_k, index=True) rocdl.sched_barrier(0) - if const_expr(use_async_copy): - prefetch_x_to_lds(_k1, lds_x_ping) - else: - _x_regs_prime = load_x_tile(_k1) - store_x_tile_to_lds(_x_regs_prime, lds_x_ping) - _k0_b = k_base_idx // arith.constant(2, index=True) - gate_w0, up_w0 = load_b_tile(_k0_b) - # Prime the deep pipeline: DMA K=tile_k -> ping (1 tile ahead) - if const_expr(use_async_copy): - rocdl.s_waitcnt(0) + # DMA A[1] → ping + _k1 = k_base + arith.index(tile_k) + + prefetch_x_to_lds(_k1, lds_x_ping) + rocdl.sched_barrier(0) + + # Init C + acc_gate = [acc_init] * (num_acc_n * m_repeat) + acc_up = [acc_init] * (num_acc_n * m_repeat) if not _single_b else None + + # Wait for all DMA + barrier to sync all threads + rocdl.s_waitcnt(0) gpu.barrier() rocdl.sched_barrier(0) - a_tile_pong = prefetch_full_a_from_lds(lds_x_pong) + # Preload A[0] from pong LDS → VGPRs (safe: all threads' DMA done) + a_cur = preload_a_from_lds(lds_x_pong) rocdl.sched_barrier(0) - rocdl.s_waitcnt(6) - num_k_tiles_py = int(_k_dim) // int(tile_k) - odd_k_tiles = (num_k_tiles_py % 2) == 1 - tail_tiles = 1 if odd_k_tiles else 2 - k_main2_py = (num_k_tiles_py - tail_tiles) * int(tile_k) - if const_expr(k_main2_py < 0): - k_main2_py = 0 + total_tiles = int(_k_dim) // int(tile_k) + pair_iters = max((total_tiles - 2) // 2, 0) - gate_w_pong = gate_w0 - up_w_pong = up_w0 + for pair_i in range_constexpr(pair_iters): + k_iv = k_base + arith.index(pair_i * (tile_k * 2)) - rocdl.sched_barrier(0) + # ---- Half 2i ---- + _k_b1 = k_iv + arith.index(tile_k) + g_raw_pong, g_sc_pong, u_raw_pong, u_sc_pong = load_all_b_raw( + _k_b1 + _wg_k_off + ) - if const_expr(k_main2_py > 0): - for k_iv_py in range_constexpr(0, k_main2_py, tile_k * 2): - next_k_load_1 = k_iv_py + tile_k - next_k_load_2 = k_iv_py + tile_k * 2 - next_k_dma_1 = k_iv_py + tile_k * 2 - next_k_dma_2 = k_iv_py + tile_k * 3 - - # Half 1: read ping (DMA'd prev half), DMA->pong, MFMA(pong) - ( - a_tile_ping, - gate_w_ping, - up_w_ping, - a_scale_ping, - gate_bs_ping, - up_bs_ping, - acc_gate, - acc_up, - ) = _interleaved_half( - lds_x_ping, - lds_x_pong, - next_k_dma_1, - next_k_load_1, - a_tile_pong, - gate_w_pong, - up_w_pong, - a_scale_pong, - gate_bs_pong, - up_bs_pong, - acc_gate, - acc_up, - ) + acc_gate, acc_up, a_next = compute_tile( + acc_gate, acc_up, + g_raw_ping, g_sc_ping, u_raw_ping, u_sc_ping, + a_cur, lds_x_pong, lds_x_ping, + ) - # Half 2: read pong (DMA'd Half 1), DMA->ping, MFMA(ping) - ( - a_tile_pong, - gate_w_pong, - up_w_pong, - a_scale_pong, - gate_bs_pong, - up_bs_pong, - acc_gate, - acc_up, - ) = _interleaved_half( - lds_x_pong, - lds_x_ping, - next_k_dma_2, - next_k_load_2, - a_tile_ping, - gate_w_ping, - up_w_ping, - a_scale_ping, - gate_bs_ping, - up_bs_ping, - acc_gate, - acc_up, - ) + _k_a2 = k_iv + arith.index(tile_k * 2) + prefetch_x_to_lds(_k_a2, lds_x_pong) + rocdl.s_waitcnt(0) + gpu.barrier() + rocdl.sched_barrier(0) + a_cur = a_next - # _wave_mod2_b = wave_id % arith.constant(2, index=True) - # _wave_odd = arith.cmpi( - # CmpIPredicate.eq, _wave_mod2_b, arith.constant(1, index=True) - # ) - # _if_wave_odd = scf.IfOp(_wave_odd) - # with ir.InsertionPoint(_if_wave_odd.then_block): - # # gpu.barrier() - # _barrier() - # scf.YieldOp([]) + # ---- Half 2i+1 ---- + _k_b2 = k_iv + arith.index(tile_k * 2) + g_raw_ping, g_sc_ping, u_raw_ping, u_sc_ping = load_all_b_raw( + _k_b2 + _wg_k_off + ) - if const_expr(odd_k_tiles): - acc_gate, acc_up, epilogue_pf = compute_tile( - acc_gate, - acc_up, - gate_w_pong, - up_w_pong, - a_tile_pong, - a_scale_pong, - gate_bs_pong, - up_bs_pong, - prefetch_epilogue=True, - ku_count=_tail_ku if _pad_ku_skip > 0 else k_unroll, + acc_gate, acc_up, a_next = compute_tile( + acc_gate, acc_up, + g_raw_pong, g_sc_pong, u_raw_pong, u_sc_pong, + a_cur, lds_x_ping, lds_x_pong, + ) + + _k_a3 = k_iv + arith.index(tile_k * 3) + prefetch_x_to_lds(_k_a3, lds_x_ping) + rocdl.s_waitcnt(0) + gpu.barrier() + rocdl.sched_barrier(0) + a_cur = a_next + + # ---- TAIL: last 2 tiles ---- + # Load B for tail-1 (partial if K-padding active) + k_tail1 = k_base + arith.index(_k_dim) - arith.index(tile_k) + if const_expr(_pad_ku_skip > 0): + g_raw_pong, g_sc_pong, u_raw_pong, u_sc_pong = load_all_b_raw( + k_tail1 + _wg_k_off, + k0_limit=_tail_k0_count, + ku_limit=_tail_ku, ) else: - _k_tail_rel = arith.constant(_k_dim - tile_k, index=True) - k_tail1 = k_base_idx + _k_tail_rel - x_regs_ping = [] - if const_expr(use_async_copy): - prefetch_x_to_lds(k_tail1, lds_x_ping) - else: - x_regs_ping = load_x_tile(k_tail1) - if const_expr(_pad_ku_skip > 0): - gate_w_ping, up_w_ping = load_b_tile( - k_tail1 // arith.constant(2, index=True), - ku_limit=_tail_ku, + g_raw_pong, g_sc_pong, u_raw_pong, u_sc_pong = load_all_b_raw( + k_tail1 + _wg_k_off + ) + + # GEMM tail-0: use a_cur (from pong), load a_next from ping + acc_gate, acc_up, a_next = compute_tile( + acc_gate, acc_up, + g_raw_ping, g_sc_ping, u_raw_ping, u_sc_ping, + a_cur, lds_x_pong, lds_x_ping, + ) + hot_loop_scheduler() + rocdl.s_waitcnt(0) + # _s1_barrier() + + # GEMM tail-1: use a_next (from ping), no next round + # When K-padding is active, only compute the valid portion + acc_gate, acc_up, _ = compute_tile( + acc_gate, acc_up, + g_raw_pong, g_sc_pong, u_raw_pong, u_sc_pong, + a_next, lds_x_ping, None, + ku_count=_tail_ku if _pad_ku_skip > 0 else k_unroll, + ) + + # ---- Intra-WG split-K reduce via LDS ---- + if const_expr(_split_k_intra > 1): + _num_accs = num_acc_n * m_repeat + _has_up = not _single_b + _streams = 2 if _has_up else 1 + # 4 f32 per vec4_f32 acc × _num_accs × _streams per thread + _f32_per_thread = 4 * _num_accs * _streams + _reduce_stride = arith.index(_f32_per_thread) + # 4 waves × 64 threads × _f32_per_thread + _reduce_f32_total = 4 * 64 * _f32_per_thread + + reduce_lds = SmemPtr( + base_ptr_pong, lds_pong_offset, T.f32, + shape=(_reduce_f32_total,), + ).get() + + tx_local = lane_id + _lds_base = ( + wave_id * arith.index(64) * _reduce_stride + + tx_local * _reduce_stride + ) + + for _ai in range_constexpr(_num_accs): + _off = _lds_base + arith.index(_ai * 4) + vector.store(acc_gate[_ai], reduce_lds, [_off]) + if _has_up: + for _ai in range_constexpr(_num_accs): + _off = _lds_base + arith.index(_num_accs * 4 + _ai * 4) + vector.store(acc_up[_ai], reduce_lds, [_off]) + + gpu.barrier() + + _partner_wave = ( + (wave_id + arith.index(_waves_per_group)) + % arith.index(4) + ) + _partner_base = ( + _partner_wave * arith.index(64) * _reduce_stride + + tx_local * _reduce_stride + ) + + for _ai in range_constexpr(_num_accs): + _off = _partner_base + arith.index(_ai * 4) + other_acc = vector.load_op(vec4_f32, reduce_lds, [_off]) + acc_gate[_ai] = arith.addf(acc_gate[_ai], other_acc) + if _has_up: + for _ai in range_constexpr(_num_accs): + _off = _partner_base + arith.index( + _num_accs * 4 + _ai * 4 + ) + other_up = vector.load_op(vec4_f32, reduce_lds, [_off]) + acc_up[_ai] = arith.addf(acc_up[_ai], other_up) + + # Cross-wave gate-up fusion for GUI + split_k + num_acc_n=1 + if const_expr(_gui_xwave_fuse): + gpu.barrier() + for _ai in range_constexpr(_num_accs): + _off = _lds_base + arith.index(_ai * 4) + vector.store(acc_gate[_ai], reduce_lds, [_off]) + gpu.barrier() + + # Read partner within same wave group (0↔1) + _xw_partner_in_grp = ( + arith.index(1) - _wave_in_group ) - a_scale_ping, gate_bs_ping, up_bs_ping = prefetch_ab_scale_tile( - k_tail1 // arith.constant(pack_K * 128, index=True), - ku_packed_limit=_tail_ku_packed, + _xw_partner_wave = ( + _wave_group * arith.index(_waves_per_group) + + _xw_partner_in_grp ) - else: - gate_w_ping, up_w_ping = load_b_tile( - k_tail1 // arith.constant(2, index=True) + _xw_partner_base = ( + _xw_partner_wave * arith.index(64) * _reduce_stride + + tx_local * _reduce_stride ) - a_scale_ping, gate_bs_ping, up_bs_ping = prefetch_ab_scale_tile( - k_tail1 // arith.constant(pack_K * 128, index=True) + + # silu(gate) * up — gate is wave_in_group=0, up is wave_in_group=1 + _is_gate_wave = arith.cmpi( + arith.CmpIPredicate.eq, + arith.index_cast(i32, _wave_in_group), + arith.constant(0, type=T.i32), ) - acc_gate, acc_up, _ = compute_tile( - acc_gate, - acc_up, - gate_w_pong, - up_w_pong, - a_tile_pong, - a_scale_pong, - gate_bs_pong, - up_bs_pong, + for _ai in range_constexpr(_num_accs): + _xw_off = _xw_partner_base + arith.index(_ai * 4) + _xw_partner_acc = vector.load_op( + vec4_f32, reduce_lds, [_xw_off], + ) + _fused_elems = [None] * 4 + for _e in range_constexpr(4): + own_e = vector.extract( + acc_gate[_ai], + static_position=[_e], dynamic_position=[], + ) + par_e = vector.extract( + _xw_partner_acc, + static_position=[_e], dynamic_position=[], + ) + g_e = arith.select(_is_gate_wave, own_e, par_e) + u_e = arith.select(_is_gate_wave, par_e, own_e) + _fused_elems[_e] = _act_elem(g_e, u_e) + acc_gate[_ai] = vector.from_elements( + vec4_f32, _fused_elems + ) + + # ---- Cross-wave gate-up fusion (4-wave, no split_k) ---- + if const_expr(_gui_xwave_fuse and _split_k_intra <= 1): + _xw_num_accs = num_acc_n * m_repeat + _xw_f32_per_thr = 4 * _xw_num_accs + _xw_stride = arith.index(_xw_f32_per_thr) + _xw_total = 4 * 64 * _xw_f32_per_thr + + xw_lds = SmemPtr( + base_ptr_pong, lds_pong_offset, T.f32, + shape=(_xw_total,), + ).get() + + _xw_tx = lane_id + _xw_base = ( + wave_id * arith.index(64) * _xw_stride + + _xw_tx * _xw_stride ) - if const_expr(not use_async_copy): - store_x_tile_to_lds(x_regs_ping, lds_x_ping) - rocdl.s_waitcnt(0) - _barrier() - if const_expr(_pad_ku_skip > 0): - a_tile_ping = prefetch_full_a_from_lds( - lds_x_ping, ku_limit=_tail_ku - ) - else: - a_tile_ping = prefetch_full_a_from_lds(lds_x_ping) - acc_gate, acc_up, epilogue_pf = compute_tile( - acc_gate, - acc_up, - gate_w_ping, - up_w_ping, - a_tile_ping, - a_scale_ping, - gate_bs_ping, - up_bs_ping, - prefetch_epilogue=True, - ku_count=_tail_ku if _pad_ku_skip > 0 else k_unroll, + + for _ai in range_constexpr(_xw_num_accs): + _off = _xw_base + arith.index(_ai * 4) + vector.store(acc_gate[_ai], xw_lds, [_off]) + + gpu.barrier() + + # Partner: wave_id XOR 1 → pairs (0,1) and (2,3) + _xw_wid_i32 = arith.index_cast(i32, wave_id) + _xw_pid_i32 = arith.xori( + _xw_wid_i32, arith.constant(1, type=T.i32), + ) + _xw_partner = arith.index_cast(T.index, _xw_pid_i32) + _xw_pbase = ( + _xw_partner * arith.index(64) * _xw_stride + + _xw_tx * _xw_stride ) - bias_pf = None - if const_expr(epilogue_pf is not None): - _, _, bias_pf = epilogue_pf - - # Activation helpers (f32 element-wise on vec4_f32) - def _silu_elem(g): - """silu(x) = x * sigmoid(x); HW fast path: exp2, rcp""" - neg_log2e = arith.constant(-1.4426950408889634, type=f32) - t = g * neg_log2e - emu = llvm.call_intrinsic(f32, "llvm.amdgcn.exp2.f32", [t], [], []) - one = arith.constant(1.0, type=f32) - den = one + emu - sig = llvm.call_intrinsic(f32, "llvm.amdgcn.rcp.f32", [den], [], []) - return g * sig - - def _silu_mul_vec4(gate_v4, up_v4): - """Element-wise silu(gate) * up on vec4_f32.""" - result_elems = [] - for ei in range_constexpr(4): - g = vector.extract( - gate_v4, static_position=[ei], dynamic_position=[] - ) - u = vector.extract( - up_v4, static_position=[ei], dynamic_position=[] - ) - result_elems.append(_silu_elem(g) * u) - return vector.from_elements(vec4_f32, result_elems) + # even wave_id = gate, odd = up + _xw_is_gate = arith.cmpi( + arith.CmpIPredicate.eq, + arith.andi( + _xw_wid_i32, arith.constant(1, type=T.i32), + ), + arith.constant(0, type=T.i32), + ) - def _swiglu_mul_vec4(gate_v4, up_v4): - """Element-wise swiglu(gate, up) on vec4_f32. - swiglu(g, u) = g * sigmoid(alpha * g) * (u + 1) - with clamping: gate <= limit, -limit <= up <= limit. - """ - result_elems = [] - _alpha = arith.constant(1.702, type=f32) - _limit = arith.constant(7.0, type=f32) - _neg_limit = arith.constant(-7.0, type=f32) - _one = arith.constant(1.0, type=f32) - _neg_log2e = arith.constant(-1.4426950408889634, type=f32) - for ei in range_constexpr(4): - g = vector.extract( - gate_v4, static_position=[ei], dynamic_position=[] - ) - u = vector.extract( - up_v4, static_position=[ei], dynamic_position=[] + for _ai in range_constexpr(_xw_num_accs): + _xw_off = _xw_pbase + arith.index(_ai * 4) + _xw_pacc = vector.load_op( + vec4_f32, xw_lds, [_xw_off], ) - g = arith.minimumf(g, _limit) - u = arith.minimumf(u, _limit) - u = arith.maximumf(u, _neg_limit) - t = g * _alpha * _neg_log2e - emu = llvm.call_intrinsic( - f32, "llvm.amdgcn.exp2.f32", [t], [], [] - ) - den = _one + emu - sig = llvm.call_intrinsic( - f32, "llvm.amdgcn.rcp.f32", [den], [], [] + _fused = [None] * 4 + for _e in range_constexpr(4): + own_e = vector.extract( + acc_gate[_ai], + static_position=[_e], dynamic_position=[], + ) + par_e = vector.extract( + _xw_pacc, + static_position=[_e], dynamic_position=[], + ) + g_e = arith.select(_xw_is_gate, own_e, par_e) + u_e = arith.select(_xw_is_gate, par_e, own_e) + _fused[_e] = _act_elem(g_e, u_e) + acc_gate[_ai] = vector.from_elements( + vec4_f32, _fused, ) - result_elems.append(g * sig * (u + _one)) - return vector.from_elements(vec4_f32, result_elems) - def _act_vec4(gate_v4, up_v4): - """Dispatch activation based on `act` parameter.""" - if const_expr(act == "swiglu"): - return _swiglu_mul_vec4(gate_v4, up_v4) - else: - return _silu_mul_vec4(gate_v4, up_v4) - - # Add bias to raw GEMM accumulators before activation. - # bias layout: [E, 2*inter_dim] flat f32 (non-interleaved: gate then up). - # For gate_up_interleave, map physical column to logical bias offset. - if const_expr(enable_bias and not _is_splitk): - if const_expr(bias_pf is not None): - _bias_gate_vals = bias_pf - else: - _bias_gate_vals = [] - for _ni in range_constexpr(num_acc_n): - if const_expr(gate_up_interleave): - _logical_col = ( - (by_n + n_tile_base) - // arith.constant(2, index=True) - + arith.constant((_ni // 2) * 16, index=True) - + lane_mod_16 - ) - _up_off = ( - inter_idx - if (_ni % 2 == 1) - else arith.constant(0, index=True) - ) - _bias_off = expert_off_idx + _up_off + _logical_col - else: - _bn = ( - by_n - + n_tile_base - + arith.constant(_ni * 16, index=True) - + lane_mod_16 - ) - _bias_off = expert_off_idx + _bn - _bias_gate_vals.append( - buffer_ops.buffer_load( - bias_rsrc, _bias_off, vec_width=1, dtype=f32 - ) + # ---- Bias: add to raw accumulators before activation ---- + if enable_bias and not _is_splitk: + _bias_gate_vals = [] + for _ni in range_constexpr(num_acc_n): + _bn = by_n + n_tile_base + arith.index(_ni * 16) + lane_mod_16 + _bias_gate_vals.append( + buffer_ops.buffer_load( + bias_rsrc, expert_off_idx + _bn, + vec_width=1, dtype=f32 ) + ) for _mi in range_constexpr(m_repeat): for _ni in range_constexpr(num_acc_n): _aidx = _mi * num_acc_n + _ni - _bsplat = vector.from_elements( - vec4_f32, [_bias_gate_vals[_ni]] * 4 - ) + _bsplat = vector.splat(vec4_f32, _bias_gate_vals[_ni]) acc_gate[_aidx] = arith.addf(acc_gate[_aidx], _bsplat) - if const_expr(not (mock_gate_only or gate_up_interleave)): + if not (gate_only or gate_up_interleave): _bias_up_vals = [] for _ni in range_constexpr(num_acc_n): - _bn = ( - by_n - + n_tile_base - + arith.constant(_ni * 16, index=True) - + lane_mod_16 - ) + _bn = by_n + n_tile_base + arith.index(_ni * 16) + lane_mod_16 _bias_up_vals.append( buffer_ops.buffer_load( - bias_rsrc, - expert_off_idx + inter_idx + _bn, - vec_width=1, - dtype=f32, + bias_rsrc, expert_off_idx + inter_idx + _bn, + vec_width=1, dtype=f32 ) ) for _mi in range_constexpr(m_repeat): for _ni in range_constexpr(num_acc_n): _aidx = _mi * num_acc_n + _ni - _bsplat = vector.from_elements( - vec4_f32, [_bias_up_vals[_ni]] * 4 - ) + _bsplat = vector.splat(vec4_f32, _bias_up_vals[_ni]) acc_up[_aidx] = arith.addf(acc_up[_aidx], _bsplat) - if const_expr(gate_up_interleave and not _is_splitk): - _gui_out_n = num_acc_n // pack_N + # ---- Epilogue ---- + expert_off = expert_off_idx + bx_m0 = bx_m + tokens_i32_v = tokens_i32 + topk_i32_v = topk_i32 + inter_i32_v = arith.constant(inter_dim, type=T.i32) + inter2_i32_v = arith.constant(inter_dim * 2, type=T.i32) + mask24_i32 = arith.constant(0xFFFFFF, type=T.i32) + + # Fuse activation for non-split-K paths + if const_expr(gate_up_interleave and not _is_splitk and _gui_xwave_fuse): + acc = acc_gate + _eff_num_acc_n = 1 + _eff_tile_n = _gui_out_tile_n + elif const_expr(gate_up_interleave and not _is_splitk): + pack_N_e = 2 + _gui_out_n = num_acc_n // pack_N_e acc = [None] * (_gui_out_n * m_repeat) for _mi in range_constexpr(m_repeat): for _ni in range_constexpr(_gui_out_n): - _g_idx = _mi * num_acc_n + _ni * pack_N + _g_idx = _mi * num_acc_n + _ni * pack_N_e _u_idx = _g_idx + 1 _out_idx = _mi * _gui_out_n + _ni acc[_out_idx] = _act_vec4( acc_gate[_g_idx], acc_gate[_u_idx] ) + _eff_num_acc_n = _gui_out_n + _eff_tile_n = _gui_out_tile_n elif const_expr(not _is_splitk): - acc = [None] * (int(num_acc_n) * int(m_repeat)) + acc = [None] * (num_acc_n * m_repeat) for _mi in range_constexpr(m_repeat): for _ni in range_constexpr(num_acc_n): _aidx = _mi * num_acc_n + _ni - acc[_aidx] = _silu_mul_vec4(acc_gate[_aidx], acc_up[_aidx]) + acc[_aidx] = _act_vec4(acc_gate[_aidx], acc_up[_aidx]) + _eff_num_acc_n = num_acc_n + _eff_tile_n = tile_n + else: + acc = acc_gate + _eff_num_acc_n = num_acc_n + _eff_tile_n = tile_n - # ---- Epilogue: CShuffle + direct store (accumulate=False) ---- - # Output: out[(t*topk+s) * inter_dim + col] = silu(gate) * up - # For split-K: skip silu, output gate/up separately with atomic add - tw_pf = None - bias_pf = None - if const_expr(epilogue_pf is not None): - _, tw_pf, bias_pf = epilogue_pf + col_i32_list = [] + for ni in range_constexpr(len(col_g_list)): + col_i32_list.append(arith.index_cast(i32, col_g_list[ni])) - mask24_i32 = arith.constant(0xFFFFFF) - topk_i32_v = topk_i32 - tokens_i32_v = tokens_i32 + # Row stride for output indexing + if const_expr(_is_splitk): + _out_row_stride_i32 = inter2_i32_v + else: + _out_row_stride_i32 = inter_i32_v - from flydsl._mlir.dialects import fly as _fly + _sk_n_offset = [0] - _llvm_ptr_ty = ir.Type.parse("!llvm.ptr") - out_base_ptr = _fly.extract_aligned_pointer_as_index( - _llvm_ptr_ty, arg_out - ) - out_base_i64 = llvm.ptrtoint(T.i64, out_base_ptr) - out_base_idx = arith.index_cast(ir.IndexType.get(), out_base_i64) + zero_i32_sk = arith.constant(0, type=T.i32) + c4_i32_sk = arith.constant(4, type=T.i32) - if const_expr(lds_out is None): - raise RuntimeError("CShuffle epilogue requires lds_out") + def _decode_fused2_at_row(row): + """Load fused2 from sorted_rsrc[row]; decode t/s and validity flags.""" + fused2 = buffer_ops.buffer_load( + sorted_rsrc, row, vec_width=1, dtype=i32 + ) + t2 = fused2 & mask24_i32 + s2 = fused2 >> 24 + row_i32 = arith.index_cast(i32, row) + row_ok = arith.cmpi(arith.CmpIPredicate.ult, row_i32, num_valid_i32) + t_ok = arith.cmpi(arith.CmpIPredicate.ult, t2, tokens_i32_v) + s_ok = arith.cmpi(arith.CmpIPredicate.ult, s2, topk_i32_v) + return fused2, t2, s2, row_ok, t_ok, s_ok + + if const_expr(_is_splitk): + if const_expr(gate_only or gate_up_interleave): + # Single-pass atomic: no activation fusion + def _splitk_store_row(*, mi: int, ii: int, row_in_tile, row): + _, t2, s2, row_ok, t_ok, s_ok = _decode_fused2_at_row(row) + all_ok = row_ok & t_ok & s_ok + sx = arith.select( + all_ok, + arith.constant(1.0, type=T.f32), + arith.constant(0.0, type=T.f32), + ) + t2_safe = arith.select(all_ok, t2, arith.constant(0, type=T.i32)) + s2_safe = arith.select(all_ok, s2, arith.constant(0, type=T.i32)) + idx0 = (t2_safe * topk_i32_v + s2_safe) * _out_row_stride_i32 + + for ni in range_constexpr(_eff_num_acc_n): + col_i32 = col_i32_list[ni] + acc_idx = mi * _eff_num_acc_n + ni + val = vector.extract( + acc[acc_idx], + static_position=[ii], + dynamic_position=[], + ) + val = val * sx + idx_elem = idx0 + col_i32 + byte_off = idx_elem * c4_i32_sk + rocdl.raw_ptr_buffer_atomic_fadd( + val, out_rsrc, byte_off, + zero_i32_sk, zero_i32_sk, + ) - _apply_weight = doweight_stage1 and not _is_splitk + mfma_epilog( + use_cshuffle=False, + arith=arith, + range_constexpr=range_constexpr, + m_repeat=m_repeat, + lane_div_16=lane_div_16, + bx_m=bx_m, + body_row=_splitk_store_row, + ) + else: + # Separated split-K: two-pass atomic (gate then up) + def _splitk_sep_store_row( + *, mi: int, ii: int, row_in_tile, row + ): + _, t2, s2, row_ok, t_ok, s_ok = _decode_fused2_at_row(row) + all_ok = row_ok & t_ok & s_ok + sx = arith.select( + all_ok, + arith.constant(1.0, type=T.f32), + arith.constant(0.0, type=T.f32), + ) + t2_safe = arith.select(all_ok, t2, arith.constant(0, type=T.i32)) + s2_safe = arith.select(all_ok, s2, arith.constant(0, type=T.i32)) + idx0 = ( + (t2_safe * topk_i32_v + s2_safe) * _out_row_stride_i32 + + arith.constant(_sk_n_offset[0], type=T.i32) + ) - def write_row_to_lds( - *, - mi: int, - ii: int, - row_in_tile, - row, - row_base_lds, - col_base_local, - num_acc_n: int, - lds_out, - ): - if const_expr(_apply_weight): - tw_idx = (mi * 4) + ii - if const_expr(tw_pf is not None): - tw = tw_pf[tw_idx] - else: + for ni in range_constexpr(num_acc_n): + col_i32 = col_i32_list[ni] + acc_idx = mi * num_acc_n + ni + val = vector.extract( + acc[acc_idx], + static_position=[ii], + dynamic_position=[], + ) + val = val * sx + idx_elem = idx0 + col_i32 + byte_off = idx_elem * c4_i32_sk + rocdl.raw_ptr_buffer_atomic_fadd( + val, out_rsrc, byte_off, + zero_i32_sk, zero_i32_sk, + ) + + # Pass 1: gate (offset 0) + acc = acc_gate + _sk_n_offset[0] = 0 + mfma_epilog( + use_cshuffle=False, + arith=arith, + range_constexpr=range_constexpr, + m_repeat=m_repeat, + lane_div_16=lane_div_16, + bx_m=bx_m, + body_row=_splitk_sep_store_row, + ) + gpu.barrier() + # Pass 2: up (offset inter_dim) + acc = acc_up + _sk_n_offset[0] = inter_dim + mfma_epilog( + use_cshuffle=False, + arith=arith, + range_constexpr=range_constexpr, + m_repeat=m_repeat, + lane_div_16=lane_div_16, + bx_m=bx_m, + body_row=_splitk_sep_store_row, + ) + elif const_expr(_use_cshuffle_epilog): + if const_expr(lds_out is None): + raise RuntimeError("CShuffle requires lds_out.") + + def write_row_to_lds( + *, mi: int, ii: int, row_in_tile, row, + row_base_lds, col_base_local, num_acc_n: int, lds_out, + ): + _, _, _, _, t_valid, _ = _decode_fused2_at_row(row) + + if const_expr(doweight_stage1): tw = buffer_ops.buffer_load( sorted_w_rsrc, row, vec_width=1, dtype=f32 ) - for ni in range_constexpr(num_acc_n): - col_local = col_base_local + (ni * 16) - acc_idx = mi * num_acc_n + ni - v = vector.extract( - acc[acc_idx], static_position=[ii], dynamic_position=[] - ) - if const_expr(_apply_weight): - v = v * tw - if const_expr(_need_quant): - lds_idx = row_base_lds + col_local - vec1_f32 = T.vec(1, f32) - v1 = vector.from_elements(vec1_f32, [v]) - vector.store(v1, lds_out, [lds_idx], alignment=4) - else: - v_out = arith.trunc_f(out_elem(), v) + + _out_ty = out_mlir() + _vec1_out = T.vec(1, _out_ty) + for ni in range_constexpr(_eff_num_acc_n): + col_local = col_base_local + (ni * 16) + acc_idx = mi * _eff_num_acc_n + ni + y = vector.extract( + acc[acc_idx], + static_position=[ii], dynamic_position=[], + ) + if const_expr(doweight_stage1): + y = y * tw + y16 = arith.trunc_f(_out_ty, y) lds_idx = row_base_lds + col_local - vec1_out = T.vec(1, out_elem()) - v1 = vector.from_elements(vec1_out, [v_out]) + v1 = vector.from_elements(_vec1_out, [y16]) vector.store(v1, lds_out, [lds_idx], alignment=2) - _out_row_stride = ( - inter_dim * 2 * out_elem_bytes - if _is_splitk - else ( - inter_dim // 2 - if _need_fp4 - else (inter_dim if _need_fp8 else inter_dim * out_elem_bytes) + def precompute_row(*, row_local, row): + _, t2, s2, row_ok, t_ok, s_ok = _decode_fused2_at_row(row) + row_valid = row_ok & t_ok & s_ok + row_byte_base = (t2 * topk_i32_v + s2) * inter_i32_v + return (row_byte_base, row_valid) + + def store_pair(*, row_local, row, row_ctx, col_pair0, col_g0, frag): + _, _, _, _, t_valid, _ = _decode_fused2_at_row(row) + _if_valid = scf.IfOp(t_valid) + with _if_then(_if_valid): + idx0 = row_ctx + col_i32 = arith.index_cast(i32, col_g0) + idx_out = idx0 + col_i32 + buffer_ops.buffer_store(frag, out_rsrc, idx_out) + + _cs_by_n = by_n + _cs_n_tile_base = n_tile_base + if const_expr(gate_up_interleave): + _cs_by_n = by_n // arith.index(2) + _cs_n_tile_base = n_tile_base // arith.index(2) + + _cs_nlane = min(32, _eff_tile_n // 4) + mfma_epilog( + use_cshuffle=True, + arith=arith, vector=vector, gpu=gpu, scf=scf, + range_constexpr=range_constexpr, + tile_m=tile_m, tile_n=_eff_tile_n, e_vec=4, + cshuffle_nlane=_cs_nlane, + m_repeat=m_repeat, num_acc_n=_eff_num_acc_n, + tx=tx, lane_div_16=lane_div_16, lane_mod_16=lane_mod_16, + bx_m=bx_m, by_n=_cs_by_n, n_tile_base=_cs_n_tile_base, + lds_out=lds_out, + frag_elem_type=out_mlir(), + write_row_to_lds=write_row_to_lds, + precompute_row=precompute_row, + store_pair=store_pair, ) - ) + else: + # Direct epilogue (non-split-K, non-cshuffle) + def _stage1_store_row(*, mi: int, ii: int, row_in_tile, row): + _, t2, s2, row_ok, t_ok, s_ok = _decode_fused2_at_row(row) + row_valid = row_ok & t_ok & s_ok - def precompute_row(*, row_local, row): - fused2 = memref.load(lds_tid, [row_local]) - row_i32 = arith.index_cast(T.i32, row) - row_valid0 = arith.cmpi(CmpIPredicate.ult, row_i32, num_valid_i32) - t = fused2 & mask24_i32 - s = fused2 >> 24 - t_ok = arith.cmpi(CmpIPredicate.ult, t, tokens_i32_v) - s_ok = arith.cmpi(CmpIPredicate.ult, s, topk_i32_v) - row_valid = arith.andi(row_valid0, arith.andi(t_ok, s_ok)) - t_idx = arith.index_cast(ir.IndexType.get(), t) - s_idx = arith.index_cast(ir.IndexType.get(), s) - ts_idx = t_idx * arith.constant(topk, index=True) + s_idx - row_byte_base = out_base_idx + ts_idx * arith.constant( - _out_row_stride, index=True + idx0 = (t2 * topk_i32_v + s2) * inter_i32_v + + if const_expr(doweight_stage1): + tw = buffer_ops.buffer_load( + sorted_w_rsrc, row, vec_width=1, dtype=f32 + ) + + _if_valid = scf.IfOp(row_valid) + with _if_then(_if_valid): + for ni in range_constexpr(_eff_num_acc_n): + col_i32 = col_i32_list[ni] + acc_idx = mi * _eff_num_acc_n + ni + y = vector.extract( + acc[acc_idx], + static_position=[ii], dynamic_position=[], + ) + if const_expr(doweight_stage1): + y = y * tw + y = arith.trunc_f(out_mlir(), y) + idx_out0 = idx0 + col_i32 + buffer_ops.buffer_store(y, out_rsrc, idx_out0) + + mfma_epilog( + use_cshuffle=False, + arith=arith, + range_constexpr=range_constexpr, + m_repeat=m_repeat, + lane_div_16=lane_div_16, + bx_m=bx_m, + body_row=_stage1_store_row, ) - return ((fused2, row_byte_base), row_valid) - def _idx_to_llvm_ptr(idx_val, addr_space=1): - idx_v = idx_val._value if hasattr(idx_val, "_value") else idx_val - i64_v = arith.index_cast(T.i64, idx_v) - i64_raw = i64_v._value if hasattr(i64_v, "_value") else i64_v - ptr_ty = ir.Type.parse(f"!llvm.ptr<{addr_space}>") - return llvm.inttoptr(ptr_ty, i64_raw) + return - _e_vec = _e_vec_s1 - _e_vec_sk = 2 - _cshuffle_nlane = min(32, tile_n // _e_vec) - _cshuffle_nlane_sk = min(32, tile_n // _e_vec_sk) - _num_threads_per_quant_blk = _num_threads_per_quant_blk_s1 - - _c0_i32 = arith.constant(0, type=T.i32) - _c1_i32 = arith.constant(1, type=T.i32) - _c2_i32 = arith.constant(2, type=T.i32) - _c3_i32 = arith.constant(3, type=T.i32) - _c4_i32 = arith.constant(4, type=T.i32) - _c5_i32 = arith.constant(5, type=T.i32) - _c7_i32 = arith.constant(7, type=T.i32) - _c15_i32 = arith.constant(15, type=T.i32) - _c21_i32 = arith.constant(21, type=T.i32) - _c23_i32 = arith.constant(23, type=T.i32) - _c28_i32 = arith.constant(28, type=T.i32) - _c31_i32 = arith.constant(31, type=T.i32) - _c32_i32 = arith.constant(32, type=T.i32) - _c64_i32 = arith.constant(64, type=T.i32) - _c126_i32 = arith.constant(126, type=T.i32) - _c127_i32 = arith.constant(127, type=T.i32) - _c254_i32 = arith.constant(254, type=T.i32) - _c256_i32 = arith.constant(256, type=T.i32) - _c0xFF_i32 = arith.constant(0xFF, type=T.i32) - _c0x200000_i32 = arith.constant(0x200000, type=T.i32) - _c0xFF800000_i32 = arith.constant(0xFF800000, type=T.i32) - _c0x400000_i32 = arith.constant(0x400000, type=T.i32) - _c0x7FFFFF_i32 = arith.constant(0x7FFFFF, type=T.i32) - _c0x80000000_i32 = arith.constant(0x80000000, type=T.i32) - _c0_f32 = arith.constant(0.0, type=T.f32) - - _c8_i32 = arith.constant(8, type=T.i32) - _fp_headroom = 2 if _need_fp4 else (8 if _need_fp8 else 0) - _c_headroom_i32 = arith.constant(_fp_headroom, type=T.i32) - - def _f32_to_e2m1(qx_f32): - """Convert a scaled f32 value to fp4 (e2m1) 4-bit integer.""" - qx = qx_f32.bitcast(T.i32) - s = qx & _c0x80000000_i32 - e = (qx >> _c23_i32) & _c0xFF_i32 - m = qx & _c0x7FFFFF_i32 - adj_exp = arith.maxsi(_c126_i32 - e, _c0_i32) - m_denorm = (_c0x400000_i32 | (m >> _c1_i32)) >> adj_exp - is_denorm = arith.cmpi(CmpIPredicate.ult, e, _c127_i32) - m = arith.select(is_denorm, m_denorm, m) - e = arith.maxsi(e - _c126_i32, _c0_i32) - combined = (e << _c2_i32) | (m >> _c21_i32) - rounded = (combined + _c1_i32) >> _c1_i32 - e2m1 = arith.minui(rounded, _c7_i32) - return (s >> _c28_i32) | e2m1 - - if const_expr(_need_sort): - _n32_sort = _sorted_scale_cols_i32 * _c32_i32 - - # Mutable slot for split-K N-offset (gate=0, up=inter_dim) - _sk_n_offset = [0] - def store_pair(*, row_local, row, row_ctx, col_pair0, col_g0, frag): - fused, row_byte_base = row_ctx - if const_expr(_need_quant and not _is_splitk): - frag_vals = [] - for i in range_constexpr(_e_vec): - frag_vals.append( - vector.extract( - frag, static_position=[i], dynamic_position=[] - ) - ) + tokens_in = arith.index_cast(ir.IndexType.get(), i32_tokens_in.ir_value()) + n_in = arith.index_cast(ir.IndexType.get(), i32_n_in.ir_value()) + k_in = arith.index_cast(ir.IndexType.get(), i32_k_in.ir_value()) + size_expert_ids_in = arith.index_cast( + ir.IndexType.get(), i32_size_expert_ids_in.ir_value() + ) - local_max = _c0_f32 - for i in range_constexpr(_e_vec): - abs_v = llvm.call_intrinsic( - f32, "llvm.fabs.f32", [frag_vals[i]], [], [] - ) - local_max = arith.maximumf(local_max, abs_v) - - for _si in range_constexpr(_num_shuffle_steps_s1): - off = arith.constant(_shuffle_dists_s1[_si], type=T.i32) - peer = local_max.shuffle_xor(off, _c64_i32) - local_max = arith.maximumf(local_max, peer) - - max_i32 = local_max.bitcast(T.i32) - max_rounded = (max_i32 + _c0x200000_i32) & _c0xFF800000_i32 - exp_field = max_rounded >> _c23_i32 - e8m0_biased = arith.maxsi(exp_field - _c_headroom_i32, _c0_i32) - - quant_exp = _c254_i32 - e8m0_biased - quant_scale = (quant_exp << _c23_i32).bitcast(T.f32) - - if const_expr(_need_fp4): - fp4_vals = [] - for i in range_constexpr(_e_vec): - scaled_v = frag_vals[i] * quant_scale - fp4_vals.append(_f32_to_e2m1(scaled_v)) - - packed_i32 = fp4_vals[0] | (fp4_vals[1] << _c4_i32) - for k in range_constexpr(1, _e_vec // 2): - byte_k = fp4_vals[2 * k] | ( - fp4_vals[2 * k + 1] << _c4_i32 - ) - packed_i32 = packed_i32 | ( - byte_k << arith.constant(k * 8, type=T.i32) - ) + x_elem = T.f16 if is_f16_a else (T.i8 if is_int8 else T.f8) + f32 = T.f32 + i32 = T.i32 + i64 = T.i64 + vec4_f32 = T.vec(4, f32) + vec16_elems = 16 if a_elem_bytes == 1 else 8 + vec16_x = T.vec(vec16_elems, x_elem) + vec2_i64 = T.vec(2, i64) + + acc_init = arith.constant_vector(0.0, vec4_f32) + + # --- Stage1 dimension mapping --- + # X: [tokens, model_dim] -- M = sorted tokens, K = model_dim + # W: [E*2*inter_dim, model_dim] gate portion -- N = inter_dim + # Out: [tokens*topk, inter_dim] + + # B preshuffle layout: [E*2*inter_dim, model_dim] + # Gate rows for expert e: [e*2*inter_dim, e*2*inter_dim + inter_dim) + c_n_total = arith.constant(experts * (2 * inter_dim), index=True) + b_layout = make_preshuffle_b_layout( + arith, + c_n=c_n_total, + c_k=k_in // pack_K, + kpack_bytes=kpack_bytes, + elem_bytes=b_elem_bytes, + # k_major=True, + ) + layout_b = b_layout.layout_b - ptr_addr_idx = row_byte_base + col_g0 / arith.constant( - 2, index=True - ) - out_ptr_v = _idx_to_llvm_ptr(ptr_addr_idx) - _pack_bytes = _e_vec // 2 - if const_expr(_pack_bytes == 1): - store_val = arith.TruncIOp(T.i8, packed_i32) - store_raw = ( - store_val._value - if hasattr(store_val, "_value") - else store_val - ) - llvm.StoreOp( - store_raw, out_ptr_v, alignment=1, nontemporal=True - ) - elif const_expr(_pack_bytes == 2): - store_val = arith.TruncIOp(T.i16, packed_i32) - store_raw = ( - store_val._value - if hasattr(store_val, "_value") - else store_val + # A-scale: [sorted_size, K/32] -- pre-scattered by caller into sorted layout + # Same as stage2: indexed by sorted_row position, not by token_id. + sorted_m = size_expert_ids_in * arith.constant(sort_block_m, index=True) + layout_a_scale = make_preshuffle_scale_layout( + arith, c_mn=sorted_m, c_k=arith.constant(model_dim, index=True) + ) + # B-scale: [E*2*inter_dim, K/32] + layout_b_scale = make_preshuffle_scale_layout( + arith, c_mn=c_n_total, c_k=arith.constant(model_dim, index=True) + ) + + _eff_lds_stride = lds_stride + _eff_tile_k_bytes = tile_k_bytes + if const_expr(use_async_copy and a_elem_vec_pack > 1): + _eff_lds_stride = lds_stride // a_elem_vec_pack + _eff_tile_k_bytes = tile_k_bytes // a_elem_vec_pack + + shape_lds = fx.make_shape(tile_m, _eff_lds_stride) + stride_lds = fx.make_stride(_eff_lds_stride, 1) + layout_lds = fx.make_layout(shape_lds, stride_lds) + + tx = gpu.thread_id("x") + by = gpu.block_id("x") # tile along inter_dim (N) + bx_persist = gpu.block_id("y") # persistent WG index + + if const_expr(xcd_swizzle > 0): + _NUM_XCDS_S1 = 8 + _c1_sw = arith.constant(1, index=True) + _c_tn_sw = arith.constant(tile_n, index=True) + _c_idp_sw = arith.constant(2 * inter_dim_pad, index=True) + if const_expr(mock_gate_only or gate_up_interleave): + _gx = (n_in - _c_idp_sw + _c_tn_sw - _c1_sw) / _c_tn_sw + else: + _c2_sw = arith.constant(2, index=True) + _gx = ( + (n_in - _c_idp_sw + _c2_sw * _c_tn_sw - _c1_sw) + / _c_tn_sw + / _c2_sw + ) + _c_pm_sw = arith.constant(persist_m, index=True) + _gy = (size_expert_ids_in + _c_pm_sw - _c1_sw) / _c_pm_sw + + _linear_id = bx_persist * _gx + by + _num_wgs = _gx * _gy + + _c_xcds = arith.constant(_NUM_XCDS_S1, index=True) + _wgs_per_xcd = _num_wgs / _c_xcds + _wgid = (_linear_id % _c_xcds) * _wgs_per_xcd + (_linear_id / _c_xcds) + + _WGM_S1 = xcd_swizzle + _c_wgm = arith.constant(_WGM_S1, index=True) + _num_wgid_in_group = _c_wgm * _gx + _group_id = _wgid / _num_wgid_in_group + _first_pid_m = _group_id * _c_wgm + _remaining_m = _gy - _first_pid_m + _cmp_m = arith.cmpi(CmpIPredicate.ult, _remaining_m, _c_wgm) + _group_size_m = arith.select(_cmp_m, _remaining_m, _c_wgm) + + _wgid_in_group = _wgid % _num_wgid_in_group + bx_persist = _first_pid_m + (_wgid_in_group % _group_size_m) + by = _wgid_in_group / _group_size_m + + by_n = by * arith.constant(tile_n, index=True) + + k_base_idx = arith.index(0) + if const_expr(_is_splitk): + bz = gpu.block_id("z") # K-batch id + k_base_idx = bz * arith.constant(_k_dim, index=True) + + k_blocks16 = arith.constant(_eff_tile_k_bytes // 16, index=True) + layout_tx_wave_lane = fx.make_layout((num_waves, 64), stride=(64, 1)) + layout_lane16 = fx.make_layout((4, 16), stride=(16, 1)) + + base_ptr_pong = allocator_pong.get_base() + base_ptr_ping = allocator_ping.get_base() + lds_x_pong = SmemPtr( + base_ptr_pong, lds_pong_offset, x_lds_elem(), shape=(_input_elems,) + ).get() + lds_x_ping = SmemPtr( + base_ptr_ping, lds_ping_offset, x_lds_elem(), shape=(_input_elems,) + ).get() + _lds_out_elem_type = ( + T.f32 if _need_quant else (T.bf16 if out_is_bf16 else T.f16) + ) + if const_expr(_split_lds_out and _use_cshuffle_epilog): + _half_out_elems = int(tile_m) * (int(tile_n) // 2) + lds_out = SmemPtr( + base_ptr_pong, + lds_pong_offset, + _lds_out_elem_type, + shape=(_half_out_elems,), + ).get() + lds_out_B = SmemPtr( + base_ptr_ping, + lds_ping_offset, + _lds_out_elem_type, + shape=(_half_out_elems,), + ).get() + else: + lds_out = ( + SmemPtr( + base_ptr_pong, + lds_pong_offset, + _lds_out_elem_type, + shape=(tile_m * tile_n,), + ).get() + if _use_cshuffle_epilog + else None + ) + lds_out_B = None + lds_tid = SmemPtr( + base_ptr_pong, _lds_tid_offset_pong, T.i32, shape=(tile_m,) + ).get() + + # Buffer resources + c_a_pack = arith.constant(int(a_elem_vec_pack), index=True) + c_elem_bytes = arith.constant(int(a_elem_bytes), index=True) + + # X: [tokens, model_dim] + x_nbytes_idx = (tokens_in * k_in * c_elem_bytes) / c_a_pack + x_nbytes_i32 = arith.index_cast(T.i32, x_nbytes_idx) + x_rsrc = buffer_ops.create_buffer_resource( + arg_x, max_size=False, num_records_bytes=x_nbytes_i32 + ) + + w_rsrc = buffer_ops.create_buffer_resource(arg_w, max_size=False) + + # Out: [tokens*topk, inter_dim] + numids_rsrc = buffer_ops.create_buffer_resource( + arg_num_valid_ids, + max_size=False, + num_records_bytes=arith.constant(4, type=T.i32), + ) + num_valid_i32 = buffer_ops.buffer_load( + numids_rsrc, arith.constant(0, index=True), vec_width=1, dtype=T.i32 + ) + + sx_rsrc = 1 + sw_rsrc = 1 + if const_expr(not (is_f16_a or a_scale_one)): + # A scale: [sorted_size, model_dim/32] pre-scattered by caller + c32 = arith.constant(32, index=True) + kblk = k_in / c32 + sx_nbytes_idx = sorted_m * kblk + sx_nbytes_i32 = arith.index_cast(T.i32, sx_nbytes_idx) + sx_rsrc = buffer_ops.create_buffer_resource( + arg_scale_x, max_size=False, num_records_bytes=sx_nbytes_i32 + ) + + if const_expr(not is_f16_b): + c32 = arith.constant(32, index=True) + kblk_w = k_in / c32 + mn_w = arith.constant(experts * (2 * inter_dim), index=True) + sw_nbytes_idx = mn_w * kblk_w + sw_nbytes_i32 = arith.index_cast(T.i32, sw_nbytes_idx) + sw_rsrc = buffer_ops.create_buffer_resource( + arg_scale_w, max_size=False, num_records_bytes=sw_nbytes_i32 + ) + + sorted_nbytes_idx = size_expert_ids_in * arith.constant( + sort_block_m * 4, index=True + ) + sorted_nbytes_i32 = arith.index_cast(T.i32, sorted_nbytes_idx) + sorted_rsrc = buffer_ops.create_buffer_resource( + arg_sorted_token_ids, + max_size=False, + num_records_bytes=sorted_nbytes_i32, + ) + sorted_w_rsrc = buffer_ops.create_buffer_resource( + arg_sorted_weights, max_size=False, num_records_bytes=sorted_nbytes_i32 + ) + + eid_nbytes_idx = size_expert_ids_in * arith.constant(4, index=True) + eid_nbytes_i32 = arith.index_cast(T.i32, eid_nbytes_idx) + expert_rsrc = buffer_ops.create_buffer_resource( + arg_expert_ids, max_size=False, num_records_bytes=eid_nbytes_i32 + ) + bias_rsrc = ( + buffer_ops.create_buffer_resource(arg_bias, max_size=False) + if enable_bias + else None + ) + + # Sorted-scale buffer resource for fused mxfp4 quantization + _sorted_scale_cols = inter_dim // 32 + _sorted_scale_cols_i32 = arith.constant(_sorted_scale_cols, type=T.i32) + sorted_scale_rsrc = None + if const_expr(_need_sort): + sorted_scale_rsrc = buffer_ops.create_buffer_resource( + arg_out_scale_sorted, max_size=False + ) + + # ---- persist_m loop (same pattern as stage2) ---- + _PERSIST_M = persist_m + _c0_p = arith.constant(0, index=True) + _c1_p = arith.constant(1, index=True) + _c_pm = arith.constant(_PERSIST_M, index=True) + _for_persist = scf.ForOp(_c0_p, _c_pm, _c1_p) + _for_ip = ir.InsertionPoint(_for_persist.body) + _for_ip.__enter__() + _mi_p = _for_persist.induction_variable + bx = bx_persist * _c_pm + _mi_p + bx_m = bx * arith.constant(sort_block_m, index=True) + + # Block validity + bx_m_i32 = arith.index_cast(T.i32, bx_m) + blk_valid = arith.cmpi(CmpIPredicate.ult, bx_m_i32, num_valid_i32) + expert_i32 = buffer_ops.buffer_load( + expert_rsrc, bx, vec_width=1, dtype=T.i32 + ) + expert_idx = arith.index_cast(ir.IndexType.get(), expert_i32) + exp_valid = arith.cmpi( + CmpIPredicate.ult, expert_i32, arith.constant(experts, type=T.i32) + ) + + def _moe_gemm1_body(): + # Gate expert offset: first inter_dim rows of each expert's 2*inter_dim block + expert_off_idx = expert_idx * arith.constant(2 * inter_dim, index=True) + + # X loading -- KEY DIFFERENCE from stage2: X row = token_id only + x_load_bytes = 16 + num_x_loads = bytes_per_thread_x // x_load_bytes + chunk_i32 = x_load_bytes // 4 + + c_k_div4 = ( + (k_in / c_a_pack) * arith.constant(int(a_elem_bytes), index=True) + ) / arith.index(4) + tile_k_dwords = (int(tile_k) * int(a_elem_bytes)) // ( + 4 * int(a_elem_vec_pack) + ) + layout_x_tile_div4 = fx.make_layout( + (tile_m, tile_k_dwords), stride=(tile_k_dwords, 1) + ) + c_chunk_i32 = arith.constant(chunk_i32, index=True) + tx_i32_base = tx * c_chunk_i32 + + topk_i32 = arith.constant(topk) + mask24 = arith.constant(0xFFFFFF) + tokens_i32 = arith.index_cast(T.i32, tokens_in) + + def x_tile_chunk_coord_i32(i: int): + return tile_chunk_coord_i32( + arith, + tx_i32_base=tx_i32_base, + i=i, + total_threads=total_threads, + layout_tile_div4=layout_x_tile_div4, + chunk_i32=chunk_i32, + ) + + def load_x(idx_i32): + idx_elem = ( + idx_i32 if a_elem_bytes == 1 else (idx_i32 * arith.index(2)) + ) + return buffer_copy_gmem16_dwordx4( + buffer_ops, + vector, + elem_type=x_elem, + idx_i32=idx_elem, + rsrc=x_rsrc, + vec_elems=vec16_elems, + ) + + # Decode sorted token ids -- stage1: X row = token_id (not t*topk+s) + x_row_base_div4 = [] + x_col_local_i32 = [] + x_row_local = [] + # Also store token_id and slot_id for output indexing + + for i in range_constexpr(num_x_loads): + row_local, col_local_i32 = x_tile_chunk_coord_i32(i) + x_row_local.append(row_local) + x_col_local_i32.append(col_local_i32) + + sorted_row_i = bx_m + row_local + fused_i = buffer_ops.buffer_load( + sorted_rsrc, sorted_row_i, vec_width=1, dtype=T.i32 + ) + t_i32 = arith.andi(fused_i, mask24) + s_i32 = arith.shrui(fused_i, arith.constant(24)) + t_valid = arith.cmpi(CmpIPredicate.ult, t_i32, tokens_i32) + s_valid = arith.cmpi(CmpIPredicate.ult, s_i32, topk_i32) + ts_valid = arith.andi(t_valid, s_valid) + t_safe = arith.select(ts_valid, t_i32, arith.constant(0)) + + # KEY: X row base uses token_id only (not t*topk+s) + t_idx = arith.index_cast(ir.IndexType.get(), t_safe) + x_row_base_div4.append(t_idx * c_k_div4) + + def load_x_tile(base_k): + base_k_div4 = ( + (base_k / c_a_pack) + * arith.constant(int(a_elem_bytes), index=True) + ) / arith.index(4) + parts = [] + for i in range_constexpr(num_x_loads): + idx_i32 = x_row_base_div4[i] + base_k_div4 + x_col_local_i32[i] + x_vec = load_x(idx_i32) + parts.append(vector.bitcast(T.vec(4, i32), x_vec)) + return parts + + # Wave/lane decomposition (identical to stage2) + coord_wl = idx2crd(tx, layout_tx_wave_lane) + wave_id = layout_get(coord_wl, 0) + lane_id = layout_get(coord_wl, 1) + coord_l16 = idx2crd(lane_id, layout_lane16) + lane_div_16 = layout_get(coord_l16, 0) + lane_mod_16 = layout_get(coord_l16, 1) + row_a_lds = lane_mod_16 + col_offset_base = lane_div_16 * arith.constant(16, index=True) + + num_acc_n = n_per_wave // 16 + c_n_per_wave = arith.constant(n_per_wave, index=True) + wave_n_id = wave_id % arith.constant(num_waves, index=True) + n_tile_base = wave_n_id * c_n_per_wave + + # N-tile precompute for gate AND up weights + gate_n_intra_list = [] + gate_n_blk_list = [] + up_n_intra_list = [] + up_n_blk_list = [] + col_g_list = [] + c_n0_static = experts * (2 * inter_dim) // 16 + layout_n_blk_intra = fx.make_layout((c_n0_static, 16), stride=(16, 1)) + inter_idx = arith.constant(inter_dim, index=True) + + for i in range_constexpr(num_acc_n): + offset = i * 16 + c_offset = arith.constant(offset, index=True) + if const_expr(not gate_up_interleave): + col_g = by_n + n_tile_base + c_offset + lane_mod_16 + col_g_list.append(col_g) + + global_n = by_n + n_tile_base + c_offset + lane_mod_16 + # Gate/interleave: rows [expert_off, expert_off + 2*inter_dim) + gate_row_w = expert_off_idx + global_n + gate_coord = idx2crd(gate_row_w, layout_n_blk_intra) + gate_n_blk_list.append(layout_get(gate_coord, 0)) + gate_n_intra_list.append(layout_get(gate_coord, 1)) + if const_expr(not mock_gate_only and not gate_up_interleave): + up_row_w = gate_row_w + inter_idx + up_coord = idx2crd(up_row_w, layout_n_blk_intra) + up_n_blk_list.append(layout_get(up_coord, 0)) + up_n_intra_list.append(layout_get(up_coord, 1)) + + if const_expr(gate_up_interleave): + _gui_num_acc_n_out = num_acc_n // pack_N + for _gui_i in range_constexpr(_gui_num_acc_n_out): + _gui_offset = _gui_i * 16 + _gui_c_offset = arith.constant(_gui_offset, index=True) + _gui_col_g = ( + (by_n + n_tile_base) // arith.constant(2, index=True) + + _gui_c_offset + + lane_mod_16 + ) + col_g_list.append(_gui_col_g) + + m_repeat = tile_m // 16 + k_unroll = tile_k_bytes // 128 + k_unroll_packed = k_unroll // pack_K + m_repeat_packed = m_repeat // pack_M + num_acc_n_packed = num_acc_n // pack_N + + _K_per_ku = tile_k // k_unroll + _pad_k_elems = ( + (model_dim_pad % tile_k) + if (not _is_splitk and model_dim_pad > 0) + else 0 + ) + _pad_ku_skip = _pad_k_elems // _K_per_ku + _tail_ku = k_unroll - _pad_ku_skip + _tail_ku_packed = ( + (_tail_ku + pack_K - 1) // pack_K if _pad_ku_skip > 0 else None + ) + + # B load for gate and up separately + def load_b_packs_k64(base_k, ku: int, n_blk, n_intra): + c64 = arith.constant(64, index=True) + base_k_bytes = base_k * arith.constant( + int(b_elem_bytes), index=True + ) + k0 = base_k_bytes // c64 + arith.constant(ku, index=True) + k1 = lane_div_16 + coord_pack = (n_blk, k0, k1, n_intra, arith.constant(0, index=True)) + idx_pack = crd2idx(coord_pack, layout_b) + vec_elems = kpack_bytes // int(b_elem_bytes) + b16 = _buffer_load_vec( + buffer_ops, + vector, + w_rsrc, + idx_pack, + elem_type=_w_elem_type(), + vec_elems=vec_elems, + elem_bytes=b_elem_bytes, + offset_in_bytes=(b_elem_bytes == 1), + cache_modifier=b_nt, + ) + b_i64x2 = vector.bitcast(vec2_i64, b16) + b0 = vector.extract( + b_i64x2, static_position=[0], dynamic_position=[] + ) + b1 = vector.extract( + b_i64x2, static_position=[1], dynamic_position=[] + ) + return b0, b1 + + def load_b_tile(base_k, ku_limit=k_unroll): + """Load B tiles. Returns (gate_b_tile, up_b_tile). + When mock_gate_only or gate_up_interleave, up_b_tile is None.""" + gate_b_tile = [] + up_b_tile = ( + [] if (not mock_gate_only and not gate_up_interleave) else None + ) + for ku in range_constexpr(ku_limit): + g_packs0, g_packs1 = [], [] + u_packs0, u_packs1 = [], [] + for ni in range_constexpr(num_acc_n): + gb0, gb1 = load_b_packs_k64( + base_k, ku, gate_n_blk_list[ni], gate_n_intra_list[ni] + ) + g_packs0.append(gb0) + g_packs1.append(gb1) + if const_expr( + not mock_gate_only and not gate_up_interleave + ): + ub0, ub1 = load_b_packs_k64( + base_k, ku, up_n_blk_list[ni], up_n_intra_list[ni] + ) + u_packs0.append(ub0) + u_packs1.append(ub1) + gate_b_tile.append((g_packs0, g_packs1)) + if const_expr(not mock_gate_only and not gate_up_interleave): + up_b_tile.append((u_packs0, u_packs1)) + return gate_b_tile, up_b_tile + + # Pre-compute scale base element indices (K-loop invariant). + # idx = mni * stride_n0 + ku * stride_k0 + k_lane * stride_klane + n_lane + # Split into: base_elem = mni * stride_n0 + lane_elem (invariant) + # k_elem = ku * stride_k0 (per-iteration) + _scale_lane_elem = ( + lane_div_16 * layout_b_scale.stride_klane + lane_mod_16 + ) + + _gate_scale_bases = [] + _up_scale_bases = [] + for _ni in range_constexpr(num_acc_n_packed): + _col_base = ( + by_n + + n_tile_base + + arith.constant(_ni * 16 * pack_N, index=True) + ) + _gate_mni = (expert_off_idx + _col_base) // arith.constant( + 32, index=True + ) + _gate_scale_bases.append( + _gate_mni * layout_b_scale.stride_n0 + _scale_lane_elem + ) + if const_expr(not mock_gate_only and not gate_up_interleave): + _up_mni = ( + expert_off_idx + inter_idx + _col_base + ) // arith.constant(32, index=True) + _up_scale_bases.append( + _up_mni * layout_b_scale.stride_n0 + _scale_lane_elem + ) + + if const_expr(not a_scale_one): + _a_scale_bases = [] + for _mi in range_constexpr(m_repeat_packed): + _a_mni = _mi + bx_m // scale_mn_pack // 16 + _a_scale_bases.append( + _a_mni * layout_a_scale.stride_n0 + _scale_lane_elem + ) + + _c16_idx = arith.constant(16, index=True) + _c2_idx = arith.constant(2, index=True) + _scale_mask_lo = arith.constant(0xFF, type=T.i32) + + _m_half_idx = arith.constant(0, type=T.i32) + _m_half_i32 = arith.constant(0, type=T.i32) + _scale_shift = arith.constant(0, type=T.i32) + _scale_shift_hi = arith.constant(0, type=T.i32) + _n_half_idx = arith.constant(0, type=T.i32) + _n_half_i32 = arith.constant(0, type=T.i32) + _bscale_shift = arith.constant(0, type=T.i32) + _bscale_shift_hi = arith.constant(0, type=T.i32) + if const_expr(pack_M < scale_mn_pack): + _m_half_idx = (bx_m // _c16_idx) % _c2_idx + _m_half_i32 = arith.index_cast(T.i32, _m_half_idx) + _scale_shift = _m_half_i32 * arith.constant(8, type=T.i32) + _scale_shift_hi = _scale_shift + arith.constant(16, type=T.i32) + + if const_expr(pack_N < scale_mn_pack): + _n_half_idx = (n_tile_base // _c16_idx) % _c2_idx + _n_half_i32 = arith.index_cast(T.i32, _n_half_idx) + _bscale_shift = _n_half_i32 * arith.constant(8, type=T.i32) + _bscale_shift_hi = _bscale_shift + arith.constant(16, type=T.i32) + + def _rearrange_a_scale(raw_i32): + """Rearrange scale bytes for pack_M=1: extract m_half's k0,k1 bytes.""" + if const_expr(pack_M >= scale_mn_pack): + return raw_i32 + b_k0 = arith.andi( + arith.shrui(raw_i32, _scale_shift), _scale_mask_lo + ) + b_k1 = arith.andi( + arith.shrui(raw_i32, _scale_shift_hi), _scale_mask_lo + ) + return arith.ori( + b_k0, arith.shli(b_k1, arith.constant(8, type=T.i32)) + ) + + def _rearrange_b_scale(raw_i32): + """Rearrange scale bytes for pack_N=1: extract n_half's k0,k1 bytes.""" + if const_expr(pack_N >= scale_mn_pack): + return raw_i32 + b_k0 = arith.andi( + arith.shrui(raw_i32, _bscale_shift), _scale_mask_lo + ) + b_k1 = arith.andi( + arith.shrui(raw_i32, _bscale_shift_hi), _scale_mask_lo + ) + return arith.ori( + b_k0, arith.shli(b_k1, arith.constant(8, type=T.i32)) + ) + + if const_expr(a_scale_one): + _as1_const = arith.constant(0x7F7F7F7F, type=T.i32) + _as1_vec = vector.from_elements(T.vec(1, T.i32), [_as1_const]) + + def prefetch_ab_scale_tile(base_k, ku_packed_limit=k_unroll_packed): + a_scale_tile = [] + gate_b_scale = [] + up_b_scale = ( + [] if (not mock_gate_only and not gate_up_interleave) else None + ) + for ku in range_constexpr(ku_packed_limit): + k_off = (ku + base_k) * layout_b_scale.stride_k0 + for mi in range_constexpr(m_repeat_packed): + if const_expr(a_scale_one): + a_scale_tile.append(_as1_vec) + else: + s = buffer_ops.buffer_load( + sx_rsrc, + _a_scale_bases[mi] + k_off, + vec_width=1, + dtype=T.i32, + cache_modifier=0, + ) + s = _rearrange_a_scale(s) + a_scale_tile.append( + vector.from_elements(T.vec(1, T.i32), [s]) + ) + for ni in range_constexpr(num_acc_n_packed): + gs = buffer_ops.buffer_load( + sw_rsrc, + _gate_scale_bases[ni] + k_off, + vec_width=1, + dtype=T.i32, + cache_modifier=0, + ) + gs = _rearrange_b_scale(gs) + gate_b_scale.append( + vector.from_elements(T.vec(1, T.i32), [gs]) + ) + if const_expr( + not mock_gate_only and not gate_up_interleave + ): + us = buffer_ops.buffer_load( + sw_rsrc, + _up_scale_bases[ni] + k_off, + vec_width=1, + dtype=T.i32, + cache_modifier=0, + ) + us = _rearrange_b_scale(us) + up_b_scale.append( + vector.from_elements(T.vec(1, T.i32), [us]) + ) + return [a_scale_tile, gate_b_scale, up_b_scale] + + _lds_base_zero = arith.index(0) + + def store_x_tile_to_lds(vec_x_in_parts, lds_buffer): + for i in range_constexpr(num_x_loads): + row_local = x_row_local[i] + col_local_i32 = x_col_local_i32[i] + if const_expr(x_load_bytes == 16): + lds_store_16b_xor16( + arith, + vector, + lds_memref=lds_buffer, + vec16_ty=vec16_x, + layout_lds=layout_lds, + row_local=row_local, + col_local_i32=col_local_i32, + tx_c4=arith.index(4), + k_blocks16=k_blocks16, + lds_base=_lds_base_zero, + vec_part_i32x4=vec_x_in_parts[i], + elem_bytes=elem_bytes, + ) + + if const_expr(use_async_copy): + _dma_bytes = 16 + _wave_size = 64 + _eff_bytes_per_buffer = ( + int(tile_m) * int(_eff_lds_stride) * int(a_elem_bytes) + ) + _num_dma_loads = max( + 1, _eff_bytes_per_buffer // (total_threads * _dma_bytes) + ) + + def dma_x_tile_to_lds(base_k, lds_buffer): + c4_idx = arith.index(4) + base_k_div4 = ( + (base_k / c_a_pack) + * arith.constant(int(elem_bytes), index=True) + ) / arith.index(4) + + lds_ptr_i64 = None + for i in range_constexpr(_num_dma_loads): + row_local_i = x_row_local[i] + col_local_i32_i = x_col_local_i32[i] + col_local_sw = swizzle_xor16( + row_local_i, col_local_i32_i * c4_idx, k_blocks16 + ) + row_k_dw = x_row_base_div4[i] + base_k_div4 + global_byte_idx = row_k_dw * c4_idx + col_local_sw + global_offset = arith.index_cast(T.i32, global_byte_idx) + + if const_expr(i == 0): + lds_addr = memref.extract_aligned_pointer_as_index( + lds_buffer + ) + wave_id * arith.constant( + _wave_size * _dma_bytes, index=True + ) + lds_ptr_i64 = rocdl.readfirstlane( + T.i64, arith.index_cast(T.i64, lds_addr) + ) + else: + lds_ptr_i64 = lds_ptr_i64 + arith.constant( + total_threads * _dma_bytes, type=T.i64 + ) + + lds_ptr_type = ir.Type.parse("!llvm.ptr<3>") + lds_ptr = llvm.inttoptr(lds_ptr_type, lds_ptr_i64) + + rocdl.raw_ptr_buffer_load_lds( + x_rsrc, + lds_ptr, + arith.constant(_dma_bytes, type=T.i32), + global_offset, + arith.constant(0, type=T.i32), + arith.constant(0, type=T.i32), + arith.constant(0, type=T.i32), + ) + + def prefetch_x_to_lds(base_k, lds_buffer): + dma_x_tile_to_lds(base_k, lds_buffer) + + def lds_load_packs_k64(curr_row_a_lds, col_base, lds_buffer): + col_base_swz_bytes = swizzle_xor16( + curr_row_a_lds, col_base, k_blocks16 + ) + col_base_swz = ( + col_base_swz_bytes + if elem_bytes == 1 + else (col_base_swz_bytes / arith.index(2)) + ) + idx_a16 = crd2idx([curr_row_a_lds, col_base_swz], layout_lds) + loaded_a16 = vector.load_op(vec16_x, lds_buffer, [idx_a16]) + a_i64x2 = vector.bitcast(vec2_i64, loaded_a16) + a0 = vector.extract( + a_i64x2, static_position=[0], dynamic_position=[] + ) + a1 = vector.extract( + a_i64x2, static_position=[1], dynamic_position=[] + ) + return a0, a1 + + def prefetch_full_a_from_lds(lds_buffer, ku_limit=k_unroll): + """Load entire A tile from LDS into registers before compute.""" + a_regs = [] + for k_idx in range_constexpr(ku_limit): + col_base = col_offset_base + (k_idx * 128) // a_elem_vec_pack + for mi_idx in range_constexpr(m_repeat): + mi_val = arith.constant(mi_idx * 16, index=True) + curr_row = row_a_lds + mi_val + a0, a1 = lds_load_packs_k64(curr_row, col_base, lds_buffer) + if const_expr(is_f8_a): + a2, a3 = lds_load_packs_k64( + curr_row, col_base + 64, lds_buffer + ) + a_regs.append((a0, a1, a2, a3)) + else: + a_regs.append((a0, a1)) + return a_regs + + # Compute tile: gate + up MFMA interleaved, same A data, different B data. + # Two accumulator sets; after all K tiles, acc = acc_gate + acc_up (f32 add). + def compute_tile( + acc_gate_in, + acc_up_in, + gate_b_tile_in, + up_b_tile_in, + a_tile_regs, + a_scale=None, + gate_b_scale=None, + up_b_scale=None, + *, + prefetch_epilogue=False, + ku_count=k_unroll, + ): + gate_list = list(acc_gate_in) + _single_b = mock_gate_only or gate_up_interleave + up_list = None if _single_b else list(acc_up_in) + mfma_res_ty = vec4_f32 + epilogue_pf = None + bias_pf = None + if const_expr(prefetch_epilogue): + if const_expr(enable_bias): + bias_pf = [] + for ni in range_constexpr(num_acc_n): + if const_expr(gate_up_interleave): + _logical_col = ( + (by_n + n_tile_base) + // arith.constant(2, index=True) + + arith.constant((ni // 2) * 16, index=True) + + lane_mod_16 ) - llvm.StoreOp( - store_raw, out_ptr_v, alignment=2, nontemporal=True + _up_off = ( + inter_idx + if (ni % 2 == 1) + else arith.constant(0, index=True) + ) + bias_offset = ( + expert_off_idx + _up_off + _logical_col ) else: - packed_raw = ( - packed_i32._value - if hasattr(packed_i32, "_value") - else packed_i32 + global_n = ( + by_n + + n_tile_base + + arith.constant(ni * 16, index=True) + + lane_mod_16 ) - llvm.StoreOp( - packed_raw, out_ptr_v, alignment=4, nontemporal=True + bias_offset = expert_off_idx + global_n + bias_pf.append( + buffer_ops.buffer_load( + bias_rsrc, bias_offset, vec_width=1, dtype=f32 ) - - elif const_expr(_need_fp8): - scaled_vals = [] - for i in range_constexpr(_e_vec): - scaled_vals.append(frag_vals[i] * quant_scale) - - ptr_addr_idx = row_byte_base + col_g0 - if const_expr(_e_vec <= 4): - packed_i32 = _c0_i32 - for _w in range_constexpr(_e_vec // 2): - packed_i32 = rocdl.cvt_pk_fp8_f32( - T.i32, - scaled_vals[2 * _w], - scaled_vals[2 * _w + 1], - packed_i32, - _w, - ) - out_ptr_v = _idx_to_llvm_ptr(ptr_addr_idx) - if const_expr(_e_vec == 2): - store_val = arith.TruncIOp(T.i16, packed_i32) - store_raw = ( - store_val._value - if hasattr(store_val, "_value") - else store_val - ) - llvm.StoreOp( - store_raw, - out_ptr_v, - alignment=2, - nontemporal=True, - ) - else: - packed_raw = ( - packed_i32._value - if hasattr(packed_i32, "_value") - else packed_i32 - ) - llvm.StoreOp( - packed_raw, - out_ptr_v, - alignment=4, - nontemporal=True, - ) - else: - for _wg in range_constexpr(_e_vec // 4): - _b = _wg * 4 - packed_w = _c0_i32 - packed_w = rocdl.cvt_pk_fp8_f32( - T.i32, - scaled_vals[_b], - scaled_vals[_b + 1], - packed_w, - 0, - ) - packed_w = rocdl.cvt_pk_fp8_f32( - T.i32, - scaled_vals[_b + 2], - scaled_vals[_b + 3], - packed_w, - 1, - ) - word_ptr = ptr_addr_idx + arith.constant( - _wg * 4, index=True - ) - out_ptr_v = _idx_to_llvm_ptr(word_ptr) - packed_raw = ( - packed_w._value - if hasattr(packed_w, "_value") - else packed_w + ) + tw_pf = None + if const_expr(doweight_stage1): + tw_pf = [] + lane_div_16_mul4_pf = lane_div_16 * arith.index(4) + ii_idx_list_pf = [ + arith.constant(ii, index=True) for ii in range(4) + ] + for mi in range_constexpr(m_repeat): + mi_base_pf = arith.constant(mi * 16, index=True) + for ii in range_constexpr(4): + row_off_pf = ( + lane_div_16_mul4_pf + ii_idx_list_pf[ii] + ) + sorted_row_pf = bx_m + mi_base_pf + row_off_pf + tw_pf.append( + buffer_ops.buffer_load( + sorted_w_rsrc, + sorted_row_pf, + vec_width=1, + dtype=f32, ) - llvm.StoreOp( - packed_raw, - out_ptr_v, - alignment=4, - nontemporal=True, + ) + epilogue_pf = (None, tw_pf, bias_pf) + + c0_i64 = arith.constant(0, type=T.i64) + vec4_i64 = T.vec(4, T.i64) + vec8_i32 = T.vec(8, T.i32) + + def pack_i64x4_to_i32x8(x0, x1, x2, x3): + v4 = vector.from_elements(vec4_i64, [x0, x1, x2, x3]) + return vector.bitcast(vec8_i32, v4) + + _eff_packed = (ku_count + pack_K - 1) // pack_K + # B-major: fix B (ni), cycle A (mi) -- B from VMEM stays + # in registers while A from LDS is repacked per mi. + for ku128 in range_constexpr(_eff_packed): + for ni in range_constexpr(num_acc_n_packed): + gate_bs_i32 = gate_b_scale[ku128 * num_acc_n_packed + ni] + gate_bs_val = vector.extract( + gate_bs_i32, + static_position=[0], + dynamic_position=[], + ) + if const_expr(not _single_b): + up_bs_i32 = up_b_scale[ku128 * num_acc_n_packed + ni] + up_bs_val = vector.extract( + up_bs_i32, static_position=[0], dynamic_position=[] + ) + for ikxdl in range_constexpr(pack_K): + k_idx = ku128 * pack_K + ikxdl + if const_expr(k_idx < ku_count): + gate_bp0, gate_bp1 = gate_b_tile_in[k_idx] + if const_expr(not _single_b): + up_bp0, up_bp1 = up_b_tile_in[k_idx] + for inxdl in range_constexpr(pack_N): + ni_idx = ni * pack_N + inxdl + gb0 = gate_bp0[ni_idx] + gb1 = gate_bp1[ni_idx] + gb128 = pack_i64x4_to_i32x8( + gb0, gb1, c0_i64, c0_i64 ) + if const_expr(not _single_b): + ub0 = up_bp0[ni_idx] + ub1 = up_bp1[ni_idx] + ub128 = pack_i64x4_to_i32x8( + ub0, ub1, c0_i64, c0_i64 + ) + for mi in range_constexpr(m_repeat_packed): + a_scale_i32 = a_scale[ + ku128 * m_repeat_packed + mi + ] + a_scale_val = vector.extract( + a_scale_i32, + static_position=[0], + dynamic_position=[], + ) + for imxdl in range_constexpr(pack_M): + mi_idx = mi * pack_M + imxdl + _a_reg_idx = k_idx * m_repeat + mi_idx + if const_expr(is_f8_a): + a0, a1, a2, a3 = a_tile_regs[ + _a_reg_idx + ] + a128 = pack_i64x4_to_i32x8( + a0, a1, a2, a3 + ) + else: + a0, a1 = a_tile_regs[_a_reg_idx] + a128 = pack_i64x4_to_i32x8( + a0, a1, c0_i64, c0_i64 + ) + acc_idx = mi_idx * num_acc_n + ni_idx + gate_list[acc_idx] = ( + rocdl.mfma_scale_f32_16x16x128_f8f6f4( + mfma_res_ty, + [ + a128, + gb128, + gate_list[acc_idx], + cbsz, + blgp, + ikxdl * pack_M + imxdl, + a_scale_val, + ikxdl * pack_N + inxdl, + gate_bs_val, + ], + ) + ) + if const_expr(not _single_b): + up_list[acc_idx] = ( + rocdl.mfma_scale_f32_16x16x128_f8f6f4( + mfma_res_ty, + [ + a128, + ub128, + up_list[acc_idx], + cbsz, + blgp, + ikxdl * pack_M + imxdl, + a_scale_val, + ikxdl * pack_N + inxdl, + up_bs_val, + ], + ) + ) + return gate_list, up_list, epilogue_pf + + def load_a_subtile(k_idx, mi_idx, lds_buffer): + """Load a single A sub-tile from LDS (one ds_read).""" + col_base = col_offset_base + (k_idx * 128) // a_elem_vec_pack + mi_val = arith.constant(mi_idx * 16, index=True) + curr_row = row_a_lds + mi_val + a0, a1 = lds_load_packs_k64(curr_row, col_base, lds_buffer) + if const_expr(is_f8_a): + a2, a3 = lds_load_packs_k64(curr_row, col_base + 64, lds_buffer) + return (a0, a1, a2, a3) + else: + return (a0, a1) + + _single_b_pipe = mock_gate_only or gate_up_interleave + + def compute_bmajor_mfma_phase( + all_a_tiles, + gate_b_single, + up_b_single, + a_scale_vals, + gate_bs_val, + up_bs_val, + gate_list, + up_list, + k_idx, + ni_idx, + ikxdl, + inxdl, + ): + """B-major MFMA: fix one B (ni), cycle all A tiles (mi). + + Packs B once and reuses across all mi iterations. + A tiles come from LDS (already available, no VMEM wait). + + all_a_tiles: flat list indexed by [k*m_repeat + mi]. + gate_b_single/up_b_single: (b0, b1) for one specific ni. + When _single_b_pipe (mock_gate_only or interleave), up_b_single is None. + a_scale_vals: list of A scale scalars indexed by mi_packed. + """ + c0_i64 = arith.constant(0, type=T.i64) + vec4_i64 = T.vec(4, T.i64) + vec8_i32 = T.vec(8, T.i32) + + def _pack(x0, x1, x2, x3): + v4 = vector.from_elements(vec4_i64, [x0, x1, x2, x3]) + return vector.bitcast(vec8_i32, v4) + + mfma_res_ty = vec4_f32 + gb128 = _pack(gate_b_single[0], gate_b_single[1], c0_i64, c0_i64) + if const_expr(not _single_b_pipe): + ub128 = _pack(up_b_single[0], up_b_single[1], c0_i64, c0_i64) + + for mi_p in range_constexpr(m_repeat_packed): + a_scale_val = a_scale_vals[mi_p] + for imxdl in range_constexpr(pack_M): + mi_idx = mi_p * pack_M + imxdl + a_reg = all_a_tiles[k_idx * m_repeat + mi_idx] + + if const_expr(is_f8_a): + a128 = _pack(a_reg[0], a_reg[1], a_reg[2], a_reg[3]) + else: + a128 = _pack(a_reg[0], a_reg[1], c0_i64, c0_i64) + + acc_idx = mi_idx * num_acc_n + ni_idx + gate_list[acc_idx] = rocdl.mfma_scale_f32_16x16x128_f8f6f4( + mfma_res_ty, + [ + a128, + gb128, + gate_list[acc_idx], + cbsz, + blgp, + ikxdl * pack_M + imxdl, + a_scale_val, + ikxdl * pack_N + inxdl, + gate_bs_val, + ], + ) + if const_expr(not _single_b_pipe): + up_list[acc_idx] = ( + rocdl.mfma_scale_f32_16x16x128_f8f6f4( + mfma_res_ty, + [ + a128, + ub128, + up_list[acc_idx], + cbsz, + blgp, + ikxdl * pack_M + imxdl, + a_scale_val, + ikxdl * pack_N + inxdl, + up_bs_val, + ], + ) + ) + + def _interleaved_half( + lds_read, + lds_write, + next_k_dma_py, + next_k_load, + prev_a_tile, + prev_gate_w, + prev_up_w, + prev_a_scale, + prev_gate_bs, + prev_up_bs, + acc_gate, + acc_up, + ): + """One flatmm-style interleaved half-iteration (deep pipeline). + + Generalized for arbitrary m_repeat (block_m=32, 64, ...). + DMA targets lds_write (OTHER buffer) while ds_read uses + lds_read (already DMA'd in previous half). + + Interleaving schedule (per half): + Phase 0: scale VMEM + 2 ds_read(A) -> 4 MFMA(prev) + Phase 1..N: B VMEM(distributed) + 2 ds_read(A, if avail) -> 4 MFMA(prev) + Phase N+1..: remaining B VMEM -> 4 MFMA(prev) + """ + _abs_k = k_base_idx + arith.constant(next_k_load, index=True) + _bk = _abs_k // arith.constant(2, index=True) + _sk = _abs_k // arith.constant(pack_K * 128, index=True) + _k_off = _sk * layout_b_scale.stride_k0 - if const_expr(_need_sort): - col_g0_i32 = arith.index_cast(T.i32, col_g0) - is_scale_writer = arith.cmpi( - CmpIPredicate.eq, col_g0_i32 & _c31_i32, _c0_i32 + rocdl.sched_barrier(0) + rocdl.s_waitcnt(_vmcnt_before_barrier) + _barrier() + rocdl.sched_barrier(0) + + # DMA A to OTHER buffer (for next half), non-blocking + _abs_k_dma = k_base_idx + arith.constant(next_k_dma_py, index=True) + if const_expr(use_async_copy and next_k_dma_py < int(_k_dim)): + prefetch_x_to_lds(_abs_k_dma, lds_write) + if const_expr(not use_async_copy): + _x_regs = load_x_tile(_abs_k_dma) + + # ---- Extract previous scale values ---- + _prev_asvs = [] + for _mi_p in range_constexpr(m_repeat_packed): + _prev_asvs.append( + vector.extract( + prev_a_scale[_mi_p], + static_position=[0], + dynamic_position=[], + ) + ) + _prev_gsv_list = [] + for _gs_ni in range_constexpr(num_acc_n_packed): + _prev_gsv_list.append( + vector.extract( + prev_gate_bs[_gs_ni], + static_position=[0], + dynamic_position=[], + ) + ) + if const_expr(not _single_b_pipe): + _prev_usv_list = [] + for _us_ni in range_constexpr(num_acc_n_packed): + _prev_usv_list.append( + vector.extract( + prev_up_bs[_us_ni], + static_position=[0], + dynamic_position=[], ) - _if_scale = scf.IfOp(is_scale_writer) - with ir.InsertionPoint(_if_scale.then_block): - row_i32_s = arith.index_cast(T.i32, row) - col_s_i32 = col_g0_i32 >> _c5_i32 - d0 = row_i32_s >> _c5_i32 - d1 = (row_i32_s >> _c4_i32) & _c1_i32 - d2 = row_i32_s & _c15_i32 - d3 = col_s_i32 >> _c3_i32 - d4 = (col_s_i32 >> _c2_i32) & _c1_i32 - d5 = col_s_i32 & _c3_i32 - byte_off = ( - d0 * _n32_sort - + d3 * _c256_i32 - + d5 * _c64_i32 - + d2 * _c4_i32 - + d4 * _c2_i32 - + d1 + ) + + # ---- Execute phases from unified schedule ---- + _a_all = {} + _b_gate_all = {} + _b_up_all = {} + + for _p in range_constexpr(_pipe_n_phases): + # Scale VMEM loads (phase 0 only) + if const_expr(_pp_has_scale[_p]): + _new_as_list = [] + for _mi_p in range_constexpr(m_repeat_packed): + if const_expr(a_scale_one): + _new_as_list.append(_as1_const) + else: + _raw_as = buffer_ops.buffer_load( + sx_rsrc, + _a_scale_bases[_mi_p] + _k_off, + vec_width=1, + dtype=T.i32, + cache_modifier=0, ) - e8m0_i8 = arith.TruncIOp(T.i8, e8m0_biased) - buffer_ops.buffer_store( - e8m0_i8, - sorted_scale_rsrc, - byte_off, - offset_is_bytes=True, + _new_as_list.append(_rearrange_a_scale(_raw_as)) + _new_gs_list = [] + for _gs_ni in range_constexpr(num_acc_n_packed): + _gs_raw = buffer_ops.buffer_load( + sw_rsrc, + _gate_scale_bases[_gs_ni] + _k_off, + vec_width=1, + dtype=T.i32, + cache_modifier=0, + ) + _new_gs_list.append(_rearrange_b_scale(_gs_raw)) + if const_expr(not _single_b_pipe): + _new_us_list = [] + for _us_ni in range_constexpr(num_acc_n_packed): + _us_raw = buffer_ops.buffer_load( + sw_rsrc, + _up_scale_bases[_us_ni] + _k_off, + vec_width=1, + dtype=T.i32, + cache_modifier=0, ) - scf.YieldOp([]) - elif const_expr(_is_splitk): - col_idx = col_g0 + arith.constant(_sk_n_offset[0], index=True) - byte_off_col = col_idx * arith.constant( - out_elem_bytes, index=True + _new_us_list.append(_rearrange_b_scale(_us_raw)) + + # B VMEM loads + for _b_j in range_constexpr(len(_pp_b_loads[_p])): + _b_type, _b_ku, _b_ni = _pp_b_loads[_p][_b_j] + if const_expr(_b_type == "gate"): + _b_gate_all[(_b_ku, _b_ni)] = load_b_packs_k64( + _bk, + _b_ku, + gate_n_blk_list[_b_ni], + gate_n_intra_list[_b_ni], + ) + else: + _b_up_all[(_b_ku, _b_ni)] = load_b_packs_k64( + _bk, + _b_ku, + up_n_blk_list[_b_ni], + up_n_intra_list[_b_ni], + ) + + # A ds_reads + rocdl.sched_barrier(0) + for _a_j in range_constexpr(len(_pp_a_reads[_p])): + _ak, _ami = _pp_a_reads[_p][_a_j] + _a_all[(_ak, _ami)] = load_a_subtile( + _ak, + _ami, + lds_read, ) - ptr_addr_idx = row_byte_base + byte_off_col - out_ptr_v = _idx_to_llvm_ptr(ptr_addr_idx) - frag_v = frag._value if hasattr(frag, "_value") else frag - llvm.AtomicRMWOp( - llvm.AtomicBinOp.fadd, - out_ptr_v, - frag_v, - llvm.AtomicOrdering.monotonic, - syncscope="agent", - alignment=_e_vec_sk * out_elem_bytes, + rocdl.sched_barrier(0) + + # MFMAs on prev data + rocdl.s_setprio(1) + for _m_j in range_constexpr(len(_pp_mfma[_p])): + _k_idx, _ni_idx, _ikxdl, _inxdl, _ku128 = _pp_mfma[_p][_m_j] + _ni_packed_idx = _ni_idx // pack_N + _up_b_single = ( + ( + prev_up_w[_k_idx][0][_ni_idx], + prev_up_w[_k_idx][1][_ni_idx], + ) + if not _single_b_pipe + else None ) - else: - col_idx = col_g0 - byte_off_col = col_idx * arith.constant( - out_elem_bytes, index=True + compute_bmajor_mfma_phase( + prev_a_tile, + ( + prev_gate_w[_k_idx][0][_ni_idx], + prev_gate_w[_k_idx][1][_ni_idx], + ), + _up_b_single, + _prev_asvs, + _prev_gsv_list[_ni_packed_idx], + ( + _prev_usv_list[_ni_packed_idx] + if not _single_b_pipe + else None + ), + acc_gate, + acc_up, + _k_idx, + _ni_idx, + _ikxdl, + _inxdl, ) - ptr_addr_idx = row_byte_base + byte_off_col - out_ptr_v = _idx_to_llvm_ptr(ptr_addr_idx) - frag_v = frag._value if hasattr(frag, "_value") else frag - llvm.StoreOp( - frag_v, - out_ptr_v, - alignment=_e_vec * out_elem_bytes, - nontemporal=True, + rocdl.s_setprio(0) + rocdl.sched_barrier(0) + + # ---- Assemble loaded data for next half-iteration ---- + cur_a_tile = [] + for _k in range_constexpr(k_unroll): + for _mi in range_constexpr(m_repeat): + cur_a_tile.append(_a_all[(_k, _mi)]) + + cur_gate_w = [] + cur_up_w = None if _single_b_pipe else [] + for ku in range_constexpr(k_unroll): + g_packs0, g_packs1 = [], [] + u_packs0, u_packs1 = [], [] + for ni in range_constexpr(num_acc_n): + g = _b_gate_all[(ku, ni)] + g_packs0.append(g[0]) + g_packs1.append(g[1]) + if const_expr(not _single_b_pipe): + u = _b_up_all[(ku, ni)] + u_packs0.append(u[0]) + u_packs1.append(u[1]) + cur_gate_w.append((g_packs0, g_packs1)) + if const_expr(not _single_b_pipe): + cur_up_w.append((u_packs0, u_packs1)) + + cur_a_scale = [] + for _mi_p in range_constexpr(m_repeat_packed): + cur_a_scale.append( + vector.from_elements( + T.vec(1, T.i32), + [_new_as_list[_mi_p]], ) + ) + cur_gate_bs = [] + for _gs_ni in range_constexpr(num_acc_n_packed): + cur_gate_bs.append( + vector.from_elements( + T.vec(1, T.i32), [_new_gs_list[_gs_ni]] + ) + ) + if const_expr(not _single_b_pipe): + cur_up_bs = [] + for _us_ni in range_constexpr(num_acc_n_packed): + cur_up_bs.append( + vector.from_elements( + T.vec(1, T.i32), [_new_us_list[_us_ni]] + ) + ) + else: + cur_up_bs = None + + if const_expr(not use_async_copy): + store_x_tile_to_lds(_x_regs, lds_write) + + return ( + cur_a_tile, + cur_gate_w, + cur_up_w, + cur_a_scale, + cur_gate_bs, + cur_up_bs, + acc_gate, + acc_up, + ) + + # Pipeline (split ping/pong allocators) + rocdl.sched_barrier(0) - _frag_elem = ( - ir.F32Type.get() - if _need_quant - else (ir.BF16Type.get() if out_is_bf16 else ir.F16Type.get()) + k0 = k_base_idx + if const_expr(use_async_copy): + prefetch_x_to_lds(k0, lds_x_pong) + else: + x_regs0 = load_x_tile(k0) + store_x_tile_to_lds(x_regs0, lds_x_pong) + rocdl.sched_barrier(0) + _k0_scale = k_base_idx // arith.constant(pack_K * 128, index=True) + a_scale_pong, gate_bs_pong, up_bs_pong = prefetch_ab_scale_tile( + _k0_scale + ) + _c_tile_m_idx = arith.constant(tile_m, index=True) + _tid_in_range = arith.cmpi(CmpIPredicate.ult, tx, _c_tile_m_idx) + _if_tid = scf.IfOp(_tid_in_range) + with ir.InsertionPoint(_if_tid.then_block): + _tid_row = bx_m + tx + _tid_val = buffer_ops.buffer_load( + sorted_rsrc, _tid_row, vec_width=1, dtype=T.i32 ) + _tid_vec1 = vector.from_elements(T.vec(1, T.i32), [_tid_val]) + vector.store(_tid_vec1, lds_tid, [tx]) + scf.YieldOp([]) - if const_expr(gate_up_interleave and not _is_splitk): - # gui without splitk: acc has activation applied, halved N - _gui_eff_n = _gui_out_n - _gui_tile_n = tile_n // 2 - _gui_cshuffle_nlane = min(32, _gui_tile_n // _e_vec) - _gui_by_n = by_n / arith.constant(2, index=True) - _gui_n_tile_base = n_tile_base / arith.constant(2, index=True) - c_shuffle_epilog( - arith=arith, - vector=vector, - gpu=gpu, - scf=scf, - range_constexpr=range_constexpr, - tile_m=tile_m, - tile_n=_gui_tile_n, - e_vec=_e_vec, - cshuffle_nlane=_gui_cshuffle_nlane, - block_size=total_threads, - m_repeat=m_repeat, - num_acc_n=_gui_eff_n, - tx=tx, - lane_div_16=lane_div_16, - lane_mod_16=lane_mod_16, - bx_m=bx_m, - by_n=_gui_by_n, - n_tile_base=_gui_n_tile_base, - lds_out=lds_out, - frag_elem_type=_frag_elem, - write_row_to_lds=write_row_to_lds, - precompute_row=precompute_row, - store_pair=store_pair, + acc_gate = [acc_init] * num_acc_n * m_repeat + acc_up = ( + [acc_init] * num_acc_n * m_repeat if not _single_b_pipe else None + ) + + _k1 = k_base_idx + arith.constant(tile_k, index=True) + rocdl.sched_barrier(0) + if const_expr(use_async_copy): + prefetch_x_to_lds(_k1, lds_x_ping) + else: + _x_regs_prime = load_x_tile(_k1) + store_x_tile_to_lds(_x_regs_prime, lds_x_ping) + + _k0_b = k_base_idx // arith.constant(2, index=True) + gate_w0, up_w0 = load_b_tile(_k0_b) + # Prime the deep pipeline: DMA K=tile_k -> ping (1 tile ahead) + if const_expr(use_async_copy): + rocdl.s_waitcnt(0) + gpu.barrier() + rocdl.sched_barrier(0) + a_tile_pong = prefetch_full_a_from_lds(lds_x_pong) + + rocdl.sched_barrier(0) + rocdl.s_waitcnt(6) + + num_k_tiles_py = int(_k_dim) // int(tile_k) + odd_k_tiles = (num_k_tiles_py % 2) == 1 + tail_tiles = 1 if odd_k_tiles else 2 + k_main2_py = (num_k_tiles_py - tail_tiles) * int(tile_k) + if const_expr(k_main2_py < 0): + k_main2_py = 0 + + gate_w_pong = gate_w0 + up_w_pong = up_w0 + + rocdl.sched_barrier(0) + + if const_expr(k_main2_py > 0): + for k_iv_py in range_constexpr(0, k_main2_py, tile_k * 2): + next_k_load_1 = k_iv_py + tile_k + next_k_load_2 = k_iv_py + tile_k * 2 + next_k_dma_1 = k_iv_py + tile_k * 2 + next_k_dma_2 = k_iv_py + tile_k * 3 + + # Half 1: read ping (DMA'd prev half), DMA->pong, MFMA(pong) + ( + a_tile_ping, + gate_w_ping, + up_w_ping, + a_scale_ping, + gate_bs_ping, + up_bs_ping, + acc_gate, + acc_up, + ) = _interleaved_half( + lds_x_ping, + lds_x_pong, + next_k_dma_1, + next_k_load_1, + a_tile_pong, + gate_w_pong, + up_w_pong, + a_scale_pong, + gate_bs_pong, + up_bs_pong, + acc_gate, + acc_up, ) - elif const_expr(mock_gate_only or (gate_up_interleave and _is_splitk)): - # mock_gate_only: single pass, by_n covers full [0, 2*inter_dim) - _eff_e_vec = _e_vec_sk - acc = acc_gate - c_shuffle_epilog( - arith=arith, - vector=vector, - gpu=gpu, - scf=scf, - range_constexpr=range_constexpr, - tile_m=tile_m, - tile_n=tile_n, - e_vec=_eff_e_vec, - cshuffle_nlane=_cshuffle_nlane_sk, - block_size=total_threads, - m_repeat=m_repeat, - num_acc_n=num_acc_n, - tx=tx, - lane_div_16=lane_div_16, - lane_mod_16=lane_mod_16, - bx_m=bx_m, - by_n=by_n, - n_tile_base=n_tile_base, - lds_out=lds_out, - frag_elem_type=_frag_elem, - write_row_to_lds=write_row_to_lds, - precompute_row=precompute_row, - store_pair=store_pair, - lds_out_split=lds_out_B, + + # Half 2: read pong (DMA'd Half 1), DMA->ping, MFMA(ping) + ( + a_tile_pong, + gate_w_pong, + up_w_pong, + a_scale_pong, + gate_bs_pong, + up_bs_pong, + acc_gate, + acc_up, + ) = _interleaved_half( + lds_x_pong, + lds_x_ping, + next_k_dma_2, + next_k_load_2, + a_tile_ping, + gate_w_ping, + up_w_ping, + a_scale_ping, + gate_bs_ping, + up_bs_ping, + acc_gate, + acc_up, ) - elif const_expr(_is_splitk): - # Two-pass epilogue: gate then up, each with atomic add - _eff_e_vec = _e_vec_sk - # Pass 1: gate - acc = acc_gate - _sk_n_offset[0] = 0 - c_shuffle_epilog( - arith=arith, - vector=vector, - gpu=gpu, - scf=scf, - range_constexpr=range_constexpr, - tile_m=tile_m, - tile_n=tile_n, - e_vec=_eff_e_vec, - cshuffle_nlane=_cshuffle_nlane_sk, - block_size=total_threads, - m_repeat=m_repeat, - num_acc_n=num_acc_n, - tx=tx, - lane_div_16=lane_div_16, - lane_mod_16=lane_mod_16, - bx_m=bx_m, - by_n=by_n, - n_tile_base=n_tile_base, - lds_out=lds_out, - frag_elem_type=_frag_elem, - write_row_to_lds=write_row_to_lds, - precompute_row=precompute_row, - store_pair=store_pair, - lds_out_split=lds_out_B, + # _wave_mod2_b = wave_id % arith.constant(2, index=True) + # _wave_odd = arith.cmpi( + # CmpIPredicate.eq, _wave_mod2_b, arith.constant(1, index=True) + # ) + # _if_wave_odd = scf.IfOp(_wave_odd) + # with ir.InsertionPoint(_if_wave_odd.then_block): + # # gpu.barrier() + # _barrier() + # scf.YieldOp([]) + + if const_expr(odd_k_tiles): + acc_gate, acc_up, epilogue_pf = compute_tile( + acc_gate, + acc_up, + gate_w_pong, + up_w_pong, + a_tile_pong, + a_scale_pong, + gate_bs_pong, + up_bs_pong, + prefetch_epilogue=True, + ku_count=_tail_ku if _pad_ku_skip > 0 else k_unroll, + ) + else: + _k_tail_rel = arith.constant(_k_dim - tile_k, index=True) + k_tail1 = k_base_idx + _k_tail_rel + x_regs_ping = [] + if const_expr(use_async_copy): + prefetch_x_to_lds(k_tail1, lds_x_ping) + else: + x_regs_ping = load_x_tile(k_tail1) + if const_expr(_pad_ku_skip > 0): + gate_w_ping, up_w_ping = load_b_tile( + k_tail1 // arith.constant(2, index=True), + ku_limit=_tail_ku, + ) + a_scale_ping, gate_bs_ping, up_bs_ping = prefetch_ab_scale_tile( + k_tail1 // arith.constant(pack_K * 128, index=True), + ku_packed_limit=_tail_ku_packed, + ) + else: + gate_w_ping, up_w_ping = load_b_tile( + k_tail1 // arith.constant(2, index=True) + ) + a_scale_ping, gate_bs_ping, up_bs_ping = prefetch_ab_scale_tile( + k_tail1 // arith.constant(pack_K * 128, index=True) + ) + acc_gate, acc_up, _ = compute_tile( + acc_gate, + acc_up, + gate_w_pong, + up_w_pong, + a_tile_pong, + a_scale_pong, + gate_bs_pong, + up_bs_pong, + ) + if const_expr(not use_async_copy): + store_x_tile_to_lds(x_regs_ping, lds_x_ping) + rocdl.s_waitcnt(0) + _barrier() + if const_expr(_pad_ku_skip > 0): + a_tile_ping = prefetch_full_a_from_lds( + lds_x_ping, ku_limit=_tail_ku + ) + else: + a_tile_ping = prefetch_full_a_from_lds(lds_x_ping) + acc_gate, acc_up, epilogue_pf = compute_tile( + acc_gate, + acc_up, + gate_w_ping, + up_w_ping, + a_tile_ping, + a_scale_ping, + gate_bs_ping, + up_bs_ping, + prefetch_epilogue=True, + ku_count=_tail_ku if _pad_ku_skip > 0 else k_unroll, + ) + + bias_pf = None + if const_expr(epilogue_pf is not None): + _, _, bias_pf = epilogue_pf + + # Activation helpers (f32 element-wise on vec4_f32) + def _silu_elem(g): + """silu(x) = x * sigmoid(x); HW fast path: exp2, rcp""" + neg_log2e = arith.constant(-1.4426950408889634, type=f32) + t = g * neg_log2e + emu = llvm.call_intrinsic(f32, "llvm.amdgcn.exp2.f32", [t], [], []) + one = arith.constant(1.0, type=f32) + den = one + emu + sig = llvm.call_intrinsic(f32, "llvm.amdgcn.rcp.f32", [den], [], []) + return g * sig + + def _silu_mul_vec4(gate_v4, up_v4): + """Element-wise silu(gate) * up on vec4_f32.""" + result_elems = [] + for ei in range_constexpr(4): + g = vector.extract( + gate_v4, static_position=[ei], dynamic_position=[] + ) + u = vector.extract( + up_v4, static_position=[ei], dynamic_position=[] + ) + result_elems.append(_silu_elem(g) * u) + return vector.from_elements(vec4_f32, result_elems) + + def _swiglu_mul_vec4(gate_v4, up_v4): + """Element-wise swiglu(gate, up) on vec4_f32. + swiglu(g, u) = g * sigmoid(alpha * g) * (u + 1) + with clamping: gate <= limit, -limit <= up <= limit. + """ + result_elems = [] + _alpha = arith.constant(1.702, type=f32) + _limit = arith.constant(7.0, type=f32) + _neg_limit = arith.constant(-7.0, type=f32) + _one = arith.constant(1.0, type=f32) + _neg_log2e = arith.constant(-1.4426950408889634, type=f32) + for ei in range_constexpr(4): + g = vector.extract( + gate_v4, static_position=[ei], dynamic_position=[] + ) + u = vector.extract( + up_v4, static_position=[ei], dynamic_position=[] + ) + g = arith.minimumf(g, _limit) + u = arith.minimumf(u, _limit) + u = arith.maximumf(u, _neg_limit) + t = g * _alpha * _neg_log2e + emu = llvm.call_intrinsic( + f32, "llvm.amdgcn.exp2.f32", [t], [], [] ) + den = _one + emu + sig = llvm.call_intrinsic( + f32, "llvm.amdgcn.rcp.f32", [den], [], [] + ) + result_elems.append(g * sig * (u + _one)) + return vector.from_elements(vec4_f32, result_elems) + + def _act_vec4(gate_v4, up_v4): + """Dispatch activation based on `act` parameter.""" + if const_expr(act == "swiglu"): + return _swiglu_mul_vec4(gate_v4, up_v4) + else: + return _silu_mul_vec4(gate_v4, up_v4) + + # Add bias to raw GEMM accumulators before activation. + # bias layout: [E, 2*inter_dim] flat f32 (non-interleaved: gate then up). + # For gate_up_interleave, map physical column to logical bias offset. + if const_expr(enable_bias and not _is_splitk): + if const_expr(bias_pf is not None): + _bias_gate_vals = bias_pf + else: + _bias_gate_vals = [] + for _ni in range_constexpr(num_acc_n): + if const_expr(gate_up_interleave): + _logical_col = ( + (by_n + n_tile_base) + // arith.constant(2, index=True) + + arith.constant((_ni // 2) * 16, index=True) + + lane_mod_16 + ) + _up_off = ( + inter_idx + if (_ni % 2 == 1) + else arith.constant(0, index=True) + ) + _bias_off = expert_off_idx + _up_off + _logical_col + else: + _bn = ( + by_n + + n_tile_base + + arith.constant(_ni * 16, index=True) + + lane_mod_16 + ) + _bias_off = expert_off_idx + _bn + _bias_gate_vals.append( + buffer_ops.buffer_load( + bias_rsrc, _bias_off, vec_width=1, dtype=f32 + ) + ) + for _mi in range_constexpr(m_repeat): + for _ni in range_constexpr(num_acc_n): + _aidx = _mi * num_acc_n + _ni + _bsplat = vector.from_elements( + vec4_f32, [_bias_gate_vals[_ni]] * 4 + ) + acc_gate[_aidx] = arith.addf(acc_gate[_aidx], _bsplat) + + if const_expr(not (mock_gate_only or gate_up_interleave)): + _bias_up_vals = [] + for _ni in range_constexpr(num_acc_n): + _bn = ( + by_n + + n_tile_base + + arith.constant(_ni * 16, index=True) + + lane_mod_16 + ) + _bias_up_vals.append( + buffer_ops.buffer_load( + bias_rsrc, + expert_off_idx + inter_idx + _bn, + vec_width=1, + dtype=f32, + ) + ) + for _mi in range_constexpr(m_repeat): + for _ni in range_constexpr(num_acc_n): + _aidx = _mi * num_acc_n + _ni + _bsplat = vector.from_elements( + vec4_f32, [_bias_up_vals[_ni]] * 4 + ) + acc_up[_aidx] = arith.addf(acc_up[_aidx], _bsplat) + + if const_expr(gate_up_interleave and not _is_splitk): + _gui_out_n = num_acc_n // pack_N + acc = [None] * (_gui_out_n * m_repeat) + for _mi in range_constexpr(m_repeat): + for _ni in range_constexpr(_gui_out_n): + _g_idx = _mi * num_acc_n + _ni * pack_N + _u_idx = _g_idx + 1 + _out_idx = _mi * _gui_out_n + _ni + acc[_out_idx] = _act_vec4( + acc_gate[_g_idx], acc_gate[_u_idx] + ) + elif const_expr(not _is_splitk): + acc = [None] * (int(num_acc_n) * int(m_repeat)) + for _mi in range_constexpr(m_repeat): + for _ni in range_constexpr(num_acc_n): + _aidx = _mi * num_acc_n + _ni + acc[_aidx] = _silu_mul_vec4(acc_gate[_aidx], acc_up[_aidx]) + + # ---- Epilogue: CShuffle + direct store (accumulate=False) ---- + # Output: out[(t*topk+s) * inter_dim + col] = silu(gate) * up + # For split-K: skip silu, output gate/up separately with atomic add + tw_pf = None + bias_pf = None + if const_expr(epilogue_pf is not None): + _, tw_pf, bias_pf = epilogue_pf + + mask24_i32 = arith.constant(0xFFFFFF) + topk_i32_v = topk_i32 + tokens_i32_v = tokens_i32 + + from flydsl._mlir.dialects import fly as _fly + + _llvm_ptr_ty = ir.Type.parse("!llvm.ptr") + out_base_ptr = _fly.extract_aligned_pointer_as_index( + _llvm_ptr_ty, arg_out + ) + out_base_i64 = llvm.ptrtoint(T.i64, out_base_ptr) + out_base_idx = arith.index_cast(ir.IndexType.get(), out_base_i64) + + if const_expr(lds_out is None): + raise RuntimeError("CShuffle epilogue requires lds_out") + + _apply_weight = doweight_stage1 and not _is_splitk + + def write_row_to_lds( + *, + mi: int, + ii: int, + row_in_tile, + row, + row_base_lds, + col_base_local, + num_acc_n: int, + lds_out, + ): + if const_expr(_apply_weight): + tw_idx = (mi * 4) + ii + if const_expr(tw_pf is not None): + tw = tw_pf[tw_idx] + else: + tw = buffer_ops.buffer_load( + sorted_w_rsrc, row, vec_width=1, dtype=f32 + ) + for ni in range_constexpr(num_acc_n): + col_local = col_base_local + (ni * 16) + acc_idx = mi * num_acc_n + ni + v = vector.extract( + acc[acc_idx], static_position=[ii], dynamic_position=[] + ) + if const_expr(_apply_weight): + v = v * tw + if const_expr(_need_quant): + lds_idx = row_base_lds + col_local + vec1_f32 = T.vec(1, f32) + v1 = vector.from_elements(vec1_f32, [v]) + vector.store(v1, lds_out, [lds_idx], alignment=4) + else: + v_out = arith.trunc_f(out_elem(), v) + lds_idx = row_base_lds + col_local + vec1_out = T.vec(1, out_elem()) + v1 = vector.from_elements(vec1_out, [v_out]) + vector.store(v1, lds_out, [lds_idx], alignment=2) + + _out_row_stride = ( + inter_dim * 2 * out_elem_bytes + if _is_splitk + else ( + inter_dim // 2 + if _need_fp4 + else (inter_dim if _need_fp8 else inter_dim * out_elem_bytes) + ) + ) + + def precompute_row(*, row_local, row): + fused2 = memref.load(lds_tid, [row_local]) + row_i32 = arith.index_cast(T.i32, row) + row_valid0 = arith.cmpi(CmpIPredicate.ult, row_i32, num_valid_i32) + t = fused2 & mask24_i32 + s = fused2 >> 24 + t_ok = arith.cmpi(CmpIPredicate.ult, t, tokens_i32_v) + s_ok = arith.cmpi(CmpIPredicate.ult, s, topk_i32_v) + row_valid = arith.andi(row_valid0, arith.andi(t_ok, s_ok)) + t_idx = arith.index_cast(ir.IndexType.get(), t) + s_idx = arith.index_cast(ir.IndexType.get(), s) + ts_idx = t_idx * arith.constant(topk, index=True) + s_idx + row_byte_base = out_base_idx + ts_idx * arith.constant( + _out_row_stride, index=True + ) + return ((fused2, row_byte_base), row_valid) + + def _idx_to_llvm_ptr(idx_val, addr_space=1): + idx_v = idx_val._value if hasattr(idx_val, "_value") else idx_val + i64_v = arith.index_cast(T.i64, idx_v) + i64_raw = i64_v._value if hasattr(i64_v, "_value") else i64_v + ptr_ty = ir.Type.parse(f"!llvm.ptr<{addr_space}>") + return llvm.inttoptr(ptr_ty, i64_raw) + + _e_vec = _e_vec_s1 + _e_vec_sk = 2 + _cshuffle_nlane = min(32, tile_n // _e_vec) + _cshuffle_nlane_sk = min(32, tile_n // _e_vec_sk) + _num_threads_per_quant_blk = _num_threads_per_quant_blk_s1 + + _c0_i32 = arith.constant(0, type=T.i32) + _c1_i32 = arith.constant(1, type=T.i32) + _c2_i32 = arith.constant(2, type=T.i32) + _c3_i32 = arith.constant(3, type=T.i32) + _c4_i32 = arith.constant(4, type=T.i32) + _c5_i32 = arith.constant(5, type=T.i32) + _c7_i32 = arith.constant(7, type=T.i32) + _c15_i32 = arith.constant(15, type=T.i32) + _c21_i32 = arith.constant(21, type=T.i32) + _c23_i32 = arith.constant(23, type=T.i32) + _c28_i32 = arith.constant(28, type=T.i32) + _c31_i32 = arith.constant(31, type=T.i32) + _c32_i32 = arith.constant(32, type=T.i32) + _c64_i32 = arith.constant(64, type=T.i32) + _c126_i32 = arith.constant(126, type=T.i32) + _c127_i32 = arith.constant(127, type=T.i32) + _c254_i32 = arith.constant(254, type=T.i32) + _c256_i32 = arith.constant(256, type=T.i32) + _c0xFF_i32 = arith.constant(0xFF, type=T.i32) + _c0x200000_i32 = arith.constant(0x200000, type=T.i32) + _c0xFF800000_i32 = arith.constant(0xFF800000, type=T.i32) + _c0x400000_i32 = arith.constant(0x400000, type=T.i32) + _c0x7FFFFF_i32 = arith.constant(0x7FFFFF, type=T.i32) + _c0x80000000_i32 = arith.constant(0x80000000, type=T.i32) + _c0_f32 = arith.constant(0.0, type=T.f32) + + _c8_i32 = arith.constant(8, type=T.i32) + _fp_headroom = 2 if _need_fp4 else (8 if _need_fp8 else 0) + _c_headroom_i32 = arith.constant(_fp_headroom, type=T.i32) + + def _f32_to_e2m1(qx_f32): + """Convert a scaled f32 value to fp4 (e2m1) 4-bit integer.""" + qx = qx_f32.bitcast(T.i32) + s = qx & _c0x80000000_i32 + e = (qx >> _c23_i32) & _c0xFF_i32 + m = qx & _c0x7FFFFF_i32 + adj_exp = arith.maxsi(_c126_i32 - e, _c0_i32) + m_denorm = (_c0x400000_i32 | (m >> _c1_i32)) >> adj_exp + is_denorm = arith.cmpi(CmpIPredicate.ult, e, _c127_i32) + m = arith.select(is_denorm, m_denorm, m) + e = arith.maxsi(e - _c126_i32, _c0_i32) + combined = (e << _c2_i32) | (m >> _c21_i32) + rounded = (combined + _c1_i32) >> _c1_i32 + e2m1 = arith.minui(rounded, _c7_i32) + return (s >> _c28_i32) | e2m1 + + if const_expr(_need_sort): + _n32_sort = _sorted_scale_cols_i32 * _c32_i32 + + # Mutable slot for split-K N-offset (gate=0, up=inter_dim) + _sk_n_offset = [0] + + def store_pair(*, row_local, row, row_ctx, col_pair0, col_g0, frag): + fused, row_byte_base = row_ctx + if const_expr(_need_quant and not _is_splitk): + frag_vals = [] + for i in range_constexpr(_e_vec): + frag_vals.append( + vector.extract( + frag, static_position=[i], dynamic_position=[] + ) + ) + + local_max = _c0_f32 + for i in range_constexpr(_e_vec): + abs_v = llvm.call_intrinsic( + f32, "llvm.fabs.f32", [frag_vals[i]], [], [] + ) + local_max = arith.maximumf(local_max, abs_v) + + for _si in range_constexpr(_num_shuffle_steps_s1): + off = arith.constant(_shuffle_dists_s1[_si], type=T.i32) + peer = local_max.shuffle_xor(off, _c64_i32) + local_max = arith.maximumf(local_max, peer) + + max_i32 = local_max.bitcast(T.i32) + max_rounded = (max_i32 + _c0x200000_i32) & _c0xFF800000_i32 + exp_field = max_rounded >> _c23_i32 + e8m0_biased = arith.maxsi(exp_field - _c_headroom_i32, _c0_i32) + + quant_exp = _c254_i32 - e8m0_biased + quant_scale = (quant_exp << _c23_i32).bitcast(T.f32) + + if const_expr(_need_fp4): + fp4_vals = [] + for i in range_constexpr(_e_vec): + scaled_v = frag_vals[i] * quant_scale + fp4_vals.append(_f32_to_e2m1(scaled_v)) + + packed_i32 = fp4_vals[0] | (fp4_vals[1] << _c4_i32) + for k in range_constexpr(1, _e_vec // 2): + byte_k = fp4_vals[2 * k] | ( + fp4_vals[2 * k + 1] << _c4_i32 + ) + packed_i32 = packed_i32 | ( + byte_k << arith.constant(k * 8, type=T.i32) + ) + + ptr_addr_idx = row_byte_base + col_g0 / arith.constant( + 2, index=True + ) + out_ptr_v = _idx_to_llvm_ptr(ptr_addr_idx) + _pack_bytes = _e_vec // 2 + if const_expr(_pack_bytes == 1): + store_val = arith.TruncIOp(T.i8, packed_i32) + store_raw = ( + store_val._value + if hasattr(store_val, "_value") + else store_val + ) + llvm.StoreOp( + store_raw, out_ptr_v, alignment=1, nontemporal=True + ) + elif const_expr(_pack_bytes == 2): + store_val = arith.TruncIOp(T.i16, packed_i32) + store_raw = ( + store_val._value + if hasattr(store_val, "_value") + else store_val + ) + llvm.StoreOp( + store_raw, out_ptr_v, alignment=2, nontemporal=True + ) + else: + packed_raw = ( + packed_i32._value + if hasattr(packed_i32, "_value") + else packed_i32 + ) + llvm.StoreOp( + packed_raw, out_ptr_v, alignment=4, nontemporal=True + ) - gpu.barrier() + elif const_expr(_need_fp8): + scaled_vals = [] + for i in range_constexpr(_e_vec): + scaled_vals.append(frag_vals[i] * quant_scale) + + ptr_addr_idx = row_byte_base + col_g0 + if const_expr(_e_vec <= 4): + packed_i32 = _c0_i32 + for _w in range_constexpr(_e_vec // 2): + packed_i32 = rocdl.cvt_pk_fp8_f32( + T.i32, + scaled_vals[2 * _w], + scaled_vals[2 * _w + 1], + packed_i32, + _w, + ) + out_ptr_v = _idx_to_llvm_ptr(ptr_addr_idx) + if const_expr(_e_vec == 2): + store_val = arith.TruncIOp(T.i16, packed_i32) + store_raw = ( + store_val._value + if hasattr(store_val, "_value") + else store_val + ) + llvm.StoreOp( + store_raw, + out_ptr_v, + alignment=2, + nontemporal=True, + ) + else: + packed_raw = ( + packed_i32._value + if hasattr(packed_i32, "_value") + else packed_i32 + ) + llvm.StoreOp( + packed_raw, + out_ptr_v, + alignment=4, + nontemporal=True, + ) + else: + for _wg in range_constexpr(_e_vec // 4): + _b = _wg * 4 + packed_w = _c0_i32 + packed_w = rocdl.cvt_pk_fp8_f32( + T.i32, + scaled_vals[_b], + scaled_vals[_b + 1], + packed_w, + 0, + ) + packed_w = rocdl.cvt_pk_fp8_f32( + T.i32, + scaled_vals[_b + 2], + scaled_vals[_b + 3], + packed_w, + 1, + ) + word_ptr = ptr_addr_idx + arith.constant( + _wg * 4, index=True + ) + out_ptr_v = _idx_to_llvm_ptr(word_ptr) + packed_raw = ( + packed_w._value + if hasattr(packed_w, "_value") + else packed_w + ) + llvm.StoreOp( + packed_raw, + out_ptr_v, + alignment=4, + nontemporal=True, + ) - # Pass 2: up - acc = acc_up - _sk_n_offset[0] = inter_dim - c_shuffle_epilog( - arith=arith, - vector=vector, - gpu=gpu, - scf=scf, - range_constexpr=range_constexpr, - tile_m=tile_m, - tile_n=tile_n, - e_vec=_eff_e_vec, - cshuffle_nlane=_cshuffle_nlane_sk, - block_size=total_threads, - m_repeat=m_repeat, - num_acc_n=num_acc_n, - tx=tx, - lane_div_16=lane_div_16, - lane_mod_16=lane_mod_16, - bx_m=bx_m, - by_n=by_n, - n_tile_base=n_tile_base, - lds_out=lds_out, - frag_elem_type=_frag_elem, - write_row_to_lds=write_row_to_lds, - precompute_row=precompute_row, - store_pair=store_pair, - lds_out_split=lds_out_B, + if const_expr(_need_sort): + col_g0_i32 = arith.index_cast(T.i32, col_g0) + is_scale_writer = arith.cmpi( + CmpIPredicate.eq, col_g0_i32 & _c31_i32, _c0_i32 + ) + _if_scale = scf.IfOp(is_scale_writer) + with ir.InsertionPoint(_if_scale.then_block): + row_i32_s = arith.index_cast(T.i32, row) + col_s_i32 = col_g0_i32 >> _c5_i32 + d0 = row_i32_s >> _c5_i32 + d1 = (row_i32_s >> _c4_i32) & _c1_i32 + d2 = row_i32_s & _c15_i32 + d3 = col_s_i32 >> _c3_i32 + d4 = (col_s_i32 >> _c2_i32) & _c1_i32 + d5 = col_s_i32 & _c3_i32 + byte_off = ( + d0 * _n32_sort + + d3 * _c256_i32 + + d5 * _c64_i32 + + d2 * _c4_i32 + + d4 * _c2_i32 + + d1 + ) + e8m0_i8 = arith.TruncIOp(T.i8, e8m0_biased) + buffer_ops.buffer_store( + e8m0_i8, + sorted_scale_rsrc, + byte_off, + offset_is_bytes=True, + ) + scf.YieldOp([]) + elif const_expr(_is_splitk): + col_idx = col_g0 + arith.constant(_sk_n_offset[0], index=True) + byte_off_col = col_idx * arith.constant( + out_elem_bytes, index=True + ) + ptr_addr_idx = row_byte_base + byte_off_col + out_ptr_v = _idx_to_llvm_ptr(ptr_addr_idx) + frag_v = frag._value if hasattr(frag, "_value") else frag + llvm.AtomicRMWOp( + llvm.AtomicBinOp.fadd, + out_ptr_v, + frag_v, + llvm.AtomicOrdering.monotonic, + syncscope="agent", + alignment=_e_vec_sk * out_elem_bytes, ) else: - c_shuffle_epilog( - arith=arith, - vector=vector, - gpu=gpu, - scf=scf, - range_constexpr=range_constexpr, - tile_m=tile_m, - tile_n=tile_n, - e_vec=_e_vec, - cshuffle_nlane=_cshuffle_nlane, - block_size=total_threads, - m_repeat=m_repeat, - num_acc_n=num_acc_n, - tx=tx, - lane_div_16=lane_div_16, - lane_mod_16=lane_mod_16, - bx_m=bx_m, - by_n=by_n, - n_tile_base=n_tile_base, - lds_out=lds_out, - frag_elem_type=_frag_elem, - write_row_to_lds=write_row_to_lds, - precompute_row=precompute_row, - store_pair=store_pair, - lds_out_split=lds_out_B, + col_idx = col_g0 + byte_off_col = col_idx * arith.constant( + out_elem_bytes, index=True + ) + ptr_addr_idx = row_byte_base + byte_off_col + out_ptr_v = _idx_to_llvm_ptr(ptr_addr_idx) + frag_v = frag._value if hasattr(frag, "_value") else frag + llvm.StoreOp( + frag_v, + out_ptr_v, + alignment=_e_vec * out_elem_bytes, + nontemporal=True, ) - _if_blk = scf.IfOp(blk_valid) - with ir.InsertionPoint(_if_blk.then_block): - _ifexpert_of = scf.IfOp(exp_valid) - with ir.InsertionPoint(_ifexpert_of.then_block): - _moe_gemm1_body() - scf.YieldOp([]) - scf.YieldOp([]) + _frag_elem = ( + ir.F32Type.get() + if _need_quant + else (ir.BF16Type.get() if out_is_bf16 else ir.F16Type.get()) + ) - gpu.barrier() + if const_expr(gate_up_interleave and not _is_splitk): + # gui without splitk: acc has activation applied, halved N + _gui_eff_n = _gui_out_n + _gui_tile_n = tile_n // 2 + _gui_cshuffle_nlane = min(32, _gui_tile_n // _e_vec) + _gui_by_n = by_n / arith.constant(2, index=True) + _gui_n_tile_base = n_tile_base / arith.constant(2, index=True) + c_shuffle_epilog( + arith=arith, + vector=vector, + gpu=gpu, + scf=scf, + range_constexpr=range_constexpr, + tile_m=tile_m, + tile_n=_gui_tile_n, + e_vec=_e_vec, + cshuffle_nlane=_gui_cshuffle_nlane, + block_size=total_threads, + m_repeat=m_repeat, + num_acc_n=_gui_eff_n, + tx=tx, + lane_div_16=lane_div_16, + lane_mod_16=lane_mod_16, + bx_m=bx_m, + by_n=_gui_by_n, + n_tile_base=_gui_n_tile_base, + lds_out=lds_out, + frag_elem_type=_frag_elem, + write_row_to_lds=write_row_to_lds, + precompute_row=precompute_row, + store_pair=store_pair, + ) + elif const_expr(mock_gate_only or (gate_up_interleave and _is_splitk)): + # mock_gate_only: single pass, by_n covers full [0, 2*inter_dim) + _eff_e_vec = _e_vec_sk + acc = acc_gate + c_shuffle_epilog( + arith=arith, + vector=vector, + gpu=gpu, + scf=scf, + range_constexpr=range_constexpr, + tile_m=tile_m, + tile_n=tile_n, + e_vec=_eff_e_vec, + cshuffle_nlane=_cshuffle_nlane_sk, + block_size=total_threads, + m_repeat=m_repeat, + num_acc_n=num_acc_n, + tx=tx, + lane_div_16=lane_div_16, + lane_mod_16=lane_mod_16, + bx_m=bx_m, + by_n=by_n, + n_tile_base=n_tile_base, + lds_out=lds_out, + frag_elem_type=_frag_elem, + write_row_to_lds=write_row_to_lds, + precompute_row=precompute_row, + store_pair=store_pair, + lds_out_split=lds_out_B, + ) + elif const_expr(_is_splitk): + # Two-pass epilogue: gate then up, each with atomic add + _eff_e_vec = _e_vec_sk + + # Pass 1: gate + acc = acc_gate + _sk_n_offset[0] = 0 + c_shuffle_epilog( + arith=arith, + vector=vector, + gpu=gpu, + scf=scf, + range_constexpr=range_constexpr, + tile_m=tile_m, + tile_n=tile_n, + e_vec=_eff_e_vec, + cshuffle_nlane=_cshuffle_nlane_sk, + block_size=total_threads, + m_repeat=m_repeat, + num_acc_n=num_acc_n, + tx=tx, + lane_div_16=lane_div_16, + lane_mod_16=lane_mod_16, + bx_m=bx_m, + by_n=by_n, + n_tile_base=n_tile_base, + lds_out=lds_out, + frag_elem_type=_frag_elem, + write_row_to_lds=write_row_to_lds, + precompute_row=precompute_row, + store_pair=store_pair, + lds_out_split=lds_out_B, + ) + + gpu.barrier() + + # Pass 2: up + acc = acc_up + _sk_n_offset[0] = inter_dim + c_shuffle_epilog( + arith=arith, + vector=vector, + gpu=gpu, + scf=scf, + range_constexpr=range_constexpr, + tile_m=tile_m, + tile_n=tile_n, + e_vec=_eff_e_vec, + cshuffle_nlane=_cshuffle_nlane_sk, + block_size=total_threads, + m_repeat=m_repeat, + num_acc_n=num_acc_n, + tx=tx, + lane_div_16=lane_div_16, + lane_mod_16=lane_mod_16, + bx_m=bx_m, + by_n=by_n, + n_tile_base=n_tile_base, + lds_out=lds_out, + frag_elem_type=_frag_elem, + write_row_to_lds=write_row_to_lds, + precompute_row=precompute_row, + store_pair=store_pair, + lds_out_split=lds_out_B, + ) + else: + c_shuffle_epilog( + arith=arith, + vector=vector, + gpu=gpu, + scf=scf, + range_constexpr=range_constexpr, + tile_m=tile_m, + tile_n=tile_n, + e_vec=_e_vec, + cshuffle_nlane=_cshuffle_nlane, + block_size=total_threads, + m_repeat=m_repeat, + num_acc_n=num_acc_n, + tx=tx, + lane_div_16=lane_div_16, + lane_mod_16=lane_mod_16, + bx_m=bx_m, + by_n=by_n, + n_tile_base=n_tile_base, + lds_out=lds_out, + frag_elem_type=_frag_elem, + write_row_to_lds=write_row_to_lds, + precompute_row=precompute_row, + store_pair=store_pair, + lds_out_split=lds_out_B, + ) + + _if_blk = scf.IfOp(blk_valid) + with ir.InsertionPoint(_if_blk.then_block): + _ifexpert_of = scf.IfOp(exp_valid) + with ir.InsertionPoint(_ifexpert_of.then_block): + _moe_gemm1_body() + scf.YieldOp([]) scf.YieldOp([]) - _for_ip.__exit__(None, None, None) + + gpu.barrier() + scf.YieldOp([]) + _for_ip.__exit__(None, None, None) # -- Host launcher -- + # Unified cache key: gate_mode encodes gate_only / gate_up_interleave, + # so a single tuple covers both A16W4 and generic stage1 paths. The + # tag includes A16W4-only fields (`_split_k_intra`, `_use_cshuffle_epilog`) + # which are no-ops for the generic path (constant 1 / unchanged). _cache_tag = ( - module_name, - a_dtype, - b_dtype, - out_dtype, - tile_m, - tile_n, - tile_k, - doweight_stage1, - act, - enable_bias, - model_dim_pad, - inter_dim_pad, - use_cshuffle_epilog, - persist_m, - use_async_copy, - waves_per_eu, - k_batch, - gate_mode, - a_scale_one, - xcd_swizzle, + module_name, a_dtype, b_dtype, out_dtype, + tile_m, tile_n, tile_k, + doweight_stage1, act, enable_bias, + model_dim_pad, inter_dim_pad, + _use_cshuffle_epilog, persist_m, use_async_copy, + waves_per_eu, k_batch, gate_mode, + a_scale_one, xcd_swizzle, _split_k_intra, ) + + # ---- Shared postlude (used by both A16W4 and generic stage1 paths) ---- + # waves_per_eu LDS-clamp: pad LDS occupancy so that exactly waves_per_eu + # waves fit on each EU (the rest is reserved by inflating ping ptr). + if waves_per_eu is not None and waves_per_eu >= 1: + _total_cu_lds = 160 * 1024 + _min_lds = _total_cu_lds // (waves_per_eu + 1) + 1 + _pong_sz = allocator_pong._align(allocator_pong.ptr, 128) + _ping_sz = allocator_ping._align(allocator_ping.ptr, 128) + _cur_lds = _pong_sz + _ping_sz + if _cur_lds < _min_lds: + allocator_ping.ptr += _min_lds - _cur_lds + + # Shared host launcher (placed outside the if/else so both paths reuse it). + # `moe_gemm1`, `_cache_tag`, and `total_threads` are bound in either branch + # above; `allocator_pong/ping`, `mock_gate_only`, and `gate_up_interleave` + # are bound in the shared prelude. @flyc.jit def launch_mixed_moe_gemm1( arg_out: fx.Tensor, @@ -2665,12 +4305,18 @@ def launch_mixed_moe_gemm1( / tile_n_index / arith.constant(2, index=True) ) - _c_pm_l = arith.constant(persist_m, index=True) - gy = ( - arith.index_cast(ir.IndexType.get(), i32_size_expert_ids_in.ir_value()) - + _c_pm_l - - arith.constant(1, index=True) - ) / _c_pm_l + if const_expr(is_a16w4_stage1): + # A16W4 stage1 has no persistent-M scheduling; gy = size_expert_ids. + gy = arith.index_cast( + ir.IndexType.get(), i32_size_expert_ids_in.ir_value() + ) + else: + _c_pm_l = arith.constant(persist_m, index=True) + gy = ( + arith.index_cast(ir.IndexType.get(), i32_size_expert_ids_in.ir_value()) + + _c_pm_l + - arith.constant(1, index=True) + ) / _c_pm_l moe_gemm1( arg_out, @@ -2719,6 +4365,8 @@ def compile_mixed_moe_gemm2( sort_block_m: int = 0, b_nt: int = 2, xcd_swizzle: int = 0, + waves_per_eu: int = 0, + split_k_intra: int = 1, ): """Compile stage2 kernel (`moe_gemm2`) and return the compiled executable. @@ -2750,62 +4398,63 @@ def compile_mixed_moe_gemm2( `sort_block_m` is the block_size used by moe_sorting / stage1. When 0 (default), assumed equal to `tile_m`. When set, stage2 can use a different tile_m from sorting/stage1. Requires sort_block_m % tile_m == 0. - """ - _sort_block_m = tile_m if sort_block_m <= 0 else sort_block_m - if _sort_block_m != tile_m and _sort_block_m % tile_m != 0: - raise ValueError( - f"sort_block_m ({_sort_block_m}) must be a multiple of tile_m ({tile_m})" - ) + `split_k_intra` is the A16W4 stage2 intra-block K partition factor + (same name as stage1's `split_k_intra`). It is distinct from stage1's + inter-block split-K `k_batch`. + """ + is_a16w4 = b_dtype == "fp4" and a_dtype in ("bf16", "fp16") + # ---- Unified setup (A16W4 + Generic) ---- gpu_arch = get_hip_arch() allocator_pong = SmemAllocator(None, arch=gpu_arch, global_sym_name="smem0") allocator_ping = SmemAllocator(None, arch=gpu_arch, global_sym_name="smem1") - _state = {} - validate_moe_dtypes(a_dtype, b_dtype) + _sort_block_m = tile_m if (is_a16w4 or sort_block_m <= 0) else sort_block_m + if _sort_block_m != tile_m and _sort_block_m % tile_m != 0: + raise ValueError( + f"sort_block_m ({_sort_block_m}) must be a multiple of tile_m ({tile_m})" + ) - is_f16_a = a_dtype == "fp16" - is_f16_b = b_dtype == "fp16" + if not is_a16w4: + validate_moe_dtypes(a_dtype, b_dtype) - is_f8_a = a_dtype == "fp8" - is_f4_a = a_dtype == "fp4" + is_f16_a = is_a16w4 or (a_dtype == "fp16") + is_f16_b = (not is_a16w4) and (b_dtype == "fp16") + is_f8_a = (not is_a16w4) and (a_dtype == "fp8") + is_f4_a = (not is_a16w4) and (a_dtype == "fp4") is_f4_b = b_dtype == "fp4" + is_int4 = (not is_a16w4) and (b_dtype == "int4") + is_int8 = False - _scale_pack_m = 2 # physical mn_pack in preshuffle microscale layout + _scale_pack_m = 2 _scale_pack_n = 2 - _scale_pack_k = 2 # physical k_pack in preshuffle scale layout + _scale_pack_k = 2 pack_M = min(_scale_pack_m, tile_m // 16) pack_N = min(_scale_pack_n, tile_n // 64) - _k_unroll_raw = (int(tile_k) * (2 if a_dtype == "fp16" else 1)) // 128 + _k_unroll_raw = (int(tile_k) * (2 if is_f16_a else 1)) // 128 pack_K = min(_scale_pack_k, _k_unroll_raw) - elem_bytes = 1 - + elem_bytes = 2 if is_a16w4 else 1 a_elem_bytes = 2 if is_f16_a else 1 b_elem_bytes = 1 tile_k_bytes = int(tile_k) * int(a_elem_bytes) - a_elem_vec_pack = 2 if is_f4_a else 1 + a_elem_vec_pack = 1 if is_a16w4 else (2 if is_f4_a else 1) cbsz = 0 if is_f8_a else 4 blgp = 4 # ---- Static B preshuffle strides (compile-time) ---- - # All values below are Python ints computable at kernel-compile time. - # Using them in an explicit multiply-add replaces the fly dialect's - # dynamic ``crd2idx`` path which emits Barrett reduction for the - # non-power-of-2 ``n0 = experts*model_dim//16`` shape. - _b_kpack_bytes_s = 8 if (b_dtype == "int4") else 16 + _b_kpack_bytes_s = 8 if is_int4 else 16 _b_kpack_elems_s = _b_kpack_bytes_s // b_elem_bytes _b_c_k_s = inter_dim // _scale_pack_k _b_c_k0_s = (_b_c_k_s * b_elem_bytes) // 64 - _b_stride_nlane = _b_kpack_elems_s # 16 - _b_stride_klane = 16 * _b_stride_nlane # 256 - _b_stride_k0 = 4 * _b_stride_klane # 1024 - _b_stride_n0 = _b_c_k0_s * _b_stride_k0 # c_k0 * 1024 + _b_stride_nlane = _b_kpack_elems_s + _b_stride_klane = 16 * _b_stride_nlane + _b_stride_k0 = 4 * _b_stride_klane + _b_stride_n0 = _b_c_k0_s * _b_stride_k0 assert model_dim % 16 == 0, "model_dim must be divisible by 16" _expert_b_stride = (model_dim // 16) * _b_stride_n0 - # K64-byte micro-step: always 64 bytes per `ku`. For fp16, this is 32 elements (2xK16 MFMA). if (tile_k_bytes % 64) != 0: raise ValueError( f"tile_k_bytes must be divisible by 64, got tile_k_bytes={tile_k_bytes} " @@ -2815,30 +4464,40 @@ def compile_mixed_moe_gemm2( out_s = str(out_dtype).strip().lower() if out_s not in ("f16", "fp16", "half", "bf16", "bfloat16", "f32", "fp32", "float"): raise ValueError( - f"out_dtype must be 'f16', 'bf16', or 'f32', got {out_dtype!r}" + f"out_dtype must be \'f16\', \'bf16\', or \'f32\', got {out_dtype!r}" ) out_is_f32 = out_s in ("f32", "fp32", "float") out_is_bf16 = out_s in ("bf16", "bfloat16") if (not bool(accumulate)) and out_is_f32: raise ValueError( - "compile_moe_gemm2(accumulate=False) only supports out_dtype in {'f16','bf16'}" + "compile_moe_gemm2(accumulate=False) only supports out_dtype in {\'f16\',\'bf16\'}" ) - is_int4 = b_dtype == "int4" - # INT4 here means W4A8: A2 is int8, W is packed int4 and unpacked to int8 in-kernel. - is_int8 = False + # A16W4-specific: BF16 K32 MFMA + mfma_f32_bf16_k32 = None + kpack_bytes = 16 # MXFP4 preshuffle (used by A16W4 body) + if is_a16w4: + _mfma_k32_raw = getattr(rocdl, "mfma_f32_16x16x32_bf16_", None) + if _mfma_k32_raw is None: + raise AttributeError( + "BF16 K32 MFMA op not found: expected `rocdl.mfma_f32_16x16x32_bf16_`" + ) + _split_mfma = rocdl._split_mfma_operands + + def mfma_f32_bf16_k32(result_type, operands, *, loc=None, ip=None): + a, b, c, cbsz, abid, blgp = _split_mfma(operands, loc=loc) + return _mfma_k32_raw(result_type, a, b, c, cbsz, abid, blgp, loc=loc, ip=ip) + + # Generic-specific: INT8 MFMA (rare, kept for completeness) mfma_i32_k32 = None if is_int8: mfma_i32_k32 = getattr(rocdl, "mfma_i32_16x16x32i8", None) or getattr( rocdl, "mfma_i32_16x16x32_i8", None ) - if mfma_i32_k32 is None: - raise AttributeError( - "INT8 K32 MFMA op not found: expected `rocdl.mfma_i32_16x16x32i8` " - "(or `rocdl.mfma_i32_16x16x32_i8`)." - ) def _x_elem_type(): + if is_a16w4: + return T.bf16 if is_f4_b: return T.f8 if is_f8_a else T.i8 return T.f16 if is_f16_a else (T.i8 if is_int8 else T.f8) @@ -2848,9 +4507,6 @@ def _w_elem_type(): return T.i8 return T.f16 if is_f16_b else (T.i8 if is_int8 else T.f8) - def _scale_elem_type(): - return T.i32 - total_threads = 256 bytes_x_per_tile = int(tile_m) * int(tile_k) * int(a_elem_bytes) if bytes_x_per_tile % total_threads != 0: @@ -2861,11 +4517,7 @@ def _scale_elem_type(): bytes_per_thread_x = bytes_x_per_tile // total_threads _use_lds128 = os.environ.get("FLIR_CK_LDS128", "1") in ( - "1", - "true", - "True", - "YES", - "yes", + "1", "true", "True", "YES", "yes", ) pad_k = 0 if _use_lds128 else 8 lds_stride = tile_k + pad_k @@ -2878,22 +4530,17 @@ def _scale_elem_type(): _eff_tile_k_bytes = tile_k_bytes if out_is_f32: - # Match origin/dev_a16w4: f32 output uses scalar atomics and does NOT use the CShuffle epilogue. _use_cshuffle_epilog = ( False if use_cshuffle_epilog is None else bool(use_cshuffle_epilog) ) if _use_cshuffle_epilog: raise ValueError( - "out_dtype='f32' does not support CShuffle epilogue (set use_cshuffle_epilog=False)." + "out_dtype=\'f32\' does not support CShuffle epilogue (set use_cshuffle_epilog=False)." ) else: if use_cshuffle_epilog is None: _use_cshuffle_epilog = os.environ.get("FLIR_MOE_STAGE2_CSHUFFLE", "1") in ( - "1", - "true", - "True", - "YES", - "yes", + "1", "true", "True", "YES", "yes", ) else: _use_cshuffle_epilog = bool(use_cshuffle_epilog) @@ -2902,36 +4549,78 @@ def _scale_elem_type(): "stage2 f16 output currently requires CShuffle epilogue (FLIR_MOE_STAGE2_CSHUFFLE=1)." ) - # NOTE: Keep this as a callable so we don't require an MLIR Context at Python-time. + if out_is_bf16: + if not supports_bf16_global_atomics(gpu_arch): + raise ValueError( + f"out_dtype=\'bf16\' requires bf16 global atomics, got arch={gpu_arch!r}" + ) + def out_elem(): return T.f32 if out_is_f32 else (T.bf16 if out_is_bf16 else T.f16) - epilog_tag = "cshuffle" - # IMPORTANT: include tiling in the module name to avoid accidentally reusing a compiled - # binary for a different (tile_m, tile_n, tile_k) configuration. - # See stage1 note: include ABI tag to prevent binary reuse across signature changes. - # IMPORTANT: module name participates in the compiler cache key. - # Dynamic-shape variant: safe to reuse across (tokens/sorted_size/size_expert_ids) at runtime. - # Keep a distinct ABI tag so the compile cache never mixes with historical signatures. - _persistent = persist_m <= 0 - if _persistent: - from aiter.jit.utils.chip_info import get_cu_num - - _cu_num = get_cu_num() - else: + def x_lds_elem(): + if is_a16w4: + return T.bf16 + return T.f16 if is_f16_a else (T.i8 if is_int8 else T.f8) + + # ---- Persist / module name ---- + if is_a16w4: + _persistent = False _cu_num = 0 + persist_m = 1 # A16W4: single M-tile per WG + else: + _persistent = persist_m <= 0 + if _persistent: + from aiter.jit.utils.chip_info import get_cu_num + _cu_num = get_cu_num() + else: + _cu_num = 0 + + # Unified module name (cache key). Optional tags are appended only when + # they deviate from defaults so A8W4/A4W4 keys remain backward-compatible + # with previously cached binaries (ski/wpe/bias/pad default off). + _ski_tag = f"_ski{int(split_k_intra)}" if int(split_k_intra) > 1 else "" + _wpe_tag = f"_wpe{waves_per_eu}" if waves_per_eu >= 1 else "" + _bias_tag = "_bias" if enable_bias else "" + _pad_tag = ( + f"_mp{model_dim_pad}_ip{inter_dim_pad}" + if (model_dim_pad or inter_dim_pad) + else "" + ) _sbm_tag = "" if _sort_block_m == tile_m else f"_sbm{_sort_block_m}" _pm_tag = f"_persist_cu{_cu_num}" if _persistent else f"_pm{persist_m}" _xcd_tag = f"_xcd{xcd_swizzle}" if xcd_swizzle > 0 else "" module_name = ( - f"mfma_moe2_a{a_dtype}_w{b_dtype}_{out_s}_{epilog_tag}" + f"mfma_moe2_a{a_dtype}_w{b_dtype}_{out_s}_cshuffle" f"_t{tile_m}x{tile_n}x{tile_k}" - f"_vscale_fix3{_pm_tag}{_sbm_tag}{_xcd_tag}" + f"_vscale_fix3" + f"{_pm_tag}{_sbm_tag}{_xcd_tag}{_ski_tag}{_wpe_tag}{_bias_tag}{_pad_tag}" ).replace("-", "_") - # -- LDS sizing (pure Python; no MLIR Context needed) --------------------- - # Ping-pong A2 tiles via separate allocators (like stage1). + + # ---- A16W4 split_k_intra validation ---- + _split_k_intra = int(split_k_intra) + if is_a16w4 and _split_k_intra > 1: + if inter_dim % (_split_k_intra * tile_k) != 0: + raise ValueError( + f"inter_dim={inter_dim} must be divisible by " + f"split_k_intra*tile_k={_split_k_intra * tile_k}" + ) + _k_dim = inter_dim // _split_k_intra + _total_tiles_check = _k_dim // tile_k + if _total_tiles_check < 2 or _total_tiles_check % 2 != 0: + raise ValueError( + f"split_k_intra={_split_k_intra}: " + f"_k_dim/tile_k={_total_tiles_check} must be even and >= 2 " + f"for the ping-pong pipeline" + ) + elif is_a16w4: + _k_dim = inter_dim + else: + _k_dim = inter_dim # generic: unused but defined for uniformity + + # ---- LDS sizing ---- _single_x_bytes = int(tile_m) * int(_eff_lds_stride) * int(a_elem_bytes) - _cshuffle_elem_bytes_s2 = 2 # f16/bf16 = 2 bytes + _cshuffle_elem_bytes_s2 = 2 lds_out_bytes = ( _cshuffle_elem_bytes_s2 * int(tile_m) * int(tile_n) if _use_cshuffle_epilog @@ -2939,21 +4628,42 @@ def out_elem(): ) lds_tid_bytes = int(tile_m) * 4 _input_elems = _single_x_bytes if a_elem_bytes == 1 else (_single_x_bytes // 2) + _single_x_elems = _single_x_bytes // int(a_elem_bytes) if is_a16w4 else _input_elems _pong_buffer_bytes = max(_single_x_bytes, lds_out_bytes) _ping_buffer_bytes = _single_x_bytes - def x_lds_elem(): - return T.f16 if is_f16_a else (T.i8 if is_int8 else T.f8) - lds_pong_offset = allocator_pong._align(allocator_pong.ptr, 16) allocator_pong.ptr = lds_pong_offset + _pong_buffer_bytes + + # Unified sorted-token cache slot in LDS. tile_m i32 values (== lds_tid_bytes). + # A16W4 epilogue calls this `lds_sorted_info_offset`; generic uses `_lds_tid_offset_pong`. _lds_tid_offset_pong = allocator_pong._align(allocator_pong.ptr, 4) allocator_pong.ptr = _lds_tid_offset_pong + lds_tid_bytes lds_ping_offset = allocator_ping._align(allocator_ping.ptr, 16) allocator_ping.ptr = lds_ping_offset + _ping_buffer_bytes + if waves_per_eu >= 1: + _total_cu_lds = 160 * 1024 + _min_lds = _total_cu_lds // (waves_per_eu + 1) + 1 + _pong_sz = allocator_pong._align(allocator_pong.ptr, 128) + _ping_sz = allocator_ping._align(allocator_ping.ptr, 128) + _cur_lds = _pong_sz + _ping_sz + if _cur_lds < _min_lds: + allocator_ping.ptr += _min_lds - _cur_lds + + _cshuffle_nlane = 32 + if bool(accumulate): + _e_vec = 2 + else: + _e_vec = 8 if int(tile_n) % (_cshuffle_nlane * 8) == 0 else 2 + _cshuffle_stride = _cshuffle_nlane * _e_vec + if int(tile_n) % _cshuffle_stride != 0: + raise ValueError( + f"tile_n={tile_n} must be divisible by {_cshuffle_stride} when accumulate=False" + ) + if True: @flyc.kernel @@ -2973,7 +4683,6 @@ def moe_gemm2( i32_k_in: fx.Int32, i32_size_expert_ids_in: fx.Int32, ): - tokens_in = arith.index_cast(ir.IndexType.get(), i32_tokens_in.ir_value()) n_in = arith.index_cast(ir.IndexType.get(), i32_n_in.ir_value()) k_in = arith.index_cast(ir.IndexType.get(), i32_k_in.ir_value()) @@ -3004,7 +4713,7 @@ def moe_gemm2( # B preshuffle layout: [experts*model_dim, inter_dim] c_n_total = arith.constant(experts * model_dim, index=True) - kpack_bytes = 8 if is_int4 else 16 + mixed_kpack_bytes = 8 if is_int4 else 16 from .layout_utils import _div_pow2, _mod_pow2 def check_c_n_valid_gate(base_n): @@ -3145,35 +4854,29 @@ def check_c_k_valid_gate(base_k): # fp16 path ignores scales completely (implicit scale=1.0). sx_rsrc = 1 sw_rsrc = 1 - if const_expr(not is_f16_a): - if const_expr(is_f4_a or is_f8_a): - # A2 microscale: e8m0 in sorted layout [sorted_size, K/32]. - # Caller must pre-scatter a2_scale via moe_mxfp4_sort. - kblk = _div_pow2(k_in, 32) - sx_nbytes_idx = num_valid_idx * kblk - sx_nbytes_i32 = arith.index_cast(T.i32, sx_nbytes_idx) - sx_rsrc = buffer_ops.create_buffer_resource( - arg_scale_x, max_size=False, num_records_bytes=sx_nbytes_i32 - ) - else: - # scale_x (A2 scale): [tokens*topk] f32 -> bytes = tokens*topk*4 - sx_nbytes_idx = (tokens_in * c_topk) * arith.constant(4, index=True) - sx_nbytes_i32 = arith.index_cast(T.i32, sx_nbytes_idx) - sx_rsrc = buffer_ops.create_buffer_resource( - arg_scale_x, max_size=False, num_records_bytes=sx_nbytes_i32 - ) - - if const_expr(not is_f16_b): - # Weight microscale buffer (packed i32 holding e8m0 bytes). - # Use an exact descriptor size so hardware OOB checking works. - kblk_w = _div_pow2(k_in, 32) # K/32 - mn_w = arith.constant(experts * model_dim, index=True) - sw_nbytes_idx = mn_w * kblk_w # bytes (e8m0) - sw_nbytes_i32 = arith.index_cast(T.i32, sw_nbytes_idx) - sw_rsrc = buffer_ops.create_buffer_resource( - arg_scale_w, max_size=False, num_records_bytes=sw_nbytes_i32 + # A activation microscale (e8m0). Only present when activations are + # FP8 or FP4; A16W4 (bf16/fp16 activations) keeps sx_rsrc=1 sentinel. + # Scale layout is [sorted_size, K/32] -- caller pre-scatters via + # moe_mxfp4_sort. + if const_expr(is_f4_a or is_f8_a): + kblk = _div_pow2(k_in, 32) + sx_nbytes_idx = num_valid_idx * kblk + sx_nbytes_i32 = arith.index_cast(T.i32, sx_nbytes_idx) + sx_rsrc = buffer_ops.create_buffer_resource( + arg_scale_x, max_size=False, num_records_bytes=sx_nbytes_i32 ) + # Weight microscale buffer (packed i32 holding e8m0 bytes). All + # supported configurations (a16w4 / a8w4 / a4w4) have FP4 weights, + # so this descriptor is always created. + kblk_w = _div_pow2(k_in, 32) # K/32 + mn_w = arith.constant(experts * model_dim, index=True) + sw_nbytes_idx = mn_w * kblk_w # bytes (e8m0) + sw_nbytes_i32 = arith.index_cast(T.i32, sw_nbytes_idx) + sw_rsrc = buffer_ops.create_buffer_resource( + arg_scale_w, max_size=False, num_records_bytes=sw_nbytes_i32 + ) + # sorted_token_ids / sorted_weights: [blocks*tile_m] (padded length) sorted_nbytes_idx = ( size_expert_ids_in @@ -3300,10 +5003,858 @@ def check_c_k_valid_gate(base_k): _m_scale_shift_i32 = None def _moe_gemm2_then_body(): - # Expert id for this M tile. + # ===== Shared preamble (consumed by both A16W4 and generic paths) ===== + # Expert offset (in N elements) and tx -> wave/lane decomposition. n_idx = arith.constant(model_dim, index=True) expert_off_idx = expert_idx * n_idx # index + coord_wl = idx2crd(tx, layout_tx_wave_lane) + wave_id = layout_get(coord_wl, 0) + lane_id = layout_get(coord_wl, 1) + coord_l16 = idx2crd(lane_id, layout_lane16) + lane_div_16 = layout_get(coord_l16, 0) + lane_mod_16 = layout_get(coord_l16, 1) + + if const_expr(is_a16w4): + # ===== A16W4-only type aliases ===== + # i32 / f32 / i64 / vec4_f32 / vec2_i64 / acc_init come from + # the moe_gemm2 setup-level closure (already defined for the + # generic path and equivalent for A16W4 since is_int8=False). + f16 = T.f16 + _a16_x_elem = T.bf16 + w_elem = T.i8 + vec1_i32 = T.vec(1, T.i32) + vec8_bf16 = T.vec(8, _a16_x_elem) + + # ===== A16W4 M / X layout (uses runtime k_in) ===== + topk_idx = arith.index(topk) + m_in = tokens_in * topk_idx + m_i32_v = arith.index_cast(i32, m_in) + k_i32_v = i32_k_in.ir_value() + layout_x = fx.make_layout((m_i32_v, k_i32_v), stride=(k_i32_v, 1)) + + # ===== A16W4 B layout (MXFP4 kpack=16, c_k = k_in // 2) ===== + c2 = arith.index(2) + c_k_packed = k_in // c2 + b_layout = make_preshuffle_b_layout( + arith, + c_n=c_n_total, + c_k=c_k_packed, + kpack_bytes=kpack_bytes, + elem_bytes=1, + ) + layout_b = b_layout.layout_b + + # ===== A16W4 B scale layout (mn_pack=2, k_pack=2, scale_block_size=32) ===== + _a16_layout_b_scale = make_preshuffle_scale_layout( + arith, + c_mn=c_n_total, + c_k=k_in, + mn_pack=2, + k_pack=2, + elem_bytes=4, + scale_block_size=32, + ) + + # ===== A16W4 LDS aliases ===== + # Generic upper preamble preloads sorted_idx into lds_tid; A16W4 path + # also writes to it during X-load decode (same offset, same data). + lds_sorted_cache = lds_tid + + # n_idx / expert_off_idx come from the shared preamble above. + # For A16W4, _sort_block_m == tile_m so generic's sort_blk = bx + # yields the same expert_idx as the original A16W4 path. + + if const_expr(bytes_per_thread_x >= 16 and bytes_per_thread_x % 16 == 0): + x_load_bytes = 16 + elif const_expr(bytes_per_thread_x >= 8 and bytes_per_thread_x % 8 == 0): + x_load_bytes = 8 + elif const_expr(bytes_per_thread_x >= 4 and bytes_per_thread_x % 4 == 0): + x_load_bytes = 4 + else: + raise ValueError( + f"bytes_per_thread_x ({bytes_per_thread_x}) must be " + f"divisible by 4" + ) + num_x_loads = bytes_per_thread_x // x_load_bytes + chunk_i32 = x_load_bytes // 4 + x_vec_elems = x_load_bytes // elem_bytes + x_vec_i32_ty = T.vec(chunk_i32, i32) if chunk_i32 > 1 else T.vec(1, i32) + _a16_vec16_x = T.vec(8, _a16_x_elem) + + c_k_div4 = (k_in * arith.index(int(elem_bytes))) // arith.index(4) + c_k_div4_i32 = arith.index_cast(i32, c_k_div4) + layout_x_div4 = fx.make_layout( + (m_i32_v, c_k_div4_i32), stride=(c_k_div4_i32, 1) + ) + tile_k_dwords = (int(tile_k) * int(elem_bytes)) // 4 + layout_x_tile_div4 = fx.make_layout( + (tile_m, tile_k_dwords), stride=(tile_k_dwords, 1) + ) + c_chunk_i32 = arith.index(chunk_i32) + tx_i32_base = tx * c_chunk_i32 + + topk_i32 = arith.constant(topk, type=T.i32) + mask24 = arith.constant(0xFFFFFF, type=T.i32) + tokens_i32 = arith.index_cast(i32, tokens_in) + + def x_tile_chunk_coord_i32(i: int): + return tile_chunk_coord_i32( + arith, + tx_i32_base=tx_i32_base, + i=i, + total_threads=total_threads, + layout_tile_div4=layout_x_tile_div4, + chunk_i32=chunk_i32, + ) + + x_row_base_div4 = [] + x_col_local_i32 = [] + x_row_local = [] + for i in range_constexpr(num_x_loads): + row_local, col_local_i32 = x_tile_chunk_coord_i32(i) + x_row_local.append(row_local) + x_col_local_i32.append(col_local_i32) + + sorted_row_i = bx_m + row_local + fused_i = buffer_ops.buffer_load( + sorted_rsrc, sorted_row_i, vec_width=1, dtype=i32 + ) + _fused_v1 = vector.from_elements(vec1_i32, [fused_i]) + vector.store(_fused_v1, lds_sorted_cache, [row_local]) + t_i32 = fused_i & mask24 + s_i32 = arith.shrui(fused_i, arith.constant(24, type=T.i32)) + t_valid = arith.cmpi(arith.CmpIPredicate.ult, t_i32, tokens_i32) + s_valid = arith.cmpi(arith.CmpIPredicate.ult, s_i32, topk_i32) + ts_valid = t_valid & s_valid + t_safe = arith.select( + ts_valid, t_i32, arith.constant(0, type=T.i32) + ) + s_safe = arith.select( + ts_valid, s_i32, arith.constant(0, type=T.i32) + ) + row_ts_i32 = t_safe * topk_i32 + s_safe + row_ts_idx = arith.index_cast(T.index, row_ts_i32) + x_row_base_div4.append(row_ts_idx * c_k_div4) + + # NOTE: load_x_tile is unused on the A16W4 path -- B-cycle + # consumes X via dma_x_tile_to_lds + LDS reads. Removed. + # wave_id / lane_id / lane_div_16 / lane_mod_16 come from the + # shared preamble above. + + _dma_bytes = 16 + _wave_size = 64 + + def dma_x_tile_to_lds(base_k, lds_buffer): + """Async DMA: global -> LDS via buffer_load_lds, no VGPR.""" + c4_idx = arith.index(4) + base_k_div4 = ( + base_k * arith.index(int(elem_bytes)) + ) // arith.index(4) + + lds_ptr_i64 = None + for i in range_constexpr(num_x_loads): + row_local_i = x_row_local[i] + col_local_i32_i = x_col_local_i32[i] + col_local_sw = swizzle_xor16( + row_local_i, col_local_i32_i * c4_idx, k_blocks16 + ) + row_k_dw = x_row_base_div4[i] + base_k_div4 + global_byte_idx = row_k_dw * c4_idx + col_local_sw + global_offset = arith.index_cast(i32, global_byte_idx) + + if const_expr(i == 0): + lds_addr = memref.extract_aligned_pointer_as_index( + lds_buffer + ) + wave_id * arith.constant( + _wave_size * _dma_bytes, index=True + ) + lds_ptr_i64 = rocdl.readfirstlane( + i64, arith.index_cast(i64, lds_addr) + ) + else: + lds_ptr_i64 = lds_ptr_i64 + arith.constant( + total_threads * _dma_bytes, type=i64 + ) + + lds_ptr_type = ir.Type.parse("!llvm.ptr<3>") + lds_ptr = llvm.inttoptr(lds_ptr_type, lds_ptr_i64) + + rocdl.raw_ptr_buffer_load_lds( + x_rsrc, + lds_ptr, + arith.constant(_dma_bytes, type=i32), + global_offset, + arith.constant(0, type=i32), + arith.constant(0, type=i32), + arith.constant(0, type=i32), + ) + + def prefetch_x_to_lds(base_k, lds_buffer): + dma_x_tile_to_lds(base_k, lds_buffer) + + row_a_lds = lane_mod_16 + _a_sublane_stride = 64 # 32 bf16 * 2 bytes + _a_ku_stride_bytes = 16 # 8 bf16 * 2 bytes + col_offset_base_bytes = lane_div_16 * arith.index(_a_sublane_stride) + + by_n = by * arith.index(tile_n) + num_waves = 4 + n_per_wave = tile_n // num_waves + num_acc_n = n_per_wave // 16 + c_n_per_wave = arith.index(n_per_wave) + wave_mod_4 = wave_id % arith.index(4) + n_tile_base = wave_mod_4 * c_n_per_wave + + n_intra_list = [] + n_blk_list = [] + col_g_list = [] + c_n0_static = experts * model_dim // 16 + layout_n_blk_intra = fx.make_layout((c_n0_static, 16), stride=(16, 1)) + for ni in range_constexpr(num_acc_n): + offset = arith.index(ni * 16) + col_g = by_n + n_tile_base + offset + lane_mod_16 + col_g_list.append(col_g) + + row_w = expert_off_idx + col_g + coord_w = fx.idx2crd(row_w, layout_n_blk_intra) + n_blk_list.append(fx.get(coord_w, 0)) + n_intra_list.append(fx.get(coord_w, 1)) + + m_repeat = tile_m // 16 + k_unroll = tile_k_bytes // 64 + + _pad_k_elems = ( + (inter_dim_pad % tile_k) + if (_split_k_intra == 1 and inter_dim_pad > 0) + else 0 + ) + _pad_ku_skip = _pad_k_elems // 32 + _tail_ku = k_unroll - _pad_ku_skip + _tail_k0_count = (_tail_ku + 3) // 4 if _pad_ku_skip > 0 else None + + # ---- Scale index helpers ---- + # mni for each ni: (expert_off + by_n + n_tile_base + ni*16) // 32 + scale_mni_list = [] + scale_n_pack_list = [] + for ni in range_constexpr(num_acc_n): + n_global = expert_off_idx + by_n + n_tile_base + arith.index(ni * 16) + scale_mni_list.append(n_global // arith.index(32)) + n_block_16 = n_global // arith.index(16) + scale_n_pack_list.append(n_block_16 % arith.index(2)) + + def _load_scale_i32(scale_ku_idx, ni, scale_klane=None): + """Load one packed i32 from the scale buffer.""" + _klane = scale_klane if scale_klane is not None else lane_div_16 + idx = (scale_mni_list[ni] * _a16_layout_b_scale.stride_n0 + + scale_ku_idx * _a16_layout_b_scale.stride_k0 + + _klane * _a16_layout_b_scale.stride_klane + + lane_mod_16) + return buffer_ops.buffer_load( + sw_rsrc, idx, vec_width=1, dtype=i32 + ) + + def _extract_e8m0_f32_dynamic(packed_i32, byte_pos_idx): + """Extract E8M0 byte at runtime byte_pos and decode to f32.""" + shift = arith.index_cast(i32, byte_pos_idx) * arith.constant(8, type=i32) + byte_i32 = arith.shrui(packed_i32, shift) & arith.constant(0xFF, type=i32) + scale_bits = arith.shli(byte_i32, arith.constant(23, type=i32)) + return arith.bitcast(f32, scale_bits) + + # ---- B Load (dwordx4) + Scale for MXFP4 ---- + def _get_scale_f32(base_k, ku, ni, scale_cache): + """CK addressing for scale: adj_ku = base_k//32 + (ku//4)*4 + lane_div_16.""" + _k0_blk = ku // 4 + adj_ku = (base_k // arith.index(32) + + arith.index(_k0_blk * 4) + + lane_div_16) + scale_klane_rt = lane_div_16 + k_pack_sub_rt = (adj_ku // arith.index(4)) % arith.index(2) + s_ku = adj_ku // arith.index(8) + + cache_key = (_k0_blk, ni) + if cache_key not in scale_cache: + scale_cache[cache_key] = _load_scale_i32( + s_ku, ni, scale_klane=scale_klane_rt + ) + packed = scale_cache[cache_key] + n_pack_sub_val = scale_n_pack_list[ni] + byte_pos_even = k_pack_sub_rt * arith.index(2) + byte_pos_odd = byte_pos_even + arith.index(1) + scale_even = _extract_e8m0_f32_dynamic(packed, byte_pos_even) + scale_odd = _extract_e8m0_f32_dynamic(packed, byte_pos_odd) + n_pack_is_zero = arith.cmpi( + arith.CmpIPredicate.eq, + arith.index_cast(i32, n_pack_sub_val), + arith.constant(0, type=i32), + ) + return arith.select(n_pack_is_zero, scale_even, scale_odd) + + _k_per_dwordx4 = 128 + _k0_count = tile_k // _k_per_dwordx4 + + def load_b_raw(base_k, k0_limit=_k0_count): + """Load raw FP4 data via dwordx4. Returns raw_v4[k0_idx][ni].""" + raw_all = [] + for k0_idx in range_constexpr(k0_limit): + raw_k0 = [] + k_off = base_k + arith.index(k0_idx * _k_per_dwordx4) + for ni in range_constexpr(num_acc_n): + v4 = load_b_raw_mxfp4_dwordx4( + buffer_ops, arith, vector, + arg_b=arg_w, + b_rsrc=w_rsrc, + layout_b=layout_b, + base_k=k_off, + n_blk=n_blk_list[ni], + n_intra=n_intra_list[ni], + lane_div_16=lane_div_16, + elem_type=w_elem, + kpack_bytes=kpack_bytes, + cache_modifier=2, + ) + raw_k0.append(v4) + raw_all.append(raw_k0) + return raw_all + + def load_b_scale_raw(base_k, k0_limit=_k0_count): + """Issue scale buffer_loads only (no extraction). + Returns (packed_dict, kps_dict): + packed_dict: {(k0_blk, ni): packed_i32} + kps_dict: {k0_blk: k_pack_sub_rt} + """ + packed_dict = {} + kps_dict = {} + for k0_blk in range_constexpr(k0_limit): + adj_ku = (base_k // arith.index(32) + + arith.index(k0_blk * 4) + + lane_div_16) + scale_klane_rt = lane_div_16 + kps_dict[k0_blk] = (adj_ku // arith.index(4)) % arith.index(2) + s_ku = adj_ku // arith.index(8) + for ni in range_constexpr(num_acc_n): + packed_dict[(k0_blk, ni)] = _load_scale_i32( + s_ku, ni, scale_klane=scale_klane_rt + ) + return packed_dict, kps_dict + + def extract_b_scales(packed_dict, kps_dict, ku_limit=k_unroll): + """Extract f32 scales from pre-loaded packed i32. + Returns scales[ku][ni] = f32. + """ + scales = [] + for ku in range_constexpr(ku_limit): + scales_ku = [] + _k0_blk = ku // 4 + k_pack_sub_rt = kps_dict[_k0_blk] + for ni in range_constexpr(num_acc_n): + packed = packed_dict[(_k0_blk, ni)] + n_pack_sub_val = scale_n_pack_list[ni] + byte_pos_even = k_pack_sub_rt * arith.index(2) + byte_pos_odd = byte_pos_even + arith.index(1) + scale_even = _extract_e8m0_f32_dynamic(packed, byte_pos_even) + scale_odd = _extract_e8m0_f32_dynamic(packed, byte_pos_odd) + n_pack_is_zero = arith.cmpi( + arith.CmpIPredicate.eq, + arith.index_cast(i32, n_pack_sub_val), + arith.constant(0, type=i32), + ) + sf = arith.select(n_pack_is_zero, scale_even, scale_odd) + scales_ku.append(sf) + scales.append(scales_ku) + return scales + + def lds_load_packs_k64(curr_row_a_lds, col_base_bytes, lds_buffer): + col_base_swz_bytes = swizzle_xor16( + curr_row_a_lds, col_base_bytes, k_blocks16 + ) + col_base_swz = col_base_swz_bytes // arith.index(int(elem_bytes)) + idx_a16 = crd2idx((curr_row_a_lds, col_base_swz), layout_lds) + loaded_a16 = vector.load_op(_a16_vec16_x, lds_buffer, [idx_a16]) + a_i64x2 = vector.bitcast(vec2_i64, loaded_a16) + a0 = vector.extract(a_i64x2, static_position=[0], dynamic_position=[]) + a1 = vector.extract(a_i64x2, static_position=[1], dynamic_position=[]) + return a0, a1 + + def _a_col_bytes_for_ku(ku_val): + """CK-style A col address: L*64 + (ku%4)*16 + (ku//4)*256.""" + _k0_blk = ku_val // 4 + _ku_in = ku_val % 4 + return col_offset_base_bytes + arith.index( + _ku_in * _a_ku_stride_bytes + _k0_blk * 256 + ) + + _total_a_slots = k_unroll * m_repeat + + def preload_a_from_lds(lds_buffer, ku_limit=k_unroll): + """Load all A tiles for ku_limit × m_repeat from LDS into VGPRs.""" + a_tiles = [None] * (ku_limit * m_repeat) + for ku in range_constexpr(ku_limit): + for mi in range_constexpr(m_repeat): + col = _a_col_bytes_for_ku(ku) + row = row_a_lds + arith.index(mi * 16) + a_tiles[ku * m_repeat + mi] = lds_load_packs_k64( + row, col, lds_buffer + ) + return a_tiles + + def _mfma_k32(acc_in, a0, a1, b0, b1): + a_v2 = vector.from_elements(vec2_i64, [a0, a1]) + a_v8 = vector.bitcast(vec8_bf16, a_v2) + b_v2 = vector.from_elements(vec2_i64, [b0, b1]) + b_v8 = vector.bitcast(vec8_bf16, b_v2) + return mfma_f32_bf16_k32(vec4_f32, [a_v8, b_v8, acc_in, 0, 0, 0]) + + def compute_tile( + acc_in, b_v4, b_scales, a_tiles_cur, + *, ku_count=k_unroll, prefetch_epilogue: bool = False, + ): + """Compute GEMM tile with preloaded A (pure compute, no ds_read). + + Returns: (acc_list, epilogue_pf). + """ + acc_list = list(acc_in) + + epilogue_pf = None + if const_expr(prefetch_epilogue): + tw_pf = None + if const_expr(doweight_stage2): + tw_pf = [] + lane_div_16_mul4_pf = lane_div_16 * arith.index(4) + ii_idx_list_pf = [arith.index(ii) for ii in range(4)] + for mi in range_constexpr(m_repeat): + mi_base_pf = arith.index(mi * 16) + for ii in range_constexpr(4): + row_off_pf = lane_div_16_mul4_pf + ii_idx_list_pf[ii] + row_in_tile_pf = mi_base_pf + row_off_pf + sorted_row_pf = bx_m + row_in_tile_pf + tw_pf.append( + buffer_ops.buffer_load( + sorted_w_rsrc, sorted_row_pf, + vec_width=1, dtype=f32, + ) + ) + epilogue_pf = (None, tw_pf) + + for ni in range_constexpr(num_acc_n): + for ku in range_constexpr(ku_count): + _k0_idx = ku // 4 + _ku_in_k0 = ku % 4 + + b_raw_ku = vector.extract( + b_v4[_k0_idx][ni], + static_position=[_ku_in_k0], + dynamic_position=[], + ) + bb0, bb1 = unpack_b_mxfp4_bf16( + b_raw_ku, arith, vector, + scale_f32=b_scales[ku][ni], + ) + + for mi in range_constexpr(m_repeat): + _flat = ku * m_repeat + mi + a0, a1 = a_tiles_cur[_flat] + + acc_idx = mi * num_acc_n + ni + acc_list[acc_idx] = _mfma_k32( + acc_list[acc_idx], a0, a1, bb0, bb1, + ) + + return acc_list, epilogue_pf + + rocdl.sched_barrier(0) + + def hot_loop_scheduler(): + """CK-style scheduler: interleave MFMA, DS_READ, VMEM_READ.""" + _dsread_per_wg = 1 + _mfma_per_wg = 1 + _NIterPerWarp = num_acc_n + _mfma_perM_perK = _NIterPerWarp * _mfma_per_wg + + _HalfMIter = (m_repeat + 1) // 2 + + _Aload_num_perK = _dsread_per_wg * m_repeat + _Aload_rep = max((_Aload_num_perK + m_repeat - 1) // m_repeat, 1) + _Bload_num_perK = num_acc_n + _Bload_rep = max((_Bload_num_perK + _HalfMIter - 1) // _HalfMIter, 1) + + for _ku in range_constexpr(k_unroll): + for _mi in range_constexpr(m_repeat): + _dsread_perM = _dsread_per_wg + _load_perM = 0 + + if const_expr(_mi < _HalfMIter): + _load_perM = ( + (_Aload_rep if (_Aload_num_perK - (m_repeat - 1 - _mi) * _Aload_rep) > 0 else 0) + + (_Bload_rep if (_Bload_num_perK - (_HalfMIter - 1 - _mi) * _Bload_rep) > 0 else 0) + ) + else: + _load_perM = ( + _Aload_rep if (_Aload_num_perK - (m_repeat - 1 - _mi) * _Aload_rep) > 0 else 0 + ) + + _sum_data = _dsread_perM + _load_perM + _round_data = max((_sum_data + _mfma_perM_perK - 1) // _mfma_perM_perK, 1) + + _inst_order = [] + _max_data = max(_load_perM, _dsread_perM) + for _j in range_constexpr(_max_data): + if const_expr(_load_perM > _j): + _inst_order.append(2) + if const_expr(_dsread_perM > _j): + _inst_order.append(3) + _pad_len = _mfma_perM_perK * _round_data - len(_inst_order) + _inst_order.extend([0] * _pad_len) + + for _nj in range_constexpr(_mfma_perM_perK): + if const_expr(_nj == 0): + _inst_idx = 0 + elif const_expr(_nj == 1): + _inst_idx = _mfma_perM_perK - 2 if _mfma_perM_perK > 2 else 1 + elif const_expr(_nj == 2): + _inst_idx = _mfma_perM_perK - 1 + else: + _inst_idx = _mfma_perM_perK - _nj + + rocdl.sched_mfma(1) + + for _r in range_constexpr(_round_data): + if const_expr(_r % 2 == 0): + _oi = _inst_idx + _r * _mfma_perM_perK + else: + _oi = (_r + 1) * _mfma_perM_perK - 1 - _inst_idx + if const_expr(_oi < len(_inst_order)): + if const_expr(_inst_order[_oi] == 2): + rocdl.sched_vmem(1) + elif const_expr(_inst_order[_oi] == 3): + rocdl.sched_dsrd(1) + + if const_expr(_Aload_num_perK == 0): + rocdl.sched_vmem(1) + rocdl.sched_barrier(0) + + # ---- intra-block split-K offset ---- + if const_expr(_split_k_intra > 1): + bz = gpu.block_id("z") + k_base = bz * arith.index(_k_dim) + else: + k_base = arith.index(0) + + # ---- CK-style pipeline: HEAD (scale prefetch) ---- + k0 = k_base + prefetch_x_to_lds(k0, lds_x_pong) + rocdl.sched_barrier(0) + + sc_raw_cur, kps_cur = load_b_scale_raw(k0) + b_v4_cur = load_b_raw(k0) + rocdl.sched_barrier(0) + + _k1 = k_base + arith.index(tile_k) + prefetch_x_to_lds(_k1, lds_x_ping) + rocdl.sched_barrier(0) + + acc = [acc_init] * (num_acc_n * m_repeat) + + rocdl.s_waitcnt(0) + gpu.barrier() + rocdl.sched_barrier(0) + a_cur = preload_a_from_lds(lds_x_pong) + b_sc_cur = extract_b_scales(sc_raw_cur, kps_cur) + gpu.barrier() + rocdl.sched_barrier(0) + + total_tiles = int(_k_dim) // int(tile_k) + pair_iters = max((total_tiles - 2) // 2, 0) + + for pair_i in range_constexpr(pair_iters): + k_iv = k_base + arith.index(pair_i * (tile_k * 2)) + + # ---- Half 2i: scale prefetch -> B_raw -> compute -> extract -> barrier ---- + rocdl.sched_barrier(0) + _k_a2 = k_iv + arith.index(tile_k * 2) + prefetch_x_to_lds(_k_a2, lds_x_pong) + rocdl.sched_barrier(0) + _k_b1 = k_iv + arith.index(tile_k) + sc_raw_nxt, kps_nxt = load_b_scale_raw(_k_b1) + rocdl.sched_barrier(0) + + b_v4_nxt = load_b_raw(_k_b1) + + rocdl.sched_barrier(0) + acc, _ = compute_tile( + acc, b_v4_cur, b_sc_cur, a_cur, + ) + a_next = preload_a_from_lds(lds_x_ping) + rocdl.sched_barrier(0) + b_sc_nxt = extract_b_scales(sc_raw_nxt, kps_nxt) + + rocdl.sched_barrier(0) + _barrier(lgkmcnt=2) + rocdl.sched_barrier(0) + a_cur = a_next + + # ---- Half 2i+1: scale prefetch -> B_raw -> compute -> extract -> barrier ---- + _k_a3 = k_iv + arith.index(tile_k * 3) + prefetch_x_to_lds(_k_a3, lds_x_ping) + rocdl.sched_barrier(0) + + _k_b2 = k_iv + arith.index(tile_k * 2) + sc_raw_cur2, kps_cur2 = load_b_scale_raw(_k_b2) + b_v4_cur2 = load_b_raw(_k_b2) + + rocdl.sched_barrier(0) + acc, _ = compute_tile( + acc, b_v4_nxt, b_sc_nxt, a_cur, + ) + a_next = preload_a_from_lds(lds_x_pong) + b_sc_cur2 = extract_b_scales(sc_raw_cur2, kps_cur2) + + rocdl.sched_barrier(0) + _barrier(lgkmcnt=2) + rocdl.sched_barrier(0) + b_v4_cur, b_sc_cur = b_v4_cur2, b_sc_cur2 + a_cur = a_next + + # ---- TAIL: last 2 tiles (scale prefetch) ---- + k_tail1 = k_base + arith.index(_k_dim) - arith.index(tile_k) + if const_expr(_pad_ku_skip > 0): + sc_raw_tail, kps_tail = load_b_scale_raw(k_tail1, k0_limit=_tail_k0_count) + b_v4_tail = load_b_raw(k_tail1, k0_limit=_tail_k0_count) + else: + sc_raw_tail, kps_tail = load_b_scale_raw(k_tail1) + b_v4_tail = load_b_raw(k_tail1) + + acc, _ = compute_tile( + acc, b_v4_cur, b_sc_cur, a_cur, + ) + if const_expr(_pad_ku_skip > 0): + a_next = preload_a_from_lds(lds_x_ping, ku_limit=_tail_ku) + b_sc_tail = extract_b_scales(sc_raw_tail, kps_tail, ku_limit=_tail_ku) + else: + a_next = preload_a_from_lds(lds_x_ping) + b_sc_tail = extract_b_scales(sc_raw_tail, kps_tail) + + hot_loop_scheduler() + rocdl.s_waitcnt(0) + a_cur = a_next + + acc, epilogue_pf = compute_tile( + acc, b_v4_tail, b_sc_tail, a_cur, + ku_count=_tail_ku if _pad_ku_skip > 0 else k_unroll, + prefetch_epilogue=True, + ) + + # ---- Bias: add to raw accumulators ---- + if const_expr(enable_bias): + _bias_vals = [] + for _ni in range_constexpr(num_acc_n): + _bn = by_n + n_tile_base + arith.index(_ni * 16) + lane_mod_16 + _bias_vals.append( + buffer_ops.buffer_load( + bias_rsrc, expert_off_idx + _bn, + vec_width=1, dtype=f32 + ) + ) + for _mi in range_constexpr(m_repeat): + for _ni in range_constexpr(num_acc_n): + _aidx = _mi * num_acc_n + _ni + _bsplat = vector.from_elements( + vec4_f32, + [_bias_vals[_ni], _bias_vals[_ni], _bias_vals[_ni], _bias_vals[_ni]], + ) + acc[_aidx] = arith.addf(acc[_aidx], _bsplat) + + # ---- Epilogue ---- + expert_off = expert_off_idx + mask24_i32 = arith.constant(0xFFFFFF, type=T.i32) + model_i32 = arith.constant(model_dim, type=T.i32) + topk_i32_v = topk_i32 + + zero_i32 = arith.constant(0, type=T.i32) + c2_i32 = arith.constant(2, type=T.i32) + mask_even_i32 = arith.constant(0xFFFFFFFE, type=T.i32) + e_vec = _e_vec + + sw_pf = None + tw_pf = None + if const_expr(epilogue_pf is not None): + sw_pf, tw_pf = epilogue_pf + + # No per-channel weight scale for MXFP4 (scales already applied in dequant). + sw_vals = [arith.constant(1.0, type=T.f32)] * num_acc_n + + def _a16_decode_lds_fused(idx): + """Load fused2 from lds_sorted_cache[idx]; decode (t, s, ts_ok).""" + _fv1 = vector.load_op(vec1_i32, lds_sorted_cache, [idx]) + fused2 = vector.extract(_fv1, static_position=[0], dynamic_position=[]) + t = fused2 & mask24_i32 + s = fused2 >> 24 + t_ok = arith.cmpi(arith.CmpIPredicate.ult, t, tokens_i32) + s_ok = arith.cmpi(arith.CmpIPredicate.ult, s, topk_i32_v) + return fused2, t, s, t_ok & s_ok + + if const_expr(out_is_f32): + c4_i32 = arith.constant(4, type=T.i32) + + def _stage2_row_atomic(*, mi: int, ii: int, row_in_tile, row): + fused2, t2, s2, ts_ok = _a16_decode_lds_fused(row_in_tile) + t2_safe = arith.select(ts_ok, t2, arith.constant(0, type=T.i32)) + sx = arith.select( + ts_ok, + arith.constant(1.0, type=T.f32), + arith.constant(0.0, type=T.f32), + ) + if const_expr(doweight_stage2): + tw_idx = (mi * 4) + ii + if const_expr(tw_pf is not None): + tw = arith.select( + ts_ok, tw_pf[tw_idx], + arith.constant(0.0, type=T.f32), + ) + else: + tw = arith.select( + ts_ok, + buffer_ops.buffer_load( + sorted_w_rsrc, row, vec_width=1, dtype=f32 + ), + arith.constant(0.0, type=T.f32), + ) + idx0 = t2_safe * model_i32 + + for ni in range_constexpr(num_acc_n): + col_g = col_g_list[ni] + acc_idx = mi * num_acc_n + ni + v = vector.extract( + acc[acc_idx], static_position=[ii], dynamic_position=[] + ) + v = v * sx + if const_expr(doweight_stage2): + v = v * tw + col_i32 = arith.index_cast(i32, col_g) + idx_elem = idx0 + col_i32 + byte_off = idx_elem * c4_i32 + rocdl.raw_ptr_buffer_atomic_fadd( + v, out_rsrc, byte_off, zero_i32, zero_i32, + ) + + default_epilog( + arith=arith, + range_constexpr=range_constexpr, + m_repeat=m_repeat, + lane_div_16=lane_div_16, + bx_m=bx_m, + body_row=_stage2_row_atomic, + ) + else: + if lds_out is None: + raise RuntimeError("CShuffle epilogue requires lds_out.") + + out_base_idx = None + if const_expr(out_is_bf16): + _llvm_ptr_ty = ir.Type.parse("!llvm.ptr") + _out_base_ptr = _fly.extract_aligned_pointer_as_index( + _llvm_ptr_ty, arg_out + ) + out_base_idx = arith.index_cast( + T.index, llvm.ptrtoint(T.i64, _out_base_ptr) + ) + + def _row_valid_for(row, ts_ok): + row_i32 = arith.index_cast(i32, row) + return arith.cmpi( + arith.CmpIPredicate.ult, row_i32, num_valid_i32 + ) & ts_ok + + def write_row_to_lds( + *, mi: int, ii: int, row_in_tile, row, + row_base_lds, col_base_local, num_acc_n: int, lds_out, + ): + _, _, _, ts_ok = _a16_decode_lds_fused(row_in_tile) + row_valid = _row_valid_for(row, ts_ok) + + if const_expr(doweight_stage2): + tw_idx = (mi * 4) + ii + if const_expr(tw_pf is not None): + tw = tw_pf[tw_idx] + else: + tw = buffer_ops.buffer_load( + sorted_w_rsrc, row, vec_width=1, dtype=f32 + ) + + for ni in range_constexpr(num_acc_n): + col_local = col_base_local + (ni * 16) + acc_idx = mi * num_acc_n + ni + v = vector.extract( + acc[acc_idx], static_position=[ii], dynamic_position=[] + ) + if const_expr(doweight_stage2): + v = v * tw + v_out = arith.trunc_f(out_elem(), v) + lds_idx = row_base_lds + col_local + vec1_out = T.vec(1, out_elem()) + v1 = vector.from_elements(vec1_out, [v_out]) + vector.store(v1, lds_out, [lds_idx], alignment=2) + + def precompute_row(*, row_local, row): + fused2, _, _, ts_ok = _a16_decode_lds_fused(row_local) + return (fused2, _row_valid_for(row, ts_ok)) + + def store_pair(*, row_local, row, row_ctx, col_pair0, col_g0, frag): + fused = row_ctx + t = fused & mask24_i32 + s = fused >> 24 + idx0 = t * model_i32 + if not bool(accumulate): + ts = t * topk_i32_v + s + idx0 = ts * model_i32 + col_i32 = arith.index_cast(i32, col_g0) + idx_elem = idx0 + col_i32 + idx_elem_even = idx_elem & mask_even_i32 + if const_expr(out_is_bf16): + if const_expr(bool(accumulate)): + byte_off = idx_elem_even * c2_i32 + byte_off_idx = arith.index_cast(T.index, byte_off) + ptr_addr_idx = out_base_idx + byte_off_idx + out_ptr = buffer_ops.create_llvm_ptr( + ptr_addr_idx, address_space=1 + ) + out_ptr_v = ( + out_ptr._value if hasattr(out_ptr, "_value") else out_ptr + ) + frag_v = frag._value if hasattr(frag, "_value") else frag + llvm.AtomicRMWOp( + llvm.AtomicBinOp.fadd, out_ptr_v, frag_v, + llvm.AtomicOrdering.monotonic, + syncscope="agent", alignment=4, + ) + else: + buffer_ops.buffer_store(frag, out_rsrc, idx_elem_even) + else: + byte_off = idx_elem_even * c2_i32 + if const_expr(bool(accumulate)): + rocdl.raw_ptr_buffer_atomic_fadd( + frag, out_rsrc, byte_off, zero_i32, zero_i32, + ) + else: + buffer_ops.buffer_store(frag, out_rsrc, idx_elem_even) + + c_shuffle_epilog( + arith=arith, vector=vector, gpu=gpu, scf=scf, + range_constexpr=range_constexpr, + tile_m=tile_m, tile_n=tile_n, e_vec=e_vec, + m_repeat=m_repeat, num_acc_n=num_acc_n, + tx=tx, lane_div_16=lane_div_16, lane_mod_16=lane_mod_16, + bx_m=bx_m, by_n=by_n, n_tile_base=n_tile_base, + lds_out=lds_out, + frag_elem_type=(T.bf16 if out_is_bf16 else T.f16), + write_row_to_lds=write_row_to_lds, + precompute_row=precompute_row, + store_pair=store_pair, + ) + return + + # n_idx / expert_off_idx already computed in the shared preamble. + # ---- X gmem->reg prefetch (match preshuffle GEMM mapping) ---- # Prefer 16B buffer-load (dwordx4). If the per-thread byte count isn't divisible by # 16, fall back to 8B (dwordx2) or 4B (dword) loads. For fp16 we require 16B. @@ -3438,14 +5989,7 @@ def load_x_tile(base_k): parts.append(vector.bitcast(vec1_i32, x_vec)) return parts - # tx -> wave/lane (GEMM-style decomposition). - coord_wl = idx2crd(tx, layout_tx_wave_lane) - wave_id = layout_get(coord_wl, 0) - lane_id = layout_get(coord_wl, 1) - coord_l16 = idx2crd(lane_id, layout_lane16) - lane_div_16 = layout_get(coord_l16, 0) - lane_mod_16 = layout_get(coord_l16, 1) - + # tx -> wave/lane decomposition computed in the shared preamble. row_a_lds = lane_mod_16 col_offset_base = lane_div_16 * arith.constant(16, index=True) @@ -3522,7 +6066,7 @@ def load_b_packs_k64(base_k, ku: int, ni: int): + n_intra_list[ni] * arith.constant(_b_stride_nlane, index=True) ) - vec_elems = kpack_bytes // int(b_elem_bytes) + vec_elems = mixed_kpack_bytes // int(b_elem_bytes) b16 = _buffer_load_vec( buffer_ops, vector, @@ -4257,8 +6801,9 @@ def atomic_add_f16x2(val_f16x2, byte_off_i32): # Both accumulate=True (global atomic) and accumulate=False (global store) # need 64-bit addressing to avoid i32 offset overflow when # tokens * model_dim * elem_bytes > INT32_MAX (~150K tokens for model_dim=7168). - from flydsl._mlir.dialects import fly as _fly - + # NOTE: `_fly` is already imported at module level; do not re-import here + # because that would shadow the closure binding inside this function and + # cause UnboundLocalError when the A16W4 early-return branch reads `_fly`. _llvm_ptr_ty = ir.Type.parse("!llvm.ptr") out_base_ptr = _fly.extract_aligned_pointer_as_index( _llvm_ptr_ty, arg_out @@ -4388,7 +6933,9 @@ def store_pair(*, row_local, row, row_ctx, col_pair0, col_g0, frag): alignment=_e_vec * out_elem_bytes, ) - _e_vec = 2 if accumulate else min(tile_n // 32, 8) + # NOTE: rename to avoid Python closure conflict with the A16W4 branch + # that reads `_e_vec` from the outer setup scope. + _e_vec_cs = 2 if accumulate else min(tile_n // 32, 8) c_shuffle_epilog( arith=arith, vector=vector, @@ -4397,7 +6944,7 @@ def store_pair(*, row_local, row, row_ctx, col_pair0, col_g0, frag): range_constexpr=range_constexpr, tile_m=tile_m, tile_n=tile_n, - e_vec=_e_vec, + e_vec=_e_vec_cs, m_repeat=m_repeat, num_acc_n=num_acc_n, tx=tx, @@ -4457,9 +7004,11 @@ def store_pair(*, row_local, row, row_ctx, col_pair0, col_g0, frag): inter_dim_pad, use_cshuffle_epilog, persist_m, - _sort_block_m, - _cu_num if _persistent else 0, + sort_block_m, + b_nt, xcd_swizzle, + waves_per_eu, + split_k_intra, ) @flyc.jit @@ -4497,6 +7046,7 @@ def launch_mixed_moe_gemm2( if const_expr(_persistent): gy = arith.constant(_cu_num, index=True) else: + # For A16W4: persist_m=1 (set in setup), so gy reduces to size_expert_ids_in. _c_pm_l = arith.constant(persist_m, index=True) gy = ( arith.index_cast(ir.IndexType.get(), i32_size_expert_ids_in.ir_value()) @@ -4504,6 +7054,8 @@ def launch_mixed_moe_gemm2( - arith.constant(1, index=True) ) / _c_pm_l + gz = _split_k_intra if is_a16w4 else 1 + moe_gemm2( arg_out, arg_x, @@ -4520,9 +7072,61 @@ def launch_mixed_moe_gemm2( i32_k_in, i32_size_expert_ids_in, ).launch( - grid=(gx, gy, 1), + grid=(gx, gy, gz), block=(256, 1, 1), stream=stream, ) return launch_mixed_moe_gemm2 + +# --------------------------------------------------------------------------- +# Stage 2: A16W4 MXFP4 kernel (BF16 activations x FP4 E2M1 weights) +# --------------------------------------------------------------------------- + +def _decode_e8m0_byte_to_f32(byte_i8, arith_mod): + """Convert a single E8M0 byte (i8) to f32 = 2^(e - 127).""" + c23 = arith_mod.constant(23, type=T.i32) + byte_u32 = arith_mod.extui(T.i32, byte_i8) + scale_bits = arith_mod.shli(byte_u32, c23) + return arith_mod.bitcast(T.f32, scale_bits) + +def compile_a16w4_moe_gemm2( + *, + model_dim: int, + inter_dim: int, + experts: int, + topk: int, + tile_m: int, + tile_n: int, + tile_k: int, + doweight_stage2: bool, + out_dtype: str = "bf16", + use_cshuffle_epilog: bool | None = None, + enable_bias: bool = False, + model_dim_pad: int = 0, + inter_dim_pad: int = 0, + accumulate: bool = True, + waves_per_eu: int = 0, + split_k_intra: int = 1, +): + """Compatibility wrapper for callers that import the legacy A16W4 entry.""" + return compile_mixed_moe_gemm2( + model_dim=model_dim, + inter_dim=inter_dim, + experts=experts, + topk=topk, + tile_m=tile_m, + tile_n=tile_n, + tile_k=tile_k, + doweight_stage2=doweight_stage2, + a_dtype="bf16", + b_dtype="fp4", + out_dtype=out_dtype, + use_cshuffle_epilog=use_cshuffle_epilog, + enable_bias=enable_bias, + model_dim_pad=model_dim_pad, + inter_dim_pad=inter_dim_pad, + accumulate=accumulate, + waves_per_eu=waves_per_eu, + split_k_intra=split_k_intra, + ) diff --git a/tests/kernels/test_moe_gemm.py b/tests/kernels/test_moe_gemm.py index 5796ab75..28ea70df 100644 --- a/tests/kernels/test_moe_gemm.py +++ b/tests/kernels/test_moe_gemm.py @@ -81,6 +81,7 @@ def _pack_shuffled_int8_to_packed_int4_no_perm(x_shuf_i8: torch.Tensor) -> torch MoeGemm2Mode, ) from kernels.mixed_moe_gemm_2stage import ( + compile_a16w4_moe_gemm2, compile_mixed_moe_gemm1, compile_mixed_moe_gemm2, ) @@ -2061,6 +2062,142 @@ def test_moe_stage2_standalone( ) +# --------------------------------------------------------------------------- +# A16W4 stage2 test: BF16 activations x MXFP4 (FP4 E2M1) weights +# --------------------------------------------------------------------------- + +_A16W4_SHAPES = [ + pytest.param(16, 3072, 3072, 128, 4, 32, 256, 256, 1, id="a16w4-s2-small"), + pytest.param(128, 3072, 3072, 128, 4, 64, 256, 256, 1, id="a16w4-s2-medium"), + pytest.param(4, 3072, 3072, 128, 4, 16, 128, 256, 2, id="a16w4-s2-kbatch2"), +] + +_A16W4_SKIP = pytest.mark.skipif( + "gfx95" not in ARCH, + reason=f"A16W4 stage2 requires gfx950+, got {ARCH}", +) + + +@_A16W4_SKIP +@pytest.mark.parametrize( + "tokens, model_dim, inter_dim, experts, topk, tile_m, tile_n, tile_k, split_k_intra", + _A16W4_SHAPES, +) +def test_moe_gemm2_a16w4( + tokens: int, + model_dim: int, + inter_dim: int, + experts: int, + topk: int, + tile_m: int, + tile_n: int, + tile_k: int, + split_k_intra: int, + *, + seed: int = 0, + atol: float = 1.0, + rtol: float = 0.05, +): + """Stage2 correctness test for compile_a16w4_moe_gemm2. + + Activations are BF16 (no FP4 quantisation on the A side). + Weights are MXFP4 E2M1 with E8M0 per-1x32 block scales, + pre-shuffled via shuffle_weight_a16w4 / shuffle_scale_a16w4 from aiter. + Result is compared against torch_moe_gemm2 in FP32. + """ + from aiter.ops.shuffle import shuffle_weight_a16w4, shuffle_scale_a16w4 + from aiter.fused_moe import moe_sorting + + device = torch.device("cuda") + torch.manual_seed(seed) + + # Generate BF16 activations and quantised FP4 weights + dtype = torch.bfloat16 + a2_bf16 = torch.randn((tokens * topk, inter_dim), device=device, dtype=dtype) * 0.2 + w2_fp32 = torch.randn((experts, model_dim, inter_dim), device=device, dtype=torch.float32) / 10 + + # Quantize weights to MXFP4 per-1x32 + try: + from tests.kernels.utils import fp4_utils + except ImportError: + fp4_utils = None + if fp4_utils is None: + pytest.skip("FP4 dependencies not available") + + from aiter import QuantType + import aiter + torch_quant = aiter.get_torch_quant(QuantType.per_1x32) + w2_qt_raw, w2_scale = torch_quant(w2_fp32.to(dtype), quant_dtype=aiter.dtypes.fp4x2) + w2_qt_raw = w2_qt_raw.view(experts, model_dim, inter_dim // 2) + + # Pre-shuffle weights and scales for the kernel; view as uint8 for DLPack. + w2_qt_shuf = shuffle_weight_a16w4(w2_qt_raw, 16, False).view(torch.uint8).contiguous() + w2_scale_shuf = shuffle_scale_a16w4(w2_scale, experts, False).view(torch.uint8).contiguous() + + # MoE routing: uniform assignment so every token routes to expert 0..topk-1 + topk_ids = torch.zeros((tokens, topk), device=device, dtype=torch.int32) + for k in range(topk): + topk_ids[:, k] = k % experts + topk_weights = torch.ones((tokens, topk), device=device, dtype=torch.float32) / topk + + sorted_ids, sorted_weights, sorted_expert_ids, num_valid_ids, _ = moe_sorting( + topk_ids, topk_weights, experts, model_dim, dtype, tile_m + ) + + # Compile and run kernel (atomic accumulation mode) + exe = compile_a16w4_moe_gemm2( + model_dim=model_dim, + inter_dim=inter_dim, + experts=experts, + topk=topk, + tile_m=tile_m, + tile_n=tile_n, + tile_k=tile_k, + doweight_stage2=True, + out_dtype="bf16", + accumulate=True, + split_k_intra=split_k_intra, + ) + + out_kernel = torch.zeros((tokens, model_dim), device=device, dtype=dtype) + # NOTE: launcher signature now matches the unified generic mixed-stage2 launcher, + # which accepts arg_scale_x (unused for A16W4 since is_f16_a=True). Pass empty. + _empty_scale_x = torch.empty(0, device=device, dtype=torch.float32) + exe( + out_kernel, + a2_bf16, + w2_qt_shuf, + _empty_scale_x, + w2_scale_shuf, + sorted_ids, + sorted_expert_ids, + sorted_weights, + num_valid_ids, + torch.empty(0, device=device, dtype=dtype), # no bias + tokens, + model_dim, + inter_dim, + sorted_ids.shape[0], + torch.cuda.current_stream(), + ) + torch.cuda.synchronize() + + # Reference: dequantize weights to fp32, compute via torch_moe_gemm2 + ref = torch_moe_gemm2( + a2_q=a2_bf16.view(tokens, topk, inter_dim), + w2_q=w2_fp32, + scale_a2=None, + scale_w2=None, + topk_ids=topk_ids, + topk_weights=topk_weights, + model_dim=model_dim, + doweight_stage2=True, + ) + ref = ref.to(dtype) + + verify_output(out_kernel, ref, atol=atol, rtol=rtol) + + if __name__ == "__main__": torch.set_default_device("cuda") # CLI (mirrors key knobs from aiter/op_tests/test_moe_2stage.py, stage1 subset)