Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
280 changes: 280 additions & 0 deletions kernels/mfma_preshuffle_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -881,3 +881,283 @@ def load_b_raw_w4a16_groupwise(
def unpack_b_w4a16_groupwise(packed32, scale_val, arith, vector, use_gfx950_cvt=False):
"""Phase 2 of W4A16 groupwise: unpack + scale + convert to bf16."""
return unpack_b_w4a16(packed32, arith, vector, scale_val=scale_val, use_gfx950_cvt=use_gfx950_cvt)


def _cvt_scalef32_pk_bf16_fp4(packed_i32, scale_f32, byte_idx, arith, vector):
"""GFX950 hardware: v_cvt_scalef32_pk_bf16_fp4.

Converts 2 FP4 E2M1 nibbles (from *byte_idx* of *packed_i32*) to
2 bf16 values (already scaled by *scale_f32*), returned as i32
(2 packed bf16).

One instruction replaces ~36 VALU of the software path.
"""
from flydsl._mlir.dialects import llvm

byte_idx_i32 = arith.constant(byte_idx, type=T.i32)
result_v2bf16 = llvm.call_intrinsic(
T.vec(2, T.bf16),
"llvm.amdgcn.cvt.scalef32.pk.bf16.fp4",
[packed_i32, scale_f32, byte_idx_i32],
[], [],
)
vec1_i32_t = T.vec(1, T.i32)
return vector.extract(
vector.bitcast(vec1_i32_t, result_v2bf16),
static_position=[0], dynamic_position=[],
)


def _fp4x4_in_i32_to_bf16x4_i64(packed4, arith, vector, scale_f32=None):
"""Convert 4 FP4 E2M1 nibbles (in 4 bytes of i32) to 4 bf16 packed as i64.

Each byte of *packed4* holds one nibble in bits [3:0]:
bit[3] = sign, bits[2:1] = exponent (bias=1), bit[0] = mantissa.

Unsigned value table (3-bit index):
000->0.0, 001->0.5, 010->1.0, 011->1.5,
100->2.0, 101->3.0, 110->4.0, 111->6.0

*scale_f32*, when provided, is an f32 E8M0 block-scale multiplied
into every element before truncation to bf16.
"""
vec1_i32_t = T.vec(1, T.i32)
vec2_i32 = T.i32x2
vec4_i8 = T.i8x4
vec1_i64 = T.vec(1, T.i64)

v1 = vector.from_elements(vec1_i32_t, [packed4])
i8x4 = vector.bitcast(vec4_i8, v1)

c1 = arith.constant(1, type=T.i32)
c3_shift = arith.constant(3, type=T.i32)
c7 = arith.constant(7, type=T.i32)
c22 = arith.constant(22, type=T.i32)
c23 = arith.constant(23, type=T.i32)
c31 = arith.constant(31, type=T.i32)
c126 = arith.constant(126, type=T.i32)
c_zero = arith.constant(0, type=T.i32)
c_half_bits = arith.constant(0x3F000000, type=T.i32) # 0.5f

f32_vals = []
for i in range(4):
nibble_i8 = vector.extract(i8x4, static_position=[i], dynamic_position=[])
n = arith.extui(T.i32, nibble_i8)

sign_bit = arith.andi(arith.shrui(n, c3_shift), c1)
unsigned_val = arith.andi(n, c7)
exp_field = arith.shrui(unsigned_val, c1)
mant_field = arith.andi(unsigned_val, c1)

f32_norm = arith.ori(
arith.shli(arith.addi(exp_field, c126), c23),
arith.shli(mant_field, c22),
)

is_zero = arith.cmpi(arith.CmpIPredicate.eq, unsigned_val, c_zero)
is_subnorm = arith.cmpi(arith.CmpIPredicate.eq, unsigned_val, c1)

f32_bits = arith.select(
is_zero, c_zero,
arith.select(is_subnorm, c_half_bits, f32_norm),
)
f32_bits = arith.ori(f32_bits, arith.shli(sign_bit, c31))

v = arith.bitcast(T.f32, f32_bits)
if scale_f32 is not None:
v = v * scale_f32
f32_vals.append(v)

c16 = arith.constant(16, type=T.i32)
c_ffff0000 = arith.constant(0xFFFF0000, type=T.i32)
bits0 = arith.bitcast(T.i32, f32_vals[0])
bits1 = arith.bitcast(T.i32, f32_vals[1])
bits2 = arith.bitcast(T.i32, f32_vals[2])
bits3 = arith.bitcast(T.i32, f32_vals[3])
i32_lo = arith.shrui(bits0, c16) | (bits1 & c_ffff0000)
i32_hi = arith.shrui(bits2, c16) | (bits3 & c_ffff0000)

v2 = vector.from_elements(vec2_i32, [i32_lo, i32_hi])
v64 = vector.bitcast(vec1_i64, v2)
return vector.extract(v64, static_position=[0], dynamic_position=[])


def load_b_raw_mxfp4(
buffer_ops,
arith,
vector,
*,
arg_b,
b_rsrc,
layout_b,
base_k: ir.Value,
ku: int,
n_blk: ir.Value,
n_intra: ir.Value,
lane_div_16: ir.Value,
elem_type: ir.Type,
kpack_bytes: int = 16,
):
"""Load 4 bytes of packed FP4 from a kpack=16 preshuffle layout.

Addressing for kpack=16 (``shuffle_weight_a16w4`` format):
- Layout shape: ``(n0, k0, klane=4, nlane=16, kpack=16)``
- The A-side LDS has klane stride = 8 bf16 elements, advancing
by 32 bf16 per ku step. B must match: each klane loads 4 bytes
(8 FP4 = 8 K elements) at K_start = base_k + ku*32 + lane*8.
- In the preshuffle layout this maps to:
k0 = base_k//128 + ku//4
klane_hw = ku % 4 (compile-time)
kpack_byte = lane_div_16*4 (runtime)

Returns a single i32 containing 4 packed bytes (8 FP4 nibbles).
"""
if kpack_bytes != 16:
raise ValueError(f"MXFP4 requires kpack_bytes=16, got {kpack_bytes!r}")

c128 = arith.constant(128, index=True)
c4 = arith.constant(4, index=True)

k0_base = base_k // c128
k0 = k0_base + arith.constant(ku // 4, index=True)
klane_hw = arith.constant(ku % 4, index=True)
byte_offset = lane_div_16 * c4

coord_pack = (n_blk, k0, klane_hw, n_intra, arith.constant(0, index=True))
idx_pack = crd2idx(coord_pack, layout_b)
idx_bytes = idx_pack + byte_offset

b4 = _buffer_load_vec(
buffer_ops,
vector,
b_rsrc,
idx_bytes,
elem_type=elem_type,
vec_elems=4,
elem_bytes=1,
offset_in_bytes=True,
)
packed32 = vector.extract(
vector.bitcast(T.vec(1, T.i32), b4),
static_position=[0],
dynamic_position=[],
)
return packed32


def load_b_raw_mxfp4_dwordx4(
buffer_ops,
arith,
vector,
*,
arg_b,
b_rsrc,
layout_b,
base_k: "ir.Value",
n_blk: "ir.Value",
n_intra: "ir.Value",
lane_div_16: "ir.Value",
elem_type: "ir.Type",
kpack_bytes: int = 16,
cache_modifier: int = 0,
):
"""Load 16 bytes (vec4_i32) of packed FP4 via buffer_load_dwordx4.

CK-style addressing: klane = lane_div_16, loading the full kpack
for the thread's sub-lane. Returns vec4_i32 where i32[j] contains
8 FP4 elements for kIter j.

Layout: ``(n0, k0, klane=4, nlane=16, kpack=16)``
"""
if kpack_bytes != 16:
raise ValueError(f"MXFP4 requires kpack_bytes=16, got {kpack_bytes!r}")

c128 = arith.constant(128, index=True)
k0 = base_k // c128

coord_pack = (n_blk, k0, lane_div_16, n_intra, arith.constant(0, index=True))
idx_pack = crd2idx(coord_pack, layout_b)

b16 = _buffer_load_vec(
buffer_ops,
vector,
b_rsrc,
idx_pack,
elem_type=elem_type,
vec_elems=16,
elem_bytes=1,
offset_in_bytes=True,
Copy link

Copilot AI Apr 29, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

load_b_raw_mxfp4_dwordx4 accepts a cache_modifier argument but never forwards it into _buffer_load_vec/buffer_load, so callers cannot actually control the cache policy (even though compile_a16w4_moe_gemm2 passes cache_modifier=2). Either plumb cache_modifier through to _buffer_load_vec(..., cache_modifier=cache_modifier) or remove the parameter (and the non-default call sites) to avoid misleading behavior.

Suggested change
offset_in_bytes=True,
offset_in_bytes=True,
cache_modifier=cache_modifier,

Copilot uses AI. Check for mistakes.
)
return vector.bitcast(T.vec(4, T.i32), b16)


def _unpack_b_mxfp4_bf16_hw(packed32, arith, vector, scale_f32):
"""Hardware fast-path: 4 x v_cvt_scalef32_pk_bf16_fp4."""
vec2_i32 = T.i32x2
vec1_i64 = T.vec(1, T.i64)

lo0 = _cvt_scalef32_pk_bf16_fp4(packed32, scale_f32, 0, arith, vector)
lo1 = _cvt_scalef32_pk_bf16_fp4(packed32, scale_f32, 1, arith, vector)
v2_lo = vector.from_elements(vec2_i32, [lo0, lo1])
v64_lo = vector.bitcast(vec1_i64, v2_lo)
b0 = vector.extract(v64_lo, static_position=[0], dynamic_position=[])

hi0 = _cvt_scalef32_pk_bf16_fp4(packed32, scale_f32, 2, arith, vector)
hi1 = _cvt_scalef32_pk_bf16_fp4(packed32, scale_f32, 3, arith, vector)
v2_hi = vector.from_elements(vec2_i32, [hi0, hi1])
v64_hi = vector.bitcast(vec1_i64, v2_hi)
b1 = vector.extract(v64_hi, static_position=[0], dynamic_position=[])

return (b0, b1)


def _unpack_b_mxfp4_bf16_sw(packed32, arith, vector, scale_f32):
"""Software fallback for non-GFX950 targets."""
c_0f = arith.constant(0x0F, type=T.i32)
c4 = arith.constant(4, type=T.i32)
c8 = arith.constant(8, type=T.i32)
c12 = arith.constant(12, type=T.i32)
c16 = arith.constant(16, type=T.i32)
c20 = arith.constant(20, type=T.i32)
c24 = arith.constant(24, type=T.i32)
c28 = arith.constant(28, type=T.i32)

n0 = packed32 & c_0f
n1 = arith.shrui(packed32, c4) & c_0f
n2 = arith.shrui(packed32, c8) & c_0f
n3 = arith.shrui(packed32, c12) & c_0f
first = n0 | arith.shli(n1, c8) | arith.shli(n2, c16) | arith.shli(n3, c24)

n4 = arith.shrui(packed32, c16) & c_0f
n5 = arith.shrui(packed32, c20) & c_0f
n6 = arith.shrui(packed32, c24) & c_0f
n7 = arith.shrui(packed32, c28) & c_0f
second = n4 | arith.shli(n5, c8) | arith.shli(n6, c16) | arith.shli(n7, c24)

b0 = _fp4x4_in_i32_to_bf16x4_i64(first, arith, vector, scale_f32=scale_f32)
b1 = _fp4x4_in_i32_to_bf16x4_i64(second, arith, vector, scale_f32=scale_f32)
return (b0, b1)


def unpack_b_mxfp4_bf16(packed32, arith, vector, scale_f32=None, use_hw_cvt=True):
"""Unpack 8 FP4 E2M1 nibbles (packed in i32) to 2 x i64 (8 bf16).

Each byte of *packed32* holds two FP4 nibbles: low nibble = K_even,
high nibble = K_even+1. For ``mfma_f32_16x16x16bf16_1k`` the B
operand needs 4 consecutive K values per i64. So we unpack the
lower 2 bytes (4 consecutive nibbles) into b0 and the upper 2 bytes
into b1.

*scale_f32* is the decoded E8M0 block-scale (as f32).

When *use_hw_cvt* is True (default), uses the GFX950 hardware
instruction ``v_cvt_scalef32_pk_bf16_fp4`` which converts 2 FP4
nibbles to 2 bf16 (with scale) in a single VALU cycle. This
replaces ~144 VALU of the software fallback with 4 instructions.

Returns ``(b0, b1)`` -- two i64 values, each containing 4 bf16 for
one ``mfma_f32_16x16x16bf16_1k`` call.
"""
if use_hw_cvt and scale_f32 is not None:
return _unpack_b_mxfp4_bf16_hw(packed32, arith, vector, scale_f32)
return _unpack_b_mxfp4_bf16_sw(packed32, arith, vector, scale_f32)
Loading
Loading