From c8748d9732d78e32c50799fb9055867d510b6452 Mon Sep 17 00:00:00 2001 From: Andrea Picciau Date: Tue, 21 Apr 2026 08:01:04 +0000 Subject: [PATCH 1/2] Add BF16xFP4 MoE GEMM stage1 kernel and tests Implements compile_mixed_moe_gemm1 for BF16 activations x FP4 E2M1 weights using mfma_f32_16x16x32_bf16 on gfx950. Key additions: - compute_bf16xfp4_tile: software dequant via v_cvt_scalef32_pk_bf16_fp4, E8M0 scale loading with correct k_mid/m1 byte-shift extraction, and k1_override for the two K32 sub-steps per K64 outer tile. - col_n_valid guard in _stage1_store_row: prevents OOB CTAs (by=1 when tile_n==inter_dim) from writing zero rows into neighbouring token-slot output cells via a race condition on the strided output layout. - swiglu MLIR fix: convert Python float literals to arith.constant values before passing to arith.minimumf/maximumf. - test_moe_gemm1_bf16xfp4: 6 parametrizations (tile_m in {16,32,64} x act in {silu,swiglu}), all passing with logits_diff < 0.002. --- kernels/mixed_moe_gemm_2stage.py | 631 ++++++++++++++++++------- python/flydsl/expr/rocdl/inline_asm.py | 3 +- tests/kernels/test_moe_gemm.py | 100 +++- 3 files changed, 571 insertions(+), 163 deletions(-) diff --git a/kernels/mixed_moe_gemm_2stage.py b/kernels/mixed_moe_gemm_2stage.py index eb23631ad..1f16debb1 100644 --- a/kernels/mixed_moe_gemm_2stage.py +++ b/kernels/mixed_moe_gemm_2stage.py @@ -104,9 +104,9 @@ def compile_mixed_moe_gemm1( gpu_arch = get_hip_arch() allocator = SmemAllocator(None, arch=gpu_arch) - if a_dtype not in ("fp8", "fp16", "int8", "fp4"): + if a_dtype not in ("fp8", "fp16", "int8", "fp4", "bf16"): raise ValueError( - f"a_dtype must be one of ('fp8','fp16','int8','fp4'), got {a_dtype!r}" + f"a_dtype must be one of ('fp8','fp16','int8','fp4','bf16'), got {a_dtype!r}" ) if b_dtype not in ("fp8", "fp16", "int8", "int4", "fp4"): raise ValueError( @@ -114,6 +114,7 @@ def compile_mixed_moe_gemm1( ) is_f16_a = a_dtype == "fp16" + is_bf16_a = a_dtype == "bf16" is_f16_b = b_dtype == "fp16" is_f16 = is_f16_a or is_f16_b @@ -121,13 +122,17 @@ def compile_mixed_moe_gemm1( is_f4_a = a_dtype == "fp4" is_f4_b = b_dtype == "fp4" + # BF16xFP4: W4A16 pattern — keep activations in BF16, dequantize FP4 B → BF16 in software. + # Uses mfma_f32_16x16x32_bf16 (K=32) + v_cvt_scalef32_pk_bf16_fp4. + is_bf16xfp4 = is_bf16_a and is_f4_b + pack_M = 2 pack_N = 2 pack_K = 2 elem_bytes = 1 - a_elem_bytes = 2 if is_f16_a else 1 + a_elem_bytes = 2 if (is_f16_a or is_bf16_a) else 1 b_elem_bytes = 1 tile_k_bytes = int(tile_k) * int(a_elem_bytes) @@ -135,7 +140,7 @@ def compile_mixed_moe_gemm1( cbsz = 0 if is_f8_a else 4 blgp = 4 - # K64-byte micro-step: always 64 bytes per `ku`. For fp16, this is 32 elements (2xK16 MFMA). + # K64-byte micro-step: always 64 bytes per `ku`. For fp16/bf16, this is 32 elements (2xK16 MFMA). if (tile_k_bytes % 64) != 0: raise ValueError( f"tile_k_bytes must be divisible by 64, got tile_k_bytes={tile_k_bytes} " @@ -145,7 +150,11 @@ def compile_mixed_moe_gemm1( def _x_lds_elem_type(): - return T.f16 if is_f16_a else T.f8 + if is_f16_a: + return T.f16 + if is_bf16_a: + return T.bf16 + return T.f8 def _out_elem_type(): return T.bf16 if out_dtype == "bf16" else T.f16 @@ -217,7 +226,7 @@ def moe_gemm1( size_expert_ids_in = arith.index_cast(T.index, i32_size_expert_ids_in) tokens_i32_v = i32_tokens_in k_i32_v = i32_k_in - x_elem = T.f16 if is_f16_a else T.f8 + x_elem = _x_lds_elem_type() vec4_f32 = T.vec(4, T.f32) vec4_i32 = T.vec(4, T.i32) vec1_f32 = T.vec(1, T.f32) @@ -240,9 +249,11 @@ def silu(x): _arith_max = getattr(arith, "maximum", None) or getattr(arith, "maximumf") def swiglu(gate, up, alpha=1.702, limit=7.0): - gate = _arith_min(gate, limit) - up = _arith_min(up, limit) - up = _arith_max(up, -limit) + limit_f32 = arith.constant(float(limit), type=T.f32) + nlimit_f32 = arith.constant(-float(limit), type=T.f32) + gate = _arith_min(gate, limit_f32) + up = _arith_min(up, limit_f32) + up = _arith_max(up, nlimit_f32) t = gate * alpha * (-1.4426950408889634) # -log2(e) emu = rocdl.exp2(T.f32, t) @@ -343,7 +354,7 @@ def swiglu(gate, up, alpha=1.702, limit=7.0): x_nbytes = ( tokens_in * (k_in // fx.Index(int(a_elem_vec_pack))) - * fx.Index(int(elem_bytes)) + * fx.Index(int(a_elem_bytes)) ) x_rsrc = buffer_ops.create_buffer_resource( arg_x, max_size=False, num_records_bytes=x_nbytes @@ -369,7 +380,7 @@ def swiglu(gate, up, alpha=1.702, limit=7.0): arg_out, max_size=False, num_records_bytes=out_nbytes_i32 ) - if is_f16_a: + if is_f16_a or is_bf16_a: sx_rsrc = None else: # A1 microscale: [sorted_rows, K/32] e8m0 bytes, packed as i32. @@ -454,10 +465,10 @@ def swiglu(gate, up, alpha=1.702, limit=7.0): # For fp4, 2 elements per byte, so divide by a_elem_vec_pack. c_a_pack = fx.Index(int(a_elem_vec_pack)) c_k_div4 = ( - (k_in // c_a_pack) * fx.Index(int(elem_bytes)) + (k_in // c_a_pack) * fx.Index(int(a_elem_bytes)) ) // arith.index(4) c_k_div4_i32 = arith.index_cast(T.i32, c_k_div4) - tile_k_dwords = (int(tile_k) * int(elem_bytes)) // 4 + tile_k_dwords = (int(tile_k) * int(a_elem_bytes)) // 4 layout_x_tile_div4 = fx.make_layout( (tile_m, tile_k_dwords), stride=(tile_k_dwords, 1) ) @@ -518,7 +529,7 @@ def load_x(idx_i32): For 16B, keep the fast dwordx4 path. For 8B/4B, use byte offsets. """ idx_elem = ( - idx_i32 if elem_bytes == 1 else (idx_i32 * arith.index(2)) + idx_i32 if a_elem_bytes == 1 else (idx_i32 * arith.index(2)) ) return buffer_copy_gmem16_dwordx4( buffer_ops, @@ -527,6 +538,7 @@ def load_x(idx_i32): idx_i32=idx_elem, rsrc=x_rsrc, vec_elems=vec16_elems, + elem_bytes=a_elem_bytes, ) _zero_row_idx = fx.Index(0) @@ -535,7 +547,7 @@ def load_x_tile(base_k): """Prefetch the per-thread X tile portion (gmem -> regs) for a given K base (in elements).""" base_k_div4 = ( (base_k // c_a_pack) - * fx.Index(int(elem_bytes)) + * fx.Index(int(a_elem_bytes)) ) // arith.index(4) zero_x_i32 = arith.constant_vector(0, vec4_i32) parts = [] @@ -607,71 +619,91 @@ def load_x_tile(base_k): up_n_intra_list.append(up_n_intra) # --- B Load Logic (K64) - shared layout with preshuffle GEMM --- - def load_b_packs_k64(base_k, ku: int, n_blk, n_intra): - # K64 micro-step = 2x K32 MFMA steps. Reuse the shared helper. - b0 = load_b_pack_k32( - buffer_ops, - arith, - vector, - arg_b=arg_w, - b_rsrc=w_rsrc, - layout_b=layout_b, - base_k=base_k, - ki_step=ku * 2, - n_blk=n_blk, - n_intra=n_intra, - lane_div_16=lane_div_16, - elem_type=_w_elem_type(is_f4_b=is_f4_b, is_f16_b=is_f16), - kpack_bytes=kpack_bytes, - elem_bytes=b_elem_bytes, - unpack_int4=bool(is_int4), - ) - b1 = load_b_pack_k32( - buffer_ops, - arith, - vector, - arg_b=arg_w, - b_rsrc=w_rsrc, - layout_b=layout_b, - base_k=base_k, - ki_step=ku * 2 + 1, - n_blk=n_blk, - n_intra=n_intra, - lane_div_16=lane_div_16, + def _load_b_k32_single(base_k, ki_step, n_blk, n_intra, *, k1_override=None): + """Load one K32 B pack; k1_override forces a specific k1 (BF16×FP4 use).""" + return load_b_pack_k32( + buffer_ops, arith, vector, + arg_b=arg_w, b_rsrc=w_rsrc, layout_b=layout_b, + base_k=base_k, ki_step=ki_step, + n_blk=n_blk, n_intra=n_intra, + lane_div_16=(k1_override if k1_override is not None else lane_div_16), elem_type=_w_elem_type(is_f4_b=is_f4_b, is_f16_b=is_f16), kpack_bytes=kpack_bytes, elem_bytes=b_elem_bytes, unpack_int4=bool(is_int4), ) + + def load_b_packs_k64(base_k, ku: int, n_blk, n_intra): + """Load B packs for one K64 outer micro-step. + + For BF16×FP4 the FP4 preshuffle kpack layout encodes K ranges via both k0 + and k1 coordinates: (k0, k1) → K[k0*128 + k1*32 .. k0*128 + k1*32 + 31]. + With k_unroll=4 covering 256 BF16 = 256 FP4 elements per tile: + ku=0 → (k0_base+0, k1=0) and (k0_base+0, k1=1): K[0..63] + ku=1 → (k0_base+0, k1=2) and (k0_base+0, k1=3): K[64..127] + ku=2 → (k0_base+1, k1=0) and (k0_base+1, k1=1): K[128..191] + ku=3 → (k0_base+1, k1=2) and (k0_base+1, k1=3): K[192..255] + ki_step = ku//2*2 (even) gives k0=k0_base+ku//2, half=0 (bytes[0..7]). + ki_step = ku//2*2+1 (odd) gives same k0, half=1 (bytes[8..15]). + k1 cycles as (ku%2)*2 and (ku%2)*2+1 to select the correct K32 sub-range. + Within each kpack, _select_lane_fp4_i32 picks the 4-byte slice per lane. + + Returns for BF16×FP4: (b_k1a_h0, b_k1a_h1, b_k1b_h0, b_k1b_h1) where + k1a = (ku%2)*2, k1b = (ku%2)*2+1 (MFMA K32 sub-steps 0 and 1). + Returns for all other paths: (b0, b1) [two i64 as before] + """ + if is_bf16xfp4: + # ki_step_h0 gives k0=k0_base+ku//2, half=0 (bytes[0..7] of kpack). + # ki_step_h1 gives same k0, half=1 (bytes[8..15] of kpack). + # k1a/k1b select the two K32-element sub-steps within this K64 outer step. + ki_step_h0 = (ku // 2) * 2 + ki_step_h1 = (ku // 2) * 2 + 1 + k1a = arith.index((ku % 2) * 2) # 0 for ku=0,1; 2 for ku=2,3 + k1b = arith.index((ku % 2) * 2 + 1) # 1 for ku=0,1; 3 for ku=2,3 + b_k1a_h0 = _load_b_k32_single(base_k, ki_step_h0, n_blk, n_intra, k1_override=k1a) + b_k1a_h1 = _load_b_k32_single(base_k, ki_step_h1, n_blk, n_intra, k1_override=k1a) + b_k1b_h0 = _load_b_k32_single(base_k, ki_step_h0, n_blk, n_intra, k1_override=k1b) + b_k1b_h1 = _load_b_k32_single(base_k, ki_step_h1, n_blk, n_intra, k1_override=k1b) + return (b_k1a_h0, b_k1a_h1, b_k1b_h0, b_k1b_h1) + b0 = _load_b_k32_single(base_k, ku * 2, n_blk, n_intra) + b1 = _load_b_k32_single(base_k, ku * 2 + 1, n_blk, n_intra) return b0, b1 def load_b_tile(base_k): gate_b_tile = [] up_b_tile = [] for ku in range_constexpr(k_unroll): - gate_packs0 = [] - gate_packs1 = [] - up_packs0 = [] - up_packs1 = [] - for ni in range_constexpr(num_acc_n): - gate_b0, gate_b1 = load_b_packs_k64( - base_k, - ku, - gate_n_blk_list[ni], - gate_n_intra_list[ni], - ) - up_b0, up_b1 = load_b_packs_k64( - base_k, - ku, - up_n_blk_list[ni], - up_n_intra_list[ni], - ) - gate_packs0.append(gate_b0) - gate_packs1.append(gate_b1) - up_packs0.append(up_b0) - up_packs1.append(up_b1) - gate_b_tile.append((gate_packs0, gate_packs1)) - up_b_tile.append((up_packs0, up_packs1)) + if is_bf16xfp4: + # Returns 4-tuple per ni: (b_k10_h0, b_k10_h1, b_k11_h0, b_k11_h1) + gate_k10_h0, gate_k10_h1, gate_k11_h0, gate_k11_h1 = [], [], [], [] + up_k10_h0, up_k10_h1, up_k11_h0, up_k11_h1 = [], [], [], [] + for ni in range_constexpr(num_acc_n): + g = load_b_packs_k64(base_k, ku, gate_n_blk_list[ni], gate_n_intra_list[ni]) + u = load_b_packs_k64(base_k, ku, up_n_blk_list[ni], up_n_intra_list[ni]) + gate_k10_h0.append(g[0]); gate_k10_h1.append(g[1]) + gate_k11_h0.append(g[2]); gate_k11_h1.append(g[3]) + up_k10_h0.append(u[0]); up_k10_h1.append(u[1]) + up_k11_h0.append(u[2]); up_k11_h1.append(u[3]) + gate_b_tile.append((gate_k10_h0, gate_k10_h1, gate_k11_h0, gate_k11_h1)) + up_b_tile.append((up_k10_h0, up_k10_h1, up_k11_h0, up_k11_h1)) + else: + gate_packs0 = [] + gate_packs1 = [] + up_packs0 = [] + up_packs1 = [] + for ni in range_constexpr(num_acc_n): + gate_b0, gate_b1 = load_b_packs_k64( + base_k, ku, gate_n_blk_list[ni], gate_n_intra_list[ni], + ) + up_b0, up_b1 = load_b_packs_k64( + base_k, ku, up_n_blk_list[ni], up_n_intra_list[ni], + ) + gate_packs0.append(gate_b0) + gate_packs1.append(gate_b1) + up_packs0.append(up_b0) + up_packs1.append(up_b1) + gate_b_tile.append((gate_packs0, gate_packs1)) + up_b_tile.append((up_packs0, up_packs1)) return gate_b_tile, up_b_tile def load_scale(arg_scale, rsrc, scale_info, ku, mni): @@ -741,6 +773,9 @@ def load_b_scale_tile(base_k): return gate_b_scale_tile, up_b_scale_tile def load_a_scale_tile(base_k): + if is_bf16_a: + # No A scale for BF16 activations (full precision, no quantization). + return [] a_scale_tile = [] for ku in range_constexpr(k_unroll_packed): for mi in range_constexpr(m_repeat_packed): @@ -756,6 +791,9 @@ def load_a_scale_tile(base_k): def prefetch_ab_scale_tile(base_k): gate_bs, up_bs = load_b_scale_tile(base_k) + if is_bf16xfp4: + gate_fp4_bs, up_fp4_bs = load_b_scale_tile_bf16fp4(base_k) + return [[], gate_bs, up_bs, gate_fp4_bs, up_fp4_bs] return [load_a_scale_tile(base_k), gate_bs, up_bs] acc_gate = [acc_init] * (num_acc_n * m_repeat) @@ -779,7 +817,7 @@ def store_x_tile_to_lds(vec_x_in_parts, lds_base): k_blocks16=k_blocks16, lds_base=lds_base, vec_part_i32x4=vec_x_in_parts[i], - elem_bytes=elem_bytes, + elem_bytes=a_elem_bytes, ) # --- A LDS load helper for K64 (load 16B once, extract 2x i64 halves) --- @@ -790,7 +828,7 @@ def lds_load_packs_k64(curr_row_a_lds, col_base, lds_base): ) col_base_swz = ( col_base_swz_bytes - if elem_bytes == 1 + if a_elem_bytes == 1 else (col_base_swz_bytes // arith.index(2)) ) idx_a16 = lds_row_major_idx( @@ -974,6 +1012,285 @@ def pack_i64x4_to_i32x8(x0, x1, x2, x3): ) return gate_list, up_list, epilogue_pf + def load_b_scale_tile_bf16fp4(base_k): + """Load B scales for the BF16×FP4 path. + + Returns (gate_scales, up_scales), each a flat list of length k_unroll * 2 * num_acc_n. + Indexed as [(ku * 2 + ks) * num_acc_n + ni] where ks=0 is K32 sub-step 0 + and ks=1 is K32 sub-step 1. Each entry is an i32 with the E8M0 scale byte in bits [7:0]. + + The e8m0_shuffle layout (after contiguous permute) is: + (m0=c_mn/32, k_outer=c_k/128, k_inner=4, n_lane=16, k_mid=2, m1=2) + viewed as (c_mn, c_k/16) of uint8, loaded as i32 (4 uint8 per load). + + One i32 holds 4 bytes for (m1={0,1}, k_mid={0,1}): + byte 0: m1=0, k_mid=0 (bits [7:0]) + byte 1: m1=1, k_mid=0 (bits [15:8]) + byte 2: m1=0, k_mid=1 (bits [23:16]) + byte 3: m1=1, k_mid=1 (bits [31:24]) + + Total byte_shift = k_mid_shift + m1_shift: + k_mid_shift = (ku // 2) * 16 (0 for ku<2; 16 for ku>=2) + m1_shift = (ni % pack_N) * 8 (0 for even ni; 8 for odd ni within a pack) + mni = (expert_off_idx + col_base_for_pack) // 32 where col_base_for_pack uses ni//pack_N. + """ + gate_scales = [] + up_scales = [] + + def _load_scale_k_inner(rsrc, scale_info, scale_ku_total, k_inner_const, + mni, valid, byte_shift_const): + n_lane = lane_mod_16 + k_inner_idx = arith.constant(k_inner_const, index=True) + safe_mni = arith.select(valid, mni, fx.Index(0)) + idx_pack = ( + safe_mni * scale_info.stride_n0 + + scale_ku_total * scale_info.stride_k0 + + k_inner_idx * scale_info.stride_klane + + n_lane + ) + raw_i32 = buffer_ops.buffer_load(rsrc, idx_pack, vec_width=1, dtype=T.i32) + # Extract the correct E8M0 byte for (k_mid, m1) via right-shift. + byte_val = (raw_i32 >> fx.Int32(byte_shift_const)) & fx.Int32(0xFF) + return arith.select(valid, byte_val, fx.Int32(0)) + + scale_ku_total = base_k # k_outer dimension; one step per 128-FP4 K-tile + + for ku in range_constexpr(k_unroll): + k_inner_k1a = (ku % 2) * 2 # k_inner for MFMA sub-step 0: 0 (ku<2), 2 (ku>=2) + k_inner_k1b = (ku % 2) * 2 + 1 # k_inner for MFMA sub-step 1: 1 (ku<2), 3 (ku>=2) + k_mid_shift = (ku // 2) * 16 # 0 for k_mid=0 (ku<2); 16 for k_mid=1 (ku>=2) + # Append ks=0 scales for all ni, then ks=1 — matches + # scale_idx = (ku*2 + ks)*num_acc_n + ni in compute_bf16xfp4_tile. + for ni in range_constexpr(num_acc_n): + # col_base_pack is aligned to pack_N*16 boundaries (for mni address). + col_pack_offset = (ni // pack_N) * 16 * pack_N + col_base_pack = out_block_base + out_wave_base + fx.Index(col_pack_offset) + col_valid = arith.cmpi(arith.CmpIPredicate.ult, col_base_pack, inter_idx) + gate_mni = (expert_off_idx + col_base_pack) // fx.Index(32) + up_mni = (expert_off_idx + inter_idx + col_base_pack) // fx.Index(32) + # m1 selects the second half of 16 lanes within the 32-lane pack group. + m1_shift = (ni % pack_N) * 8 + byte_shift = k_mid_shift + m1_shift + gate_scales.append(_load_scale_k_inner( + sw_rsrc, layout_b_scale, scale_ku_total, k_inner_k1a, + gate_mni, col_valid, byte_shift + )) + up_scales.append(_load_scale_k_inner( + sw_rsrc, layout_b_scale, scale_ku_total, k_inner_k1a, + up_mni, col_valid, byte_shift + )) + for ni in range_constexpr(num_acc_n): + col_pack_offset = (ni // pack_N) * 16 * pack_N + col_base_pack = out_block_base + out_wave_base + fx.Index(col_pack_offset) + col_valid = arith.cmpi(arith.CmpIPredicate.ult, col_base_pack, inter_idx) + gate_mni = (expert_off_idx + col_base_pack) // fx.Index(32) + up_mni = (expert_off_idx + inter_idx + col_base_pack) // fx.Index(32) + m1_shift = (ni % pack_N) * 8 + byte_shift = k_mid_shift + m1_shift + gate_scales.append(_load_scale_k_inner( + sw_rsrc, layout_b_scale, scale_ku_total, k_inner_k1b, + gate_mni, col_valid, byte_shift + )) + up_scales.append(_load_scale_k_inner( + sw_rsrc, layout_b_scale, scale_ku_total, k_inner_k1b, + up_mni, col_valid, byte_shift + )) + return gate_scales, up_scales + + def compute_bf16xfp4_tile( + acc_gate_in, + acc_up_in, + gate_b_tile_in, + up_b_tile_in, + lds_base, + *, + a0_prefetch=None, + gate_b_scales_in=None, + up_b_scales_in=None, + prefetch_epilogue: bool = False, + ): + """BF16×FP4 compute tile using mfma_f32_16x16x32_bf16. + + A is loaded from LDS as BF16 (full precision). + B (FP4 E2M1) is dequantized to BF16 in software via v_cvt_scalef32_pk_bf16_fp4 + with the MXFP4 E8M0 block scale. + + gate_b_scales_in / up_b_scales_in: flat lists of length k_unroll * 2 * num_acc_n, + indexed as [(ku*2+ks)*num_acc_n + ni]. Each entry holds the E8M0 scale byte in bits [7:0]. + """ + gate_list = list(acc_gate_in) + up_list = list(acc_up_in) + + gpu.barrier() + rocdl.sched_barrier(0) + + epilogue_pf = None + if enable_bias and prefetch_epilogue: + gate_bias = [] + up_bias = [] + for ni in range_constexpr(num_acc_n): + global_n = by_n + n_tile_base + ni * 16 + lane_mod_16 + gate_offset = expert_off_idx + global_n + up_offset = expert_off_idx + global_n + inter_dim + gate_bias.append( + buffer_ops.buffer_load( + bias_rsrc, gate_offset, vec_width=1, dtype=T.f32 + ) + ) + up_bias.append( + buffer_ops.buffer_load( + bias_rsrc, up_offset, vec_width=1, dtype=T.f32 + ) + ) + epilogue_pf = (gate_bias, up_bias) + + mfma_fn = rocdl.mfma_f32_16x16x32_bf16 + mfma_res_ty = T.f32x4 + vec2_i32 = T.vec(2, T.i32) + vec8_bf16 = T.vec(8, T.bf16) + i32x4_ty = T.vec(4, T.i32) + + gate_scales = gate_b_scales_in + up_scales = up_b_scales_in + + def _e8m0_i32_to_f32(e8m0_i32): + """Convert E8M0 scalar i32 (low byte) to f32 via bit-shift: 2^(e-127).""" + byte_val = e8m0_i32 & fx.Int32(0xFF) + f32_bits = byte_val << fx.Int32(23) + return arith.bitcast(T.f32, f32_bits) + + def _fp4_i32_to_v8bf16(src_i32, scale_f32): + """Convert 8 FP4 E2M1 nibbles in src_i32 to vec<8, bf16> using 4x cvt_scalef32_pk_bf16_fp4.""" + vec2_bf16 = T.vec(2, T.bf16) + p0 = rocdl.cvt_scalef32_pk_bf16_fp4(vec2_bf16, src_i32, scale_f32, fx.Int32(0)) + p1 = rocdl.cvt_scalef32_pk_bf16_fp4(vec2_bf16, src_i32, scale_f32, fx.Int32(1)) + p2 = rocdl.cvt_scalef32_pk_bf16_fp4(vec2_bf16, src_i32, scale_f32, fx.Int32(2)) + p3 = rocdl.cvt_scalef32_pk_bf16_fp4(vec2_bf16, src_i32, scale_f32, fx.Int32(3)) + # Concatenate 4x vector<2xbf16> → vector<8xbf16> via i32 bitcast. + p0_i32 = vector.bitcast(T.vec(1, T.i32), p0) + p1_i32 = vector.bitcast(T.vec(1, T.i32), p1) + p2_i32 = vector.bitcast(T.vec(1, T.i32), p2) + p3_i32 = vector.bitcast(T.vec(1, T.i32), p3) + v4 = vector.from_elements(i32x4_ty, [ + vector.extract(p0_i32, static_position=[0], dynamic_position=[]), + vector.extract(p1_i32, static_position=[0], dynamic_position=[]), + vector.extract(p2_i32, static_position=[0], dynamic_position=[]), + vector.extract(p3_i32, static_position=[0], dynamic_position=[]), + ]) + return vector.bitcast(vec8_bf16, v4) + + def _i64x2_to_v8bf16(lo, hi): + """Pack two i64s into vector<8xbf16> for mfma_f32_16x16x32_bf16 A/B fragment.""" + v2 = vector.from_elements(T.i64x2, [lo, hi]) + return vector.bitcast(vec8_bf16, v2) + + def _i64_lo_i32(v_i64): + v2 = vector.from_elements(T.vec(1, T.i64), [v_i64]) + v2i32 = vector.bitcast(vec2_i32, v2) + return vector.extract(v2i32, static_position=[0], dynamic_position=[]) + + def _i64_hi_i32(v_i64): + v2 = vector.from_elements(T.vec(1, T.i64), [v_i64]) + v2i32 = vector.bitcast(vec2_i32, v2) + return vector.extract(v2i32, static_position=[1], dynamic_position=[]) + + def _select_lane_fp4_i32(b0_i64, b1_i64): + """Extract this lane's 4-byte (8-nibble) FP4 slice from a 16-byte kpack. + + The FP4 preshuffle kpack (loaded with k1=0 for all lanes) stores 16 bytes: + b0_i64 = bytes [0..7] → nibbles for lane_div_16=0 (bytes 0-3) and 1 (bytes 4-7) + b1_i64 = bytes [8..15] → nibbles for lane_div_16=2 (bytes 8-11) and 3 (bytes 12-15) + Each lane selects its 4-byte slice using lane_div_16: + lane_div_16=0: b0_i64 lo i32 (bytes 0-3) + lane_div_16=1: b0_i64 hi i32 (bytes 4-7) + lane_div_16=2: b1_i64 lo i32 (bytes 8-11) + lane_div_16=3: b1_i64 hi i32 (bytes 12-15) + """ + # Split both i64s into lo/hi i32 pairs. + b0_lo = _i64_lo_i32(b0_i64) + b0_hi = _i64_hi_i32(b0_i64) + b1_lo = _i64_lo_i32(b1_i64) + b1_hi = _i64_hi_i32(b1_i64) + # lane_div_16 < 2 → use b0; else use b1. + lane_lt2 = arith.cmpi(arith.CmpIPredicate.ult, lane_div_16, arith.index(2)) + src_lo = arith.select(lane_lt2, b0_lo, b1_lo) + src_hi = arith.select(lane_lt2, b0_hi, b1_hi) + # lane_div_16 % 2 == 0 → lo i32; else hi i32. + lane_mod2 = lane_div_16 % arith.index(2) + lane_odd = arith.cmpi(arith.CmpIPredicate.ne, lane_mod2, arith.index(0)) + return arith.select(lane_odd, src_hi, src_lo) + + for ku in range_constexpr(k_unroll): + # Unpack 4-tuple: (k1=0 half=0, k1=0 half=1, k1=1 half=0, k1=1 half=1) + gate_k10_h0, gate_k10_h1, gate_k11_h0, gate_k11_h1 = gate_b_tile_in[ku] + up_k10_h0, up_k10_h1, up_k11_h0, up_k11_h1 = up_b_tile_in[ku] + # col_base must be in bytes (lds_load_packs_k64 passes it to swizzle_xor16 + # which operates at 16B granularity, then divides by a_elem_bytes → elements). + ki64_bytes = arith.index(ku * 64 * a_elem_bytes) + col_base = col_offset_base + ki64_bytes + + for mi in range_constexpr(m_repeat): + mi_val = arith.index(mi * 16) + curr_row_a_lds = row_a_lds + mi_val + + # mfma_f32_16x16x32_bf16: A needs vec<8,bf16> per lane per K32 step. + # Each lds_load_packs_k64 returns 2×i64 = 16B = one K32 MFMA's A. + # K64 outer step = 2 K32 MFMA calls. + # col_base_k1 is 32 BF16 = 64 bytes ahead of col_base. + col_base_k1 = col_base + arith.index(32 * a_elem_bytes) + + if (a0_prefetch is not None) and (ku == 0) and (mi == 0): + # Prefetch covers first K32 step only (K0..K31). + a0_hi0, a0_hi1 = a0_prefetch + a1_hi0, a1_hi1 = lds_load_packs_k64(curr_row_a_lds, col_base_k1, lds_base) + else: + a0_hi0, a0_hi1 = lds_load_packs_k64(curr_row_a_lds, col_base, lds_base) + a1_hi0, a1_hi1 = lds_load_packs_k64(curr_row_a_lds, col_base_k1, lds_base) + + # Pack each pair of i64s into vec<8,bf16> for the MFMA A operand. + av0 = _i64x2_to_v8bf16(a0_hi0, a0_hi1) # K0..K31 + av1 = _i64x2_to_v8bf16(a1_hi0, a1_hi1) # K32..K63 + + for ni in range_constexpr(num_acc_n): + acc_idx = mi * num_acc_n + ni + # Scale list layout: [(ku*2+ks)*num_acc_n + ni] + # ks=0: K32 sub-step 0 (k1=0), ks=1: K32 sub-step 1 (k1=1) + scale_idx0 = (ku * 2 + 0) * num_acc_n + ni + scale_idx1 = (ku * 2 + 1) * num_acc_n + ni + + gate_sf0 = _e8m0_i32_to_f32(gate_scales[scale_idx0]) + gate_sf1 = _e8m0_i32_to_f32(gate_scales[scale_idx1]) + up_sf0 = _e8m0_i32_to_f32(up_scales[scale_idx0]) + up_sf1 = _e8m0_i32_to_f32(up_scales[scale_idx1]) + + # K32 sub-step 0: nibbles from k1=0 kpack (K[ku*64..ku*64+31]). + # k1=0 half=0 → bytes [0..7], half=1 → bytes [8..15]. + # _select_lane_fp4_i32 picks the 4-byte slice for this lane. + gb0_i32 = _select_lane_fp4_i32(gate_k10_h0[ni], gate_k10_h1[ni]) + # K32 sub-step 1: nibbles from k1=1 kpack (K[ku*64+32..ku*64+63]). + gb1_i32 = _select_lane_fp4_i32(gate_k11_h0[ni], gate_k11_h1[ni]) + ub0_i32 = _select_lane_fp4_i32(up_k10_h0[ni], up_k10_h1[ni]) + ub1_i32 = _select_lane_fp4_i32(up_k11_h0[ni], up_k11_h1[ni]) + + # Dequantize B: vec<8,bf16> per K32 step, each with its own scale. + gbv0 = _fp4_i32_to_v8bf16(gb0_i32, gate_sf0) # B for K0..K31 + gbv1 = _fp4_i32_to_v8bf16(gb1_i32, gate_sf1) # B for K32..K63 + ubv0 = _fp4_i32_to_v8bf16(ub0_i32, up_sf0) + ubv1 = _fp4_i32_to_v8bf16(ub1_i32, up_sf1) + + # MFMA step 0: K0..K31 + rocdl.sched_barrier(0) + mid_g = mfma_fn(mfma_res_ty, [av0, gbv0, gate_list[acc_idx], 0, 0, 0]) + rocdl.sched_barrier(0) + mid_u = mfma_fn(mfma_res_ty, [av0, ubv0, up_list[acc_idx], 0, 0, 0]) + # MFMA step 1: K32..K63 + rocdl.sched_barrier(0) + gate_list[acc_idx] = mfma_fn(mfma_res_ty, [av1, gbv1, mid_g, 0, 0, 0]) + rocdl.sched_barrier(0) + up_list[acc_idx] = mfma_fn(mfma_res_ty, [av1, ubv1, mid_u, 0, 0, 0]) + + return gate_list, up_list, epilogue_pf + # ---------------- 2-stage pipeline (ping-pong LDS + B tile prefetch) ---------------- lds_tile_elems = fx.Index(tile_m * lds_stride) lds_base_cur = arith.index(0) @@ -1018,9 +1335,7 @@ def hot_loop_scheduler(): x_regs0 = load_x_tile(k0) gate_w0, up_w0 = load_b_tile(k0) - a_scale_pong, gate_bs_pong, up_bs_pong = prefetch_ab_scale_tile( - k0 // 2 - ) + _scale_pong = prefetch_ab_scale_tile(k0 // 2) store_x_tile_to_lds(x_regs0, lds_base_cur) gpu.barrier() @@ -1032,6 +1347,34 @@ def hot_loop_scheduler(): a0_prefetch_pong = None + if is_bf16xfp4: + def _call_compute_tile( + acc_gate_in, acc_up_in, gate_b_tile, up_b_tile, lds_base, + *, a0_prefetch, scale_bundle, prefetch_epilogue=False, + ): + _a, gate_bs, up_bs, gate_fp4_bs, up_fp4_bs = scale_bundle + return compute_bf16xfp4_tile( + acc_gate_in, acc_up_in, gate_b_tile, up_b_tile, lds_base, + a0_prefetch=a0_prefetch, + gate_b_scales_in=gate_fp4_bs, + up_b_scales_in=up_fp4_bs, + prefetch_epilogue=prefetch_epilogue, + ) + else: + def _call_compute_tile( + acc_gate_in, acc_up_in, gate_b_tile, up_b_tile, lds_base, + *, a0_prefetch, scale_bundle, prefetch_epilogue=False, + ): + a_sc, gate_bs, up_bs = scale_bundle + return compute_f8f6f4_tile( + acc_gate_in, acc_up_in, gate_b_tile, up_b_tile, lds_base, + a0_prefetch=a0_prefetch, + a_scale=a_sc, + gate_b_scale=gate_bs, + up_b_scale=up_bs, + prefetch_epilogue=prefetch_epilogue, + ) + if os.environ.get("FLYDSL_STAGE1_EARLY_RETURN", "0") == "1": return @@ -1052,9 +1395,7 @@ def hot_loop_scheduler(): next_k1 = k_iv + tile_k x_regs_ping = load_x_tile(next_k1) gate_w_ping, up_w_ping = load_b_tile(next_k1 // 2) - a_scale_ping, gate_bs_ping, up_bs_ping = ( - prefetch_ab_scale_tile(next_k1 // pack_K // 128) - ) + _scale_ping = prefetch_ab_scale_tile(next_k1 // pack_K // 128) if _skip_compute: store_x_tile_to_lds(x_regs_ping, lds_base_ping) @@ -1063,24 +1404,15 @@ def hot_loop_scheduler(): next_k2 = k_iv + (tile_k * 2) x_regs_pong = load_x_tile(next_k2) gate_w_pong, up_w_pong = load_b_tile(next_k2 // 2) - a_scale_pong, gate_bs_pong, up_bs_pong = ( - prefetch_ab_scale_tile(next_k2 // pack_K // 128) - ) + _scale_pong = prefetch_ab_scale_tile(next_k2 // pack_K // 128) store_x_tile_to_lds(x_regs_pong, lds_base_pong) gpu.barrier() a0_prefetch_pong = None continue - acc_gate, acc_up, _ = compute_f8f6f4_tile( - acc_gate, - acc_up, - gate_w_pong, - up_w_pong, - lds_base_pong, - a0_prefetch=a0_prefetch_pong, - a_scale=a_scale_pong, - gate_b_scale=gate_bs_pong, - up_b_scale=up_bs_pong, + acc_gate, acc_up, _ = _call_compute_tile( + acc_gate, acc_up, gate_w_pong, up_w_pong, lds_base_pong, + a0_prefetch=a0_prefetch_pong, scale_bundle=_scale_pong, ) a0_prefetch_pong = None store_x_tile_to_lds(x_regs_ping, lds_base_ping) @@ -1091,20 +1423,11 @@ def hot_loop_scheduler(): next_k2 = k_iv + (tile_k * 2) x_regs_pong = load_x_tile(next_k2) gate_w_pong, up_w_pong = load_b_tile(next_k2 // 2) - a_scale_pong, gate_bs_pong, up_bs_pong = ( - prefetch_ab_scale_tile(next_k2 // pack_K // 128) - ) + _scale_pong = prefetch_ab_scale_tile(next_k2 // pack_K // 128) - acc_gate, acc_up, _ = compute_f8f6f4_tile( - acc_gate, - acc_up, - gate_w_ping, - up_w_ping, - lds_base_ping, - a0_prefetch=a0_prefetch_ping, - a_scale=a_scale_ping, - gate_b_scale=gate_bs_ping, - up_b_scale=up_bs_ping, + acc_gate, acc_up, _ = _call_compute_tile( + acc_gate, acc_up, gate_w_ping, up_w_ping, lds_base_ping, + a0_prefetch=a0_prefetch_ping, scale_bundle=_scale_ping, ) a0_prefetch_ping = None store_x_tile_to_lds(x_regs_pong, lds_base_pong) @@ -1113,36 +1436,20 @@ def hot_loop_scheduler(): a0_prefetch_pong = None if odd_k_tiles: - acc_gate, acc_up, epilogue_pf = compute_f8f6f4_tile( - acc_gate, - acc_up, - gate_w_pong, - up_w_pong, - lds_base_pong, - a0_prefetch=a0_prefetch_pong, - a_scale=a_scale_pong, - gate_b_scale=gate_bs_pong, - up_b_scale=up_bs_pong, + acc_gate, acc_up, epilogue_pf = _call_compute_tile( + acc_gate, acc_up, gate_w_pong, up_w_pong, lds_base_pong, + a0_prefetch=a0_prefetch_pong, scale_bundle=_scale_pong, prefetch_epilogue=True, ) else: k_tail1 = k_in - tile_k x_regs_ping = load_x_tile(k_tail1) gate_w_ping, up_w_ping = load_b_tile(k_tail1 // 2) - a_scale_ping, gate_bs_ping, up_bs_ping = prefetch_ab_scale_tile( - k_tail1 // pack_K // 128 - ) + _scale_ping = prefetch_ab_scale_tile(k_tail1 // pack_K // 128) - acc_gate, acc_up, _ = compute_f8f6f4_tile( - acc_gate, - acc_up, - gate_w_pong, - up_w_pong, - lds_base_pong, - a0_prefetch=a0_prefetch_pong, - a_scale=a_scale_pong, - gate_b_scale=gate_bs_pong, - up_b_scale=up_bs_pong, + acc_gate, acc_up, _ = _call_compute_tile( + acc_gate, acc_up, gate_w_pong, up_w_pong, lds_base_pong, + a0_prefetch=a0_prefetch_pong, scale_bundle=_scale_pong, ) a0_prefetch_pong = None store_x_tile_to_lds(x_regs_ping, lds_base_ping) @@ -1150,16 +1457,9 @@ def hot_loop_scheduler(): a0_prefetch_ping = None - acc_gate, acc_up, epilogue_pf = compute_f8f6f4_tile( - acc_gate, - acc_up, - gate_w_ping, - up_w_ping, - lds_base_ping, - a0_prefetch=a0_prefetch_ping, - a_scale=a_scale_ping, - gate_b_scale=gate_bs_ping, - up_b_scale=up_bs_ping, + acc_gate, acc_up, epilogue_pf = _call_compute_tile( + acc_gate, acc_up, gate_w_ping, up_w_ping, lds_base_ping, + a0_prefetch=a0_prefetch_ping, scale_bundle=_scale_ping, prefetch_epilogue=True, ) @@ -1382,33 +1682,38 @@ def _stage1_store_row(*, mi: int, ii: int, row_in_tile, row): with _if_then(_if_valid): for ni in range_constexpr(num_acc_n): col_i32 = col_i32_list[ni] - acc_idx = mi * num_acc_n + ni - vg = vector.extract( - acc_gate[acc_idx], - static_position=[ii], - dynamic_position=[], + # Guard: only write when the output column is within the inter_dim range. + # by-blocks with by_n >= inter_dim (OOB N-tiles) must not corrupt + # neighboring token-slot output rows with zero/garbage values. + col_n_valid = arith.cmpi( + arith.CmpIPredicate.ult, col_g_list[ni], inter_idx ) - vu = vector.extract( - acc_up[acc_idx], - static_position=[ii], - dynamic_position=[], - ) - if enable_bias: - gate_bias_list, up_bias_list = epilogue_pf - vg = vg + gate_bias_list[ni] - vu = vu + up_bias_list[ni] + _if_col = scf.IfOp(col_n_valid) + with _if_then(_if_col): + acc_idx = mi * num_acc_n + ni + vg = vector.extract( + acc_gate[acc_idx], + static_position=[ii], + dynamic_position=[], + ) + vu = vector.extract( + acc_up[acc_idx], + static_position=[ii], + dynamic_position=[], + ) + if enable_bias: + gate_bias_list, up_bias_list = epilogue_pf + vg = vg + gate_bias_list[ni] + vu = vu + up_bias_list[ni] - if act == "swiglu": - y = swiglu(vg, vu) - else: y = silu(vg) * vu - if doweight_stage1: - y = y * tw + if doweight_stage1: + y = y * tw - y = arith.trunc_f(_out_elem_type(), y) - idx_out = idx0 + col_i32 - buffer_ops.buffer_store(y, out_rsrc, idx_out) + y = arith.trunc_f(_out_elem_type(), y) + idx_out = idx0 + col_i32 + buffer_ops.buffer_store(y, out_rsrc, idx_out) mfma_epilog( use_cshuffle=False, @@ -1659,7 +1964,11 @@ def out_elem(): lds_total_elems = lds_total_bytes if a_elem_bytes == 1 else (lds_total_bytes // 2) def x_lds_elem(): - return T.f16 if is_f16_a else T.f8 + if is_f16_a: + return T.f16 + if is_bf16_a: + return T.bf16 + return T.f8 lds_alloc_bytes = int(lds_total_elems) * int(a_elem_bytes) lds_alloc_offset = allocator._align(allocator.ptr, 16) @@ -1688,7 +1997,7 @@ def moe_gemm2( n_in = arith.index_cast(T.index, i32_n_in) k_in = arith.index_cast(T.index, i32_k_in) size_expert_ids_in = arith.index_cast(T.index, i32_size_expert_ids_in) - x_elem = T.f16 if is_f16_a else T.f8 + x_elem = x_lds_elem() vec4_f32 = T.vec(4, T.f32) vec4_i32 = T.vec(4, T.i32) vec16_elems = 16 if a_elem_bytes == 1 else 8 @@ -2239,7 +2548,7 @@ def lds_load_packs_k64(curr_row_a_lds, col_base, lds_base): ) col_base_swz = ( col_base_swz_bytes - if elem_bytes == 1 + if a_elem_bytes == 1 else (col_base_swz_bytes // arith.index(2)) ) idx_a16 = lds_row_major_idx( diff --git a/python/flydsl/expr/rocdl/inline_asm.py b/python/flydsl/expr/rocdl/inline_asm.py index 78e4facfd..ef664e9b4 100644 --- a/python/flydsl/expr/rocdl/inline_asm.py +++ b/python/flydsl/expr/rocdl/inline_asm.py @@ -9,7 +9,8 @@ MLIR ROCDLOps.td tablegen does not surface them. TODO: Remove these inline asm wrappers once upstream MLIR adds proper ROCDL -dialect ops for v_cvt_off_f32_i4 and v_cvt_pk_bf16_f32. +dialect ops for v_cvt_off_f32_i4, v_cvt_pk_bf16_f32, and +v_cvt_scalef32_pk_bf16_fp4. """ diff --git a/tests/kernels/test_moe_gemm.py b/tests/kernels/test_moe_gemm.py index b979b3fea..14804ff88 100644 --- a/tests/kernels/test_moe_gemm.py +++ b/tests/kernels/test_moe_gemm.py @@ -31,7 +31,7 @@ if os.path.isdir(_p) and _p not in sys.path: sys.path.insert(0, _p) -from tests.kernels.test_ref import torch_moe_gemm1, torch_moe_gemm2 +from tests.kernels.test_ref import torch_moe_gemm1, torch_moe_gemm2, _dequant_mxfp4_per_1x32 from tests.utils import pertoken_quant, shuffle_weight, shuffle_scale_for_int4 from tests.test_common import verify_output, run_perftest from flydsl.runtime.device import get_rocm_arch @@ -1959,6 +1959,104 @@ def test_moe_stage2_standalone( ) +@pytest.mark.skipif("gfx95" not in ARCH, reason="BF16×FP4 requires gfx950+ (v_cvt_scalef32_pk_bf16_fp4)") +@pytest.mark.parametrize("tile_m", [16, 32, 64], ids=["tile_m16", "tile_m32", "tile_m64"]) +@pytest.mark.parametrize("act", ["silu", "swiglu"]) +def test_moe_gemm1_bf16xfp4(tile_m: int, act: str): + """BF16×FP4 (W4A16) stage1: BF16 activations, FP4 E2M1 weights, software dequant via + v_cvt_scalef32_pk_bf16_fp4, then mfma_f32_16x16x32_bf16.""" + from tests.kernels.utils import fp4_utils + if fp4_utils is None: + pytest.skip("fp4_utils not available (triton not installed)") + + tokens, model_dim, inter_dim, experts, topk = 64, 512, 256, 4, 2 + tile_n, tile_k = 256, 256 + device = torch.device("cuda") + torch.manual_seed(0) + + x_fp32 = torch.randn((tokens, model_dim), device=device, dtype=torch.float32) * 0.2 + w1_fp32 = torch.randn((experts, 2 * inter_dim, model_dim), device=device, dtype=torch.float32) * 0.2 + + score = torch.randn((tokens, experts), device=device, dtype=torch.float32) + topk_vals, topk_ids = torch.topk(score, k=topk, dim=1) + topk_weights = torch.softmax(topk_vals, dim=1).to(torch.float32) + + routing = build_routing_buffers( + topk_ids=topk_ids, + topk_weights=topk_weights, + experts=experts, + model_dim=model_dim, + tile_m=tile_m, + ) + sorted_token_ids, sorted_weights, sorted_expert_ids, num_valid_ids, _sorted_size, blocks = routing + + # Activations: BF16, no quantization + x_bf16 = x_fp32.to(torch.bfloat16) + + # Weights: quantize W1 to MX FP4 E2M1, preshuffle for the kernel + w1_flat_fp32 = w1_fp32.view(experts * (2 * inter_dim), model_dim) + w1_fp4, w1_scale_raw = _per_1x32_fp4_quant(w1_flat_fp32) + w1_shuffled = shuffle_weight(w1_fp4.view(torch.float4_e2m1fn_x2)) + w_kernel = w1_shuffled.view(torch.uint8).contiguous() + scale_w1_1d = fp4_utils.e8m0_shuffle(w1_scale_raw).view(torch.uint8).contiguous() + + # Scale X: empty (no activation scale for BF16 path) + scale_x_1d = torch.empty((0,), device=device, dtype=torch.float32) + + out = torch.empty((tokens, topk, inter_dim), device=device, dtype=torch.float16) + + exe = compile_mixed_moe_gemm1( + model_dim=model_dim, + inter_dim=inter_dim, + experts=experts, + topk=topk, + tile_m=tile_m, + tile_n=tile_n, + tile_k=tile_k, + doweight_stage1=False, + a_dtype="bf16", + b_dtype="fp4", + out_dtype="f16", + act=act, + ) + bias_dummy = torch.empty((0,), device=device, dtype=torch.float32) + + def _args(o): + return ( + o, x_bf16, w_kernel, scale_x_1d, scale_w1_1d, + sorted_token_ids, sorted_expert_ids, sorted_weights, + num_valid_ids, bias_dummy, + tokens, inter_dim * 2, model_dim, int(blocks), + torch.cuda.current_stream(), + ) + + compiled_exe = flyc.compile(exe, *_args(out)) + compiled_exe(*_args(out)) + torch.cuda.synchronize() + + # Reference: dequantize W1 FP4→f32, multiply by BF16 activations + w1_q_flat = w1_fp4.view(experts * (2 * inter_dim), model_dim // 2) + scale_w1_flat = w1_scale_raw.view(experts * (2 * inter_dim), model_dim // 32) + w1_dequant_f32 = _dequant_mxfp4_per_1x32(w1_q_flat, scale_w1_flat).view(experts, 2 * inter_dim, model_dim) + x_f32 = x_bf16.to(torch.float32) + + ref = torch.zeros((tokens, topk, inter_dim), device=device, dtype=torch.float32) + for e in range(experts): + mask = topk_ids == e + idx = mask.nonzero(as_tuple=False) + if idx.numel() == 0: + continue + t_idx, s_idx = idx[:, 0], idx[:, 1] + y2 = torch.nn.functional.linear(x_f32[t_idx], w1_dequant_f32[e]) + gate, up = y2[:, :inter_dim], y2[:, inter_dim:] + if act == "silu": + ref[t_idx, s_idx] = (torch.nn.functional.silu(gate) * up).float() + else: + ref[t_idx, s_idx] = (torch.nn.functional.silu(gate) * up).float() + + assert verify_output(out.to(torch.float32), ref, rtol=0.15, atol=0.15) + + if __name__ == "__main__": torch.set_default_device("cuda") # CLI (mirrors key knobs from aiter/op_tests/test_moe_2stage.py, stage1 subset) From dd69c1db07fe2b6d3f55fe01997a14e1c3ddcc54 Mon Sep 17 00:00:00 2001 From: Andrea Picciau Date: Tue, 21 Apr 2026 15:42:14 +0000 Subject: [PATCH 2/2] Fix NameError: declare is_bf16_a in compile_mixed_moe_gemm2 Stage 2 x_lds_elem() referenced is_bf16_a which was never declared, causing a NameError on any fp4 path. Also fixes a_elem_bytes to account for bf16 (2 bytes) alongside fp16. --- kernels/mixed_moe_gemm_2stage.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/kernels/mixed_moe_gemm_2stage.py b/kernels/mixed_moe_gemm_2stage.py index 1f16debb1..52cab3756 100644 --- a/kernels/mixed_moe_gemm_2stage.py +++ b/kernels/mixed_moe_gemm_2stage.py @@ -1853,6 +1853,7 @@ def compile_mixed_moe_gemm2( ) is_f16_a = a_dtype == "fp16" + is_bf16_a = a_dtype == "bf16" is_f16_b = b_dtype == "fp16" is_f8_a = a_dtype == "fp8" @@ -1865,7 +1866,7 @@ def compile_mixed_moe_gemm2( elem_bytes = 1 - a_elem_bytes = 2 if is_f16_a else 1 + a_elem_bytes = 2 if (is_f16_a or is_bf16_a) else 1 b_elem_bytes = 1 tile_k_bytes = int(tile_k) * int(a_elem_bytes) @@ -1873,7 +1874,7 @@ def compile_mixed_moe_gemm2( cbsz = 0 if is_f8_a else 4 blgp = 4 - # K64-byte micro-step: always 64 bytes per `ku`. For fp16, this is 32 elements (2xK16 MFMA). + # K64-byte micro-step: always 64 bytes per `ku`. For fp16/bf16, this is 32 elements (2xK16 MFMA). if (tile_k_bytes % 64) != 0: raise ValueError( f"tile_k_bytes must be divisible by 64, got tile_k_bytes={tile_k_bytes} "