diff --git a/kernels/dispatch_combine_intranode_kernel.py b/kernels/dispatch_combine_intranode_kernel.py new file mode 100644 index 00000000..47a0533a --- /dev/null +++ b/kernels/dispatch_combine_intranode_kernel.py @@ -0,0 +1,1534 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright (c) 2025 FlyDSL Project Contributors + +"""FlyDSL intranode dispatch / combine kernels for expert-parallel MoE. + +This module exposes :func:`make_dispatch_jit` and :func:`make_combine_jit`, +which generate ``@flyc.jit`` launchers wrapping the two intranode kernels +(``ep_dispatch_intranode`` / ``ep_combine_intranode``). The kernels are +implemented mostly with FlyDSL high-level syntax (operator overloading, +``if``/``for`` lowered to ``scf.if``/``scf.for`` by the AST rewriter, +``buffer_load``/``buffer_store`` for global access); a few low-level MLIR +helpers below provide system-scope atomics and pointer-cast intrinsics that +do not yet have a high-level wrapper. +""" + +from __future__ import annotations + +import flydsl.compiler as flyc +import flydsl.expr as fx +import torch + +import mori.ir.flydsl as mori_shmem + +from flydsl.expr import T, arith, const_expr, range_constexpr +from flydsl.expr.buffer_ops import ( + buffer_load, + buffer_store, + create_buffer_resource_from_addr, +) +from flydsl.expr.rocdl import ballot_i64, readlane +from flydsl.expr.typing import Stream +from flydsl.expr import vector +from flydsl.expr.vector import bitcast_i32_to_v2bf16, bitcast_v2bf16_to_i32 +from flydsl.utils.smem_allocator import SmemAllocator, SmemPtr + +# Low-level MLIR escape hatches: system-scope atomics, pointer casts, and +# raw LLVM ops used inside the SIMD micro-kernel helpers +# (``_accum_experts`` / ``_weighted_accum_experts``). +from flydsl._mlir import ir +from flydsl._mlir import ir as _ir +from flydsl._mlir.dialects import llvm as _llvm_d +from flydsl._mlir.ir import IntegerAttr as _IntAttr, IntegerType as _IntTy + + +def _lv_unwrap(v): + """Return the underlying ``ir.Value`` for *v*. + + Accepts an ``ir.Value`` directly, any FlyDSL Numeric wrapper + (Int32/Int64/Float32/..., exposing ``ir_value()`` or the + ``_extract_to_ir_values`` protocol), or a Python ``int`` literal + (materialized as an ``i32`` constant). + """ + if isinstance(v, _ir.Value): + return v + if hasattr(v, "ir_value"): + return v.ir_value() + if hasattr(v, "_extract_to_ir_values"): + vals = v._extract_to_ir_values() + if len(vals) == 1: + return vals[0] + raise ValueError(f"Expected 1 ir.Value, got {len(vals)}") + if isinstance(v, int): + _i32 = _IntTy.get_signless(32) + return _llvm_d.ConstantOp(_i32, _IntAttr.get(_i32, v)).result + raise TypeError(f"Cannot convert {type(v).__name__} to ir.Value") + + +def _to_ptr_global(v): + """Cast an i64 address to ``!llvm.ptr<1>`` (global address space).""" + return _llvm_d.IntToPtrOp( + _llvm_d.PointerType.get(address_space=1), _lv_unwrap(v)).result + + +def store_i32_system(addr_i64, offset, val): + """System-scope monotonic i32 store at ``addr_i64 + offset*4`` bytes.""" + base = _lv_unwrap(addr_i64) + off = _lv_unwrap(offset) + val_ = _lv_unwrap(val) + _i64 = _IntTy.get_signless(64) + _i32 = _IntTy.get_signless(32) + _nuw = _ir.Attribute.parse("#llvm.overflow") + off64 = _llvm_d.ZExtOp(_i64, off).res if off.type == _i32 else off + byte_off = _llvm_d.MulOp( + off64, _llvm_d.ConstantOp(_i64, _IntAttr.get(_i64, 4)).result, _nuw).result + addr = _llvm_d.AddOp(base, byte_off, _nuw).result + gptr = _llvm_d.IntToPtrOp( + _llvm_d.PointerType.get(address_space=1), addr).result + _llvm_d.StoreOp(val_, gptr, alignment=4, + ordering=_llvm_d.AtomicOrdering.monotonic, + syncscope="one-as") + + +def store_i64_global_system(addr_i64, val): + """System-scope monotonic i64 store to ``addr_i64``.""" + gptr = _to_ptr_global(addr_i64) + _llvm_d.StoreOp(_lv_unwrap(val), gptr, alignment=8, + ordering=_llvm_d.AtomicOrdering.monotonic, + syncscope="one-as") + + +def load_i64_global(addr_i64): + """Relaxed global i64 load from ``addr_i64``.""" + ptr = _to_ptr_global(addr_i64) + _i64 = _IntTy.get_signless(64) + return _llvm_d.LoadOp(_i64, ptr, alignment=8).result + + +def atomic_add_global_at(addr_i64, val): + """Monotonic global ``atomic fetch-and-add``; returns the old value.""" + ptr = _to_ptr_global(addr_i64) + return _llvm_d.AtomicRMWOp( + _llvm_d.AtomicBinOp.add, ptr, _lv_unwrap(val), + _llvm_d.AtomicOrdering.monotonic).res + + +# NOTE: explicit ``index``/``i32`` casts that used to wrap every for-loop bound +# and induction-variable in this file have been removed. FlyDSL's +# ``scf_for_dispatch`` accepts i32/Python-int bounds directly and yields an +# i32 IV; ``SmemPtr.{load,store}`` runs each index through +# ``get_index_value`` which materializes/casts to ``index`` on demand. So +# ``arith.index_cast(T.index(), x)`` / ``arith.index_cast(T.i32(), iv)`` +# everywhere was a no-op IR-wise and pure boilerplate. + + +def make_dispatch_kernel( + *, + rank: int, + npes: int, + experts_per_rank: int, + experts_per_token: int, + hidden_dim: int, + hidden_elem_size: int, + max_tok_per_rank: int, + block_num: int, + warp_num_per_block: int, + scale_dim: int = 0, + scale_type_size: int = 0, + enable_std_moe: bool = False, + data_type=None, +): + """Build the intranode dispatch ``@flyc.kernel``. + + Schedules ``cur_tok * experts_per_token`` work items across all + ``block_num * warp_num_per_block`` warps. Each warp: + + 1. resolves the (dest_pe, dest_tok) slot via atomic counters on the + remote rank's ``shmem_tok_off``, + 2. P2P-writes token embedding, weights and indices into the + destination's symmetric shmem buffers, + 3. publishes a "send done" signal to every peer and waits for the + dual signal from each peer so it can finalize ``total_recv``, + 4. when ``enable_std_moe`` is set, performs ConvertDispatchOutput + (per-expert packing for the std-MoE expert path). + """ + max_recv = npes * max_tok_per_rank + _is_fp4 = (data_type == torch.float4_e2m1fn_x2) + if _is_fp4: + n_i32 = hidden_dim // 8 # 8 fp4 values per i32 (4 bytes) + nbytes = hidden_dim // 2 # 2 fp4 values per byte + else: + n_i32 = (hidden_dim * hidden_elem_size) // 4 + nbytes = hidden_dim * hidden_elem_size + scale_bytes = scale_dim * scale_type_size + scale_n_i32 = (scale_bytes + 3) // 4 if scale_bytes > 0 else 0 + enable_scales = scale_bytes > 0 + max_tokens_per_expert = npes * max_tok_per_rank # per-expert bucket capacity + + @flyc.kernel + def ep_dispatch_intranode( + addr_inp_tok: fx.Int64, # [cur_tok, hidden_dim] bf16 + addr_idx: fx.Int64, # [cur_tok, k] i32 (token_indices) + addr_wts: fx.Int64, # [cur_tok, k] f32 (weights_buf) + addr_out_tok: fx.Int64, # shmem_out_tok + addr_out_wts: fx.Int64, # shmem_out_wts + addr_out_idx: fx.Int64, # shmem_out_idx + addr_tok_off: fx.Int64, # shmem_tok_off (i32[1]) + addr_recv_num: fx.Int64, # recv_tok_num (i32[npes]) + addr_dest_ctr: fx.Int64, # dest_pe_ctr (i32[npes]) + addr_disp_bar: fx.Int64, # dispatch_bar (i32[1]) + addr_tok_map: fx.Int64, # dest_tok_map (i32[cur_tok*k]) + addr_tis: fx.Int64, # tok_id_to_src (i32[max_recv]) + addr_total_rv: fx.Int64, # total_recv (i32[1]) + # Pre-resolved P2P address arrays (i64[npes]): the remote-side base + # of each symmetric shmem buffer on every peer PE. + addr_p2p_tok_off: fx.Int64, + addr_p2p_tis: fx.Int64, + addr_p2p_out_wts: fx.Int64, + addr_p2p_out_idx: fx.Int64, + addr_p2p_out_tok: fx.Int64, + addr_p2p_recv_num: fx.Int64, + addr_scales: fx.Int64, # input scales buffer + addr_p2p_out_scales: fx.Int64, # i64[npes] P2P addresses of scales buffer + # ── StdMoE ConvertDispatchOutput parameters ── + addr_packed_recv_x: fx.Int64, # expert-major token buffer + addr_packed_recv_count: fx.Int64, # per-expert token count (i32[experts_per_rank]) + addr_packed_recv_src_info: fx.Int64, # source info (i32[experts_per_rank * max_tok_per_expert]) + addr_disp_tok_map: fx.Int64, # slot mapping (i64[max_recv * top_k]) + addr_disp_grid_bar: fx.Int64, # grid barrier (i32[1]) + cur_tok: fx.Int32, # runtime token count for the current batch + ): + tid = fx.thread_idx.x # thread id within the block + bid = fx.block_idx.x # block id within the grid + lane = tid & 63 # lane id within the warp (0..63) + warp = tid >> 6 # warp id within the block + global_warp_id = bid * warp_num_per_block + warp # warp id across the grid + global_warp_num = block_num * warp_num_per_block # total warps in the grid + work_limit = cur_tok * experts_per_token # total (token, k-slot) pairs + _r_idx = create_buffer_resource_from_addr(addr_idx) + _r_wts = create_buffer_resource_from_addr(addr_wts) + _r_tok_map = create_buffer_resource_from_addr(addr_tok_map) + _r_tok_off = create_buffer_resource_from_addr(addr_tok_off) + _r_dest_ctr = create_buffer_resource_from_addr(addr_dest_ctr) + _r_disp_bar = create_buffer_resource_from_addr(addr_disp_bar) + _r_total_rv = create_buffer_resource_from_addr(addr_total_rv) + _r_p2p_tok_off = create_buffer_resource_from_addr(addr_p2p_tok_off) + _r_p2p_tis = create_buffer_resource_from_addr(addr_p2p_tis) + _r_p2p_out_wts = create_buffer_resource_from_addr(addr_p2p_out_wts) + _r_p2p_out_idx = create_buffer_resource_from_addr(addr_p2p_out_idx) + _r_p2p_out_tok = create_buffer_resource_from_addr(addr_p2p_out_tok) + _r_p2p_recv_num = create_buffer_resource_from_addr(addr_p2p_recv_num) + + # Phase 1: P2P-scatter tokens to their destination PEs. + # Iteration space: every (src_tok, k_slot) pair, distributed across + # all grid-wide warps. ``k_slot`` is the per-token expert slot index + # (i.e. which of the top-k experts this work-item handles). + for work_idx in range(global_warp_id, work_limit, global_warp_num): + src_tok = (work_idx // experts_per_token) + k_slot = (work_idx % experts_per_token) + # Issue the two idx loads in parallel; divui is deferred so the + # loads do not block on the integer divide. + dest_expert = buffer_load(_r_idx, work_idx, vec_width=1, dtype=T.i32()) + safe_lane = arith.select(lane < k_slot, lane, 0) + lane_expert = buffer_load(_r_idx, src_tok * experts_per_token + safe_lane, vec_width=1, dtype=T.i32()) + dest_pe = (dest_expert // experts_per_rank) + lane_dest_pe = (lane_expert // experts_per_rank) + # Per-lane "is this lane a duplicate destPE assignment for some + # k_slot earlier than the current one?" (sentinel 64 = no). + dup_per_lane = arith.select( + lane_dest_pe == dest_pe, + arith.select(lane < k_slot, lane, 64), + 64) + dup_ballot = ballot_i64(dup_per_lane < 64) + is_dup = dup_ballot != 0 + + # Atomically allocate dest_tok_id on lane 0, then readlane-broadcast. + dest_tok_lane0 = arith.constant(0) + if lane == 0: + if dup_ballot == 0: + dest_tok_lane0 = atomic_add_global_at( + buffer_load(_r_p2p_tok_off, dest_pe, vec_width=1, dtype=T.i64()), + arith.constant(1)) + dest_tok_id = readlane(dest_tok_lane0, 0) + + # Per-(token, k_slot) entry stored into dest_tok_map: encoded + # global slot id, or sentinel ``npes * max_recv`` for dup-slots + # which the combine kernel will treat as "no source". + sentinel_val = npes * max_recv + tok_map_entry = arith.select( + is_dup, + sentinel_val, + dest_pe * max_recv + dest_tok_id) + if lane == 0: + buffer_store(tok_map_entry, _r_tok_map, work_idx) + + if lane == 0: + if dup_ballot == 0: + # Publish the (src_pe, src_lid) origin so the dest PE + # can later route the token back during combine. + src_tok_enc = rank * max_tok_per_rank + src_tok + _r_tis_remote = create_buffer_resource_from_addr( + buffer_load(_r_p2p_tis, dest_pe, vec_width=1, dtype=T.i64())) + buffer_store(src_tok_enc, _r_tis_remote, dest_tok_id) + dest_ctr_addr = addr_dest_ctr + arith.zext_i64(dest_pe) * 4 + atomic_add_global_at(dest_ctr_addr, arith.constant(1)) + + # Each lane writes one (weight, expert_idx) entry to the dest + # PE's symmetric weights / idx buffers, parallel over k_slot. + if lane < experts_per_token: + if dup_ballot == 0: + wt_src_off = src_tok * experts_per_token + lane + wt_val = buffer_load(_r_wts, wt_src_off, vec_width=1, dtype=T.f32()) + idx_val = buffer_load(_r_idx, wt_src_off, vec_width=1, dtype=T.i32()) + dest_slot = dest_tok_id * experts_per_token + lane + _r_wts_remote = create_buffer_resource_from_addr( + buffer_load(_r_p2p_out_wts, dest_pe, vec_width=1, dtype=T.i64())) + buffer_store(arith.bitcast(T.i32(), wt_val), _r_wts_remote, dest_slot) + _r_idx_remote = create_buffer_resource_from_addr( + buffer_load(_r_p2p_out_idx, dest_pe, vec_width=1, dtype=T.i64())) + buffer_store(idx_val, _r_idx_remote, dest_slot) + + if const_expr(enable_scales): + if lane < scale_n_i32: + if dup_ballot == 0: + _r_scales = create_buffer_resource_from_addr(addr_scales) + sc_src_off = src_tok * scale_n_i32 + lane + sc_val = buffer_load(_r_scales, sc_src_off, vec_width=1, dtype=T.i32()) + sc_dst_off = dest_tok_id * scale_n_i32 + lane + _r_sc_remote = create_buffer_resource_from_addr( + buffer_load( + create_buffer_resource_from_addr(_lv_unwrap(addr_p2p_out_scales)), + dest_pe, vec_width=1, dtype=T.i64())) + buffer_store(sc_val, _r_sc_remote, sc_dst_off) + + # Token-embedding scatter: when ``is_dup`` the copy_end equals + # ``lane_i32_off`` and the loop trips zero iterations. + # + # ``lane_i32_off`` - this lane's starting i32 offset (each lane + # owns 4 consecutive i32 = 16 bytes). + # ``chunk_i32_off`` - sliding i32 offset within the token's + # hidden-dim chunk being copied this step. + remote_tok_addr = buffer_load(_r_p2p_out_tok, dest_pe, vec_width=1, dtype=T.i64()) + \ + arith.zext_i64(dest_tok_id) * nbytes + local_tok_addr = addr_inp_tok + arith.zext_i64(src_tok) * nbytes + rsrc_src = create_buffer_resource_from_addr(local_tok_addr) + rsrc_dst = create_buffer_resource_from_addr(remote_tok_addr) + lane_i32_off = lane * 4 + safe_end_i32 = (n_i32 // 512) * 512 # largest multiple of 512 that fits + if const_expr(n_i32 >= 512 and safe_end_i32 > 0): + copy_end_main = arith.select(is_dup, lane_i32_off, safe_end_i32) + for chunk_i32_off in range(lane_i32_off, copy_end_main, 512): + vec_a = buffer_load(rsrc_src, chunk_i32_off, vec_width=4, dtype=T.i32()) + vec_b = buffer_load(rsrc_src, chunk_i32_off + 256, vec_width=4, dtype=T.i32()) + buffer_store(vec_a, rsrc_dst, chunk_i32_off) + buffer_store(vec_b, rsrc_dst, chunk_i32_off + 256) + if const_expr(safe_end_i32 < n_i32): + copy_end_tail = arith.select(is_dup, lane_i32_off, n_i32) + for chunk_i32_off in range(lane_i32_off + safe_end_i32, copy_end_tail, 256): + vec_a = buffer_load(rsrc_src, chunk_i32_off, vec_width=4, dtype=T.i32()) + buffer_store(vec_a, rsrc_dst, chunk_i32_off) + elif const_expr(n_i32 < 512): + copy_end_small = arith.select(is_dup, lane_i32_off, n_i32) + for chunk_i32_off in range(lane_i32_off, copy_end_small, 256): + vec_a = buffer_load(rsrc_src, chunk_i32_off, vec_width=4, dtype=T.i32()) + buffer_store(vec_a, rsrc_dst, chunk_i32_off) + + # Phase 2: grid barrier + publish per-peer token-count signal. + # ``recv_num`` is a symmetric ``i32[npes]`` array: index ``src_pe`` + # on dest holds the count of tokens that ``src_pe`` will send. + fx.barrier() + if tid == 0: + atomic_add_global_at(addr_disp_bar, arith.constant(1)) + + recv_num_local_byte_off = arith.zext_i64(arith.constant(rank)) * 4 + for dest_pe in range(lane, npes, 64): + if global_warp_id == 0: + mori_shmem.int32_wait_until_equals(addr_disp_bar, block_num) + buffer_store(arith.constant(0), _r_disp_bar, 0) + # +1 because 0 is the "unset" sentinel that consumers wait on. + signal_value = buffer_load(_r_dest_ctr, dest_pe, vec_width=1, dtype=T.i32()) + 1 + recv_num_remote_addr = buffer_load( + _r_p2p_recv_num, dest_pe, vec_width=1, dtype=T.i64()) + recv_num_local_byte_off + mori_shmem.int32_wait_until_equals(recv_num_remote_addr, 0) + store_i32_system(recv_num_remote_addr, arith.constant(0), signal_value) + + # Phase 3: wait for each peer's count signal and accumulate total_recv. + for src_pe in range(lane, npes, 64): + if global_warp_id == 0: + recv_num_src_addr = addr_recv_num + arith.zext_i64(src_pe) * 4 + signal_value = mori_shmem.int32_wait_until_greater_than(recv_num_src_addr, 0) + peer_recv_count = signal_value - 1 # undo the +1 sentinel offset + store_i32_system(recv_num_src_addr, arith.constant(0), arith.constant(0)) + atomic_add_global_at(addr_total_rv, peer_recv_count) + buffer_store(arith.constant(0), _r_dest_ctr, src_pe) + + if global_warp_id == 0: + if lane == 0: + buffer_store(arith.constant(0), _r_tok_off, 0) + + # Phase 4: ConvertDispatchOutput (StdMoE). + # Repack received tokens into per-expert buckets indexed by + # ``local_expert_id``. Each (received_tok, k_slot) pair allocates a + # slot in ``packed_recv_x[local_expert_id]`` if the expert is local. + if const_expr(enable_std_moe): + fx.barrier() + if tid == 0: + atomic_add_global_at(addr_disp_grid_bar, arith.constant(1)) + fx.barrier() + if tid == 0: + mori_shmem.int32_wait_until_equals(addr_disp_grid_bar, block_num) + fx.barrier() + + _r_out_idx_local = create_buffer_resource_from_addr(addr_out_idx) + _r_tis_local = create_buffer_resource_from_addr(addr_tis) + _r_out_tok_local = create_buffer_resource_from_addr(addr_out_tok) + total_recv = buffer_load(_r_total_rv, 0, vec_width=1, dtype=T.i32()) + smoe_work_limit = total_recv * experts_per_token + + for smoe_idx in range(global_warp_id, smoe_work_limit, global_warp_num): + smoe_tok_id = (smoe_idx // experts_per_token) + + expert_id = buffer_load(_r_out_idx_local, smoe_idx, vec_width=1, dtype=T.i32()) + local_expert_id = expert_id - rank * experts_per_rank + # MUST be unsigned ``ult``: when ``expert_id`` is NOT this + # rank's expert, ``local_expert_id`` is negative; the + # signed-overload form ``local_expert_id < experts_per_rank`` + # lowers to ``arith.cmpi slt`` and would mis-classify negative + # values as local (-> illegal global access in WarpCopy). + is_local = arith.cmpi(arith.CmpIPredicate.ult, local_expert_id, + arith.constant(experts_per_rank)) + + # Atomically allocate the per-expert packing slot on lane 0. + packed_slot_lane0 = arith.constant(0) + if lane == 0: + if is_local: + count_addr = addr_packed_recv_count + arith.zext_i64(local_expert_id) * 4 + packed_slot_lane0 = atomic_add_global_at(count_addr, arith.constant(1)) + packed_slot = readlane(packed_slot_lane0, 0) + + safe_local_expert = arith.select(is_local, local_expert_id, 0) + # Linear slot in the flat ``packed_recv_x[experts_per_rank, max_tokens_per_expert]`` buffer. + packed_linear_idx = safe_local_expert * max_tokens_per_expert + packed_slot + slot_val_i64 = arith.select(is_local, + arith.zext_i64(packed_linear_idx), + -1) # false_value materialized as i64 from true_value's type; -1 = not a local expert + if lane == 0: + slot_map_addr = addr_disp_tok_map + arith.zext_i64(smoe_idx) * 8 + store_i64_global_system(slot_map_addr, slot_val_i64) + + if lane == 0: + if is_local: + src_pos_enc = buffer_load(_r_tis_local, smoe_tok_id, + vec_width=1, dtype=T.i32()) + store_i32_system(addr_packed_recv_src_info, + packed_linear_idx, src_pos_enc) + + # WarpCopy token data from shmem_out_tok into the packed + # per-expert buffer at slot ``packed_linear_idx``. + src_tok_base = addr_out_tok + arith.zext_i64(smoe_tok_id) * nbytes + dst_tok_base = addr_packed_recv_x + arith.zext_i64(packed_linear_idx) * nbytes + rsrc_src = create_buffer_resource_from_addr(src_tok_base) + rsrc_dst = create_buffer_resource_from_addr(dst_tok_base) + lane_i32_off = lane * 4 + safe_end_i32 = (n_i32 // 512) * 512 + if n_i32 >= 512 and safe_end_i32 > 0: + copy_end_main = arith.select(is_local, safe_end_i32, lane_i32_off) + for chunk_i32_off in range(lane_i32_off, copy_end_main, 512): + vec_a = buffer_load(rsrc_src, chunk_i32_off, vec_width=4, dtype=T.i32()) + vec_b = buffer_load(rsrc_src, chunk_i32_off + 256, vec_width=4, dtype=T.i32()) + buffer_store(vec_a, rsrc_dst, chunk_i32_off) + buffer_store(vec_b, rsrc_dst, chunk_i32_off + 256) + if safe_end_i32 < n_i32: + copy_end_tail = arith.select(is_local, n_i32, lane_i32_off) + for chunk_i32_off in range(lane_i32_off + safe_end_i32, copy_end_tail, 256): + vec_a = buffer_load(rsrc_src, chunk_i32_off, vec_width=4, dtype=T.i32()) + buffer_store(vec_a, rsrc_dst, chunk_i32_off) + elif n_i32 < 512: + copy_end_small = arith.select(is_local, n_i32, lane_i32_off) + for chunk_i32_off in range(lane_i32_off, copy_end_small, 256): + vec_a = buffer_load(rsrc_src, chunk_i32_off, vec_width=4, dtype=T.i32()) + buffer_store(vec_a, rsrc_dst, chunk_i32_off) + + return ep_dispatch_intranode + + +def make_combine_kernel( + *, + rank: int, + npes: int, + experts_per_rank: int = 0, + experts_per_token: int, + hidden_dim: int, + hidden_elem_size: int, + max_tok_per_rank: int, + block_num: int, + warp_num_per_block: int, + data_type=None, + enable_weights: bool = False, + enable_std_moe: bool = False, + use_p2p_read: bool = False, + skip_stage1: bool = False, + inp_data_type=None, +): + """Build the intranode combine ``@flyc.kernel``. + + Stages: + * Stage 1 - P2P-scatter token contributions (and optionally weights) + from each rank's GEMM2 output buffer into every peer's + ``shmem_comb_inp``. + * Stage 2 - CrossDeviceBarrier so every rank has observed Stage 1 + writes from every peer. + * Stage 3 - local read of ``shmem_comb_inp`` plus per-expert WarpAccum + reducing into ``addr_comb_out``. + * Stage 3b - parallel weight accumulation (when ``enable_weights``). + + Parameters: + skip_stage1: + Compile-out the token half of Stage 1 (P2P scatter / + ConvertCombineInput). The caller is expected to have staged token + bytes into ``shmem_comb_inp`` ahead of the launch (e.g. fused + GEMM2-epilogue P2P scatter). Weight scatter is still emitted when + ``enable_weights`` is set, because the 16B weight writes share the + ROCm IPC fabric with the heavy token writes from the upstream stage + and get silently dropped under contention — the combine kernel + therefore owns weight scatter on a quiet fabric. + inp_data_type: + External input dtype. When different from ``data_type`` (currently + only ``bfloat16 + float8_e4m3fn`` is supported) Stage 1 fuses a + bf16 -> fp8 cast inline (``UseFp8DirectCast``-equivalent), and + Stage 3 widens addressing strides for bf16 output writes. + """ + max_recv = npes * max_tok_per_rank + _is_fp4 = (data_type == torch.float4_e2m1fn_x2) + if _is_fp4: + n_i32 = hidden_dim // 8 + nbytes = hidden_dim // 2 + else: + n_i32 = (hidden_dim * hidden_elem_size) // 4 + nbytes = hidden_dim * hidden_elem_size + tok_stride = n_i32 * 4 + + # Mixed-dtype combine: external dtype (kernel input AND output) differs + # from the on-wire/staging dtype used for P2P transport. Currently only + # supports bf16 external + OCP fp8 transport (mori UseFp8DirectCast). + # + # Semantics ("fp8_direct_cast"): + # - kernel reads bf16 input → Stage 1 inline cast bf16→fp8 → P2P fp8 + # - kernel reads fp8 staging → Stage 3 reduce in f32 → cast f32→bf16 + # → kernel writes bf16 output to ``addr_comb_out`` + # + # ``inp_data_type`` is the legacy parameter name; conceptually it now + # represents the external (input/output-shared) dtype. + _xfer_bf16_to_fp8 = ( + inp_data_type is not None and inp_data_type != data_type + and inp_data_type == torch.bfloat16 + and data_type == torch.float8_e4m3fn + ) + if inp_data_type is not None and inp_data_type != data_type and not _xfer_bf16_to_fp8: + raise NotImplementedError( + f"combine_kernel mixed-dtype only supports " + f"inp_data_type=bfloat16 + data_type=float8_e4m3fn, " + f"got inp_data_type={inp_data_type}, data_type={data_type}") + if _xfer_bf16_to_fp8 and enable_std_moe: + raise NotImplementedError( + "combine_kernel mixed-dtype path does not yet support " + "enable_std_moe=True (the std-MoE Stage 1 / Stage 3 use " + "_weighted_accum_experts which has not been retrofitted for " + "asymmetric I/O dtypes)") + + if _xfer_bf16_to_fp8: + # bf16 input stride for Stage 1 source addressing only. The transport + # (P2P-scattered staging) uses ``nbytes`` (= fp8 stride) as before. + # Stage 3 output addressing also uses bf16 stride (= 2 × fp8 stride). + inp_nbytes = hidden_dim * 2 + inp_n_i32 = (hidden_dim * 2) // 4 + # bf16-stride i32 count per token for Stage 3 output offsets. + out_n_i32 = (hidden_dim * 2) // 4 + else: + inp_nbytes = nbytes + inp_n_i32 = n_i32 + out_n_i32 = n_i32 + if _is_fp4: + from flydsl._mlir.dialects import rocdl as _rocdl_d + _v2f32_fp4 = T.VectorType.get([2], T.f32()) + _v8f32_fp4 = T.VectorType.get([8], T.f32()) + + def _to_accum(i32_val): + # ROCDL fp4 lane unpack: i32 (8 packed fp4) -> 4 × vector<2xf32>. + scale_one = arith.constant(1.0, type=T.f32()) + pairs = [ + _rocdl_d.cvt_scalef32_pk_f32_fp4( + res=_v2f32_fp4, src=i32_val, scale=scale_one, + src_sel_index=sel) + for sel in range(4) + ] + # Stitch 4 × v2f32 -> v8f32 via two-stage shuffle. + lo4 = vector.shuffle(pairs[0], pairs[1], [0, 1, 2, 3]) + hi4 = vector.shuffle(pairs[2], pairs[3], [0, 1, 2, 3]) + return vector.shuffle(lo4, hi4, [0, 1, 2, 3, 4, 5, 6, 7]) + + def _from_accum(accum_val): + # Re-pack v8f32 -> i32 via 4 × cvt_scalef32_pk_fp4_f32. + _i32_ty = _IntTy.get_signless(32) + scale_one = arith.constant(1.0, type=T.f32()) + old = arith.constant(0, type=_i32_ty) + for sel in range(4): + f_a = vector.extract(accum_val, static_position=[sel * 2]) + f_b = vector.extract(accum_val, static_position=[sel * 2 + 1]) + old = _rocdl_d.cvt_scalef32_pk_fp4_f32( + res=_i32_ty, old_vdst=old, src0=f_a, src1=f_b, + scale=scale_one, dst_sel_index=sel) + return old + + def _zero_accum(): + return arith.constant_vector(0.0, _v8f32_fp4) + + elif hidden_elem_size == 2: # bf16 + def _to_accum(i32_val): + return bitcast_i32_to_v2bf16(i32_val).extf( + T.VectorType.get([2], T.f32())) + def _from_accum(accum_val): + return bitcast_v2bf16_to_i32(accum_val.truncf( + T.VectorType.get([2], T.bf16()))) + def _zero_accum(): + return arith.constant_vector(0.0, T.VectorType.get([2], T.f32())) + elif hidden_elem_size == 4: # f32 + def _to_accum(i32_val): + return arith.bitcast(T.f32(), i32_val) + def _from_accum(accum_val): + return arith.bitcast(T.i32(), accum_val) + def _zero_accum(): + return arith.constant(0.0, type=T.f32()) + elif hidden_elem_size == 1: # fp8 + from flydsl._mlir.dialects import rocdl as _rocdl_d + _is_ocp = (data_type == torch.float8_e4m3fn) + _is_fnuz = (data_type == torch.float8_e4m3fnuz) + _cvt_pk_f32 = _rocdl_d.cvt_pk_f32_fp8 + _cvt_pk_f8 = _rocdl_d.cvt_pk_fp8_f32 + _v2f32_fp8 = T.VectorType.get([2], T.f32()) + _v4f32_fp8 = T.VectorType.get([4], T.f32()) + + def _to_accum(i32_val): + # ROCDL fp8 lane unpack: i32 (4 packed fp8) -> 2 × vector<2xf32>. + lo = _cvt_pk_f32(res=_v2f32_fp8, src=i32_val, word_sel=False) + hi = _cvt_pk_f32(res=_v2f32_fp8, src=i32_val, word_sel=True) + # Concatenate lo|hi -> vector<4xf32> (mask picks lo[0,1], hi[0,1]). + vec = vector.shuffle(lo, hi, [0, 1, 2, 3]) + if _is_fnuz: + vec = vec * 0.5 + return vec + + def _from_accum(accum_val): + _i32_ty = _IntTy.get_signless(32) + if _is_fnuz: + accum_val = accum_val * 2.0 + if const_expr(_xfer_bf16_to_fp8): + # Mixed-dtype path: write bf16 output (8 bytes per lane). + # v4f32 -> v4bf16 (truncf) -> v2i32 (bitcast). Caller stores + # via buffer_store(..., vec_width=2, dtype=T.i32()) at an i32 + # offset doubled relative to fp8 mode (2 i32 = 4 bf16 = 8 B). + _v4bf16 = T.VectorType.get([4], T.bf16()) + _v2i32 = T.VectorType.get([2], _i32_ty) + return vector.bitcast(_v2i32, accum_val.truncf(_v4bf16)) + f0 = vector.extract(accum_val, static_position=[0]) + f1 = vector.extract(accum_val, static_position=[1]) + f2 = vector.extract(accum_val, static_position=[2]) + f3 = vector.extract(accum_val, static_position=[3]) + zero = arith.constant(0, type=_i32_ty) + lo = _cvt_pk_f8(res=_i32_ty, src_a=f0, src_b=f1, + old=zero, word_sel=False) + return _cvt_pk_f8(res=_i32_ty, src_a=f2, src_b=f3, + old=lo, word_sel=True) + + def _zero_accum(): + return arith.constant_vector(0.0, _v4f32_fp8) + else: + raise ValueError(f"Unsupported hidden_elem_size={hidden_elem_size}") + + def _accum_experts(vals, vlds, all_vld): + """Reduce the k per-expert i32 partials into one merged i32. + + Each value is widened via ``_to_accum`` (bf16/fp8/...->f32 vector), + summed in high precision, then narrowed back via ``_from_accum``. + + Args: + vals: per-expert raw i32 values (one per k-slot). + vlds: per-expert i1 validity flags (used iff ``all_vld`` is False). + all_vld: when True, skip the masking and treat every slot as live. + """ + if all_vld: + acc = _to_accum(vals[0]) + for k_slot in range(1, len(vals)): + acc = acc + _to_accum(vals[k_slot]) + else: + acc = _zero_accum() + for k_slot in range(len(vals)): + widened = _to_accum(vals[k_slot]) + zero = _zero_accum() + vld_raw = _lv_unwrap(vlds[k_slot]) + acc = acc + arith.select(vld_raw, widened, zero) + return _from_accum(acc) + + def _weighted_accum_experts(vals, wts, vlds, all_vld): + """Weighted variant of ``_accum_experts``: ``sum(wt[k] * val[k])``. + + Used by the StdMoE Stage 1 path where the kernel reduces the k + per-expert contributions (each multiplied by the dispatch-time + output weight) into one merged token before the P2P scatter. + """ + _i32ty = _IntTy.get_signless(32) + _f32ty = T.f32() + + if _is_fp4: # fp4 → v8f32 accum + from flydsl._mlir.dialects import rocdl as _rocdl_fp4 + _v2f32 = T.VectorType.get([2], T.f32()) + _v8f32 = T.VectorType.get([8], T.f32()) + scale_one = arith.constant(1.0, type=_f32ty) + acc = arith.constant_vector(0.0, _v8f32) + for j in range(len(vals)): + # ROCDL fp4 lane unpack: i32 (8 packed fp4) -> 4 × vector<2xf32>. + pairs = [ + _rocdl_fp4.cvt_scalef32_pk_f32_fp4( + res=_v2f32, src=vals[j], scale=scale_one, + src_sel_index=sel) + for sel in range(4) + ] + # Stitch 4 × v2f32 -> v8f32 via two-stage shuffle. + lo4 = vector.shuffle(pairs[0], pairs[1], [0, 1, 2, 3]) + hi4 = vector.shuffle(pairs[2], pairs[3], [0, 1, 2, 3]) + vec = vector.shuffle(lo4, hi4, [0, 1, 2, 3, 4, 5, 6, 7]) + w = vec * wts[j] # auto-broadcast scalar to v8f32 + if all_vld: + acc = acc + w + else: + acc = acc + arith.select( + vlds[j], w, arith.constant_vector(0.0, _v8f32)) + # Re-pack v8f32 -> i32 via 4 × cvt_scalef32_pk_fp4_f32. + old = arith.constant(0, type=_i32ty) + for sel in range(4): + f_a = vector.extract(acc, static_position=[sel * 2]) + f_b = vector.extract(acc, static_position=[sel * 2 + 1]) + old = _rocdl_fp4.cvt_scalef32_pk_fp4_f32( + res=_i32ty, old_vdst=old, src0=f_a, src1=f_b, + scale=scale_one, dst_sel_index=sel) + return old + + elif hidden_elem_size == 2: # bf16 → v2f32 accum + _v2bf16 = T.VectorType.get([2], T.bf16()) + _v2f32 = T.VectorType.get([2], T.f32()) + acc = arith.constant_vector(0.0, _v2f32) + for j in range(len(vals)): + # i32 → vector<2xbf16> (shape-changing, llvm.bitcast) + # → vector<2xf32> via arith.extf, then broadcast wt and fma. + vb = bitcast_i32_to_v2bf16(vals[j]) + vf = vb.extf(_v2f32) + w = vf * wts[j] # wts[j] scalar f32, auto-broadcast to v2f32 + if all_vld: + acc = acc + w + else: + acc = acc + arith.select( + vlds[j], w, arith.constant_vector(0.0, _v2f32)) + return bitcast_v2bf16_to_i32(acc.truncf(_v2bf16)) + + elif hidden_elem_size == 4: # f32 → f32 accum + acc = arith.constant(0.0, type=_f32ty) + for j in range(len(vals)): + vf = arith.bitcast(_f32ty, vals[j]) + w = vf * wts[j] + if all_vld: + acc = acc + w + else: + acc = acc + arith.select( + vlds[j], w, arith.constant(0.0, type=_f32ty)) + return arith.bitcast(_i32ty, acc) + + elif hidden_elem_size == 1: # fp8 → v4f32 accum + from flydsl._mlir.dialects import rocdl as _rocdl + _pk_f32 = _rocdl.cvt_pk_f32_fp8 + _pk_f8 = _rocdl.cvt_pk_fp8_f32 + _v2f32 = T.VectorType.get([2], T.f32()) + _v4f32 = T.VectorType.get([4], T.f32()) + acc = arith.constant_vector(0.0, _v4f32) + for j in range(len(vals)): + # ROCDL fp8 lane unpack: i32 (4 packed fp8) -> 2 × vector<2xf32>. + lo = _pk_f32(res=_v2f32, src=vals[j], word_sel=False) + hi = _pk_f32(res=_v2f32, src=vals[j], word_sel=True) + # Concatenate lo|hi into vector<4xf32>: + # mask [0,1,2,3] -> [lo[0], lo[1], hi[0], hi[1]] + vec = vector.shuffle(lo, hi, [0, 1, 2, 3]) + if _is_fnuz: + vec = vec * 0.5 + w = vec * wts[j] # wts[j] scalar f32, auto-broadcast to v4f32 + if all_vld: + acc = acc + w + else: + acc = acc + arith.select( + vlds[j], w, arith.constant_vector(0.0, _v4f32)) + if _is_fnuz: + acc = acc * 2.0 + f0 = vector.extract(acc, static_position=[0]) + f1 = vector.extract(acc, static_position=[1]) + f2 = vector.extract(acc, static_position=[2]) + f3 = vector.extract(acc, static_position=[3]) + zi = arith.constant(0, type=_i32ty) + lo = _pk_f8(res=_i32ty, src_a=f0, src_b=f1, old=zi, word_sel=False) + return _pk_f8(res=_i32ty, src_a=f2, src_b=f3, old=lo, word_sel=True) + + def _log2_if_pow2(v): + """Return ``log2(v)`` if *v* is a positive power of two, else ``None``.""" + if v > 0 and (v & (v - 1)) == 0: + return v.bit_length() - 1 + return None + # Pow2 fast-paths: when ``max_tok_per_rank`` / ``max_recv`` are powers + # of two, decode ``dest_pe / dest_lid`` and ``dest_pe / dtok`` via + # shift + mask instead of integer divide / mod. + _log2_max_tok = _log2_if_pow2(max_tok_per_rank) + _log2_max_recv = _log2_if_pow2(max_recv) + _mask_max_tok = max_tok_per_rank - 1 if _log2_max_tok is not None else None + _mask_max_recv = max_recv - 1 if _log2_max_recv is not None else None + + # Dispatch deduplicates same-PE assignments at runtime: when more than + # one of a token's k experts fall on the same dest_pe, the duplicate + # tok_map slot is encoded as ``dest_pe = npes`` (sentinel). The combine + # accumulator must skip those invalid lanes, which is exactly what the + # ``_maybe_load`` helper below does (equivalent to mori's + # ``EpCombineIntraNodeKernel`` ``srcPtrs[j] = nullptr`` short-circuit). + _use_compaction = True + + weight_bytes = experts_per_token * 4 if enable_weights else 0 + wt_n_i32 = experts_per_token if enable_weights else 0 + + # LDS layout for the P2P-base tables (i64[npes] for tokens, optionally + # i64[npes] for weights). ``SmemAllocator.finalize()`` is called from the + # JIT launcher to publish the layout to the GPU module. + allocator = SmemAllocator(None, arch="gfx942") + p2p_base_offset = allocator._align(allocator.ptr, 8) + p2p_base_size = npes * 8 + allocator.ptr = p2p_base_offset + p2p_base_size + + if enable_weights: + p2p_wt_base_offset = allocator._align(allocator.ptr, 8) + p2p_wt_base_size = npes * 8 + allocator.ptr = p2p_wt_base_offset + p2p_wt_base_size + + + @flyc.kernel + def ep_combine_intranode( + addr_inp_tok: fx.Int64, # inp_tok base (post-expert token buffer) + addr_comb_inp: fx.Int64, # shmem_comb_inp base (symmetric) + addr_comb_out: fx.Int64, # shmem_comb_out base (symmetric) + addr_xdb_mem: fx.Int64, # xdev_bar_mem (u64[npes]) + addr_xdb_flag: fx.Int64, # xdev_bar_flag (u64[1]) + addr_tok_map: fx.Int64, # dest_tok_map (i32[cur_tok*k]) + addr_comb_bar: fx.Int64, # combine_bar (i32[1]) + addr_trecv: fx.Int64, # total_recv_ptr (i32[1]) + addr_tis: fx.Int64, # tok_id_to_src (i32[max_recv], symmetric) + addr_p2p_comb_inp: fx.Int64, # i64[npes] pre-resolved P2P addresses + addr_p2p_xdb_mem: fx.Int64, # i64[npes] pre-resolved P2P addresses + addr_wts_buf: fx.Int64, # combine input weights f32[max_recv*k] + addr_comb_inp_wts: fx.Int64, # shmem weight P2P buffer (symmetric) + addr_comb_out_wts: fx.Int64, # combine output weights f32[max_tok*k] + addr_p2p_comb_inp_wts: fx.Int64, # i64[npes] weight P2P addresses + # ── StdMoE ConvertCombineInput parameters ── + addr_packed_recv_x: fx.Int64, # expert-major token buffer (post-expert) + addr_disp_tok_map: fx.Int64, # dispTokToEpSlotMap (i64[max_recv * top_k]) + addr_disp_out_wts: fx.Int64, # dispatch output weights (f32[max_recv * top_k]) + cur_rank_num_token: fx.Int32, # this PE's output token count (used by Stage 3) + ): + tid = fx.thread_idx.x + bid = fx.block_idx.x + lane = tid & 63 + warp = tid >> 6 + global_warp_id = bid * warp_num_per_block + warp # warp id across the grid + global_warp_num = block_num * warp_num_per_block # total warps in the grid + grid_thread_id = bid * (warp_num_per_block * 64) + tid # grid-wide thread id (used by Stage 2 only) + + # Predicated buffer_load: returns 0 (i32) when vld_flag is false. + # Defined as a nested function so the AST rewriter lowers the Python + # ``if`` to ``scf.if`` for every call site (the rewriter only walks + # function bodies inside ``@flyc.kernel`` and their nested defs). + def _maybe_load(rsrc, offset, vld_flag, **kwargs): + result = arith.constant(0, type=T.i32()) + if vld_flag: + result = buffer_load(rsrc, offset, **kwargs) + return result + + _r_trecv = create_buffer_resource_from_addr(addr_trecv) + _r_xdb_flag = create_buffer_resource_from_addr(addr_xdb_flag) + _r_tis = create_buffer_resource_from_addr(addr_tis) + _r_comb_bar = create_buffer_resource_from_addr(addr_comb_bar) + _r_p2p_comb = create_buffer_resource_from_addr(addr_p2p_comb_inp) + _r_p2p_xdb = create_buffer_resource_from_addr(addr_p2p_xdb_mem) + _rsrc_tok_map = create_buffer_resource_from_addr(addr_tok_map) + + total_recv = buffer_load(_r_trecv, 0, vec_width=1, dtype=T.i32()) + # Per-launch monotonically-incrementing flag value used by Stage 2's + # cross-device barrier (each rank waits to observe this value from + # every peer). + xdb_cur_flag = buffer_load(_r_xdb_flag, 0, vec_width=1, dtype=T.i64()) + + # LDS-resident table of pre-resolved P2P base addresses (i64[npes]). + # Cached once in shared memory so the Stage 1 scatter loop (which + # may visit thousands of tokens per warp) avoids reissuing a global + # load for the same per-peer base on every iteration. + base_ptr = allocator.get_base() + # NOTE: SmemPtr ops are intentionally written as unbound-class calls + # (``SmemPtr.(instance, ...)`` rather than ``instance.(...)``) + # to avoid the upstream ast_rewriter heuristic that treats any + # ``var.method(...)`` inside an scf-lowered if/for as a loop-carried + # variable (which then fails because SmemPtr is not an MLIR Value). + # All ``_lds_p2p_*`` and downstream ``SmemPtr.{get,load,store}`` + # call sites follow the same convention. + _lds_p2p_bases = SmemPtr(base_ptr, p2p_base_offset, T.i64(), + shape=(npes,)) + SmemPtr.get(_lds_p2p_bases) + + if lane < npes: + p2p_base_addr = buffer_load(_r_p2p_comb, lane, vec_width=1, dtype=T.i64()) + SmemPtr.store(_lds_p2p_bases, p2p_base_addr, [lane]) + + if const_expr(enable_weights): + _r_p2p_comb_wt = create_buffer_resource_from_addr(addr_p2p_comb_inp_wts) + _lds_p2p_wt_bases = SmemPtr(base_ptr, p2p_wt_base_offset, T.i64(), + shape=(npes,)) + SmemPtr.get(_lds_p2p_wt_bases) + if lane < npes: + p2p_wt_base_addr = buffer_load(_r_p2p_comb_wt, lane, vec_width=1, dtype=T.i64()) + SmemPtr.store(_lds_p2p_wt_bases, p2p_wt_base_addr, [lane]) + + fx.barrier() + + # Stage 1: P2P scatter / ConvertCombineInput. + # When ``skip_stage1`` is set the entire stage is compile-time + # eliminated; the caller is responsible for having pre-staged the + # equivalent P2P writes into shmem_comb_inp[_wts]. + # + # Common per-token decoding from ``shmem_tok_id_to_src[recv_tok_id]``: + # dest_pe - which peer this token must be combined to + # dest_lid - the per-PE local id ``[0, max_tok_per_rank)`` + n_chunks = nbytes // 16 # 16-byte (4-i32) vector chunks per token + + if const_expr(skip_stage1): + if const_expr(enable_weights): + # Weight-only Stage 1: same as default path but only writes + # the small weight slot (no per-token hidden bytes). Used by + # fused_gemm2_combine to keep weight scatter off the heavy + # token-write fabric. + for recv_tok_id in range(global_warp_id, total_recv, global_warp_num): + dest_tok_enc = buffer_load(_r_tis, recv_tok_id, vec_width=1, dtype=T.i32()) + if const_expr(_log2_max_tok is not None): + dest_pe = dest_tok_enc >> _log2_max_tok + dest_lid = dest_tok_enc & _mask_max_tok + else: + dest_pe = (dest_tok_enc // max_tok_per_rank) + dest_lid = (dest_tok_enc % max_tok_per_rank) + wt_pe_base = SmemPtr.load(_lds_p2p_wt_bases, [dest_pe]) + wt_dest_off = arith.zext_i64( + rank * max_tok_per_rank + dest_lid) * weight_bytes + wt_dest_addr = _lv_unwrap(wt_pe_base) + wt_dest_off + wt_src_addr = _lv_unwrap(addr_wts_buf) + arith.zext_i64(recv_tok_id) * weight_bytes + rsrc_wt_src = create_buffer_resource_from_addr(wt_src_addr) + rsrc_wt_dst = create_buffer_resource_from_addr(wt_dest_addr) + if lane < wt_n_i32: + wt_val = buffer_load(rsrc_wt_src, lane, vec_width=1, dtype=T.i32()) + buffer_store(wt_val, rsrc_wt_dst, lane) + else: + pass + elif const_expr(enable_std_moe): + # Stage 1 StdMoE: read the k-expert partials from + # ``packed_recv_x`` (per-expert buckets), reduce with the + # dispatch-time output weights, and scatter the merged token to + # the destination PE's ``shmem_comb_inp``. + _rsrc_dtm = create_buffer_resource_from_addr(addr_disp_tok_map) + _rsrc_dow = create_buffer_resource_from_addr(addr_disp_out_wts) + smoe_all_vld = False # k-slots may be sentinel (-1) for non-local experts + + for recv_tok_id in range(global_warp_id, total_recv, global_warp_num): + dest_tok_enc = buffer_load(_r_tis, recv_tok_id, vec_width=1, dtype=T.i32()) + if const_expr(_log2_max_tok is not None): + dest_pe = dest_tok_enc >> _log2_max_tok + dest_lid = dest_tok_enc & _mask_max_tok + else: + dest_pe = (dest_tok_enc // max_tok_per_rank) + dest_lid = (dest_tok_enc % max_tok_per_rank) + + if const_expr(use_p2p_read): + # P2P-read mode: write locally; peers will pull from us in Stage 3. + dest_byte_off = arith.zext_i64(recv_tok_id) * nbytes + dest_tok_addr = _lv_unwrap(addr_comb_inp) + dest_byte_off + else: + peer_base = SmemPtr.load(_lds_p2p_bases, [dest_pe]) + dest_byte_off = arith.zext_i64(rank * max_tok_per_rank + dest_lid) * nbytes + dest_tok_addr = _lv_unwrap(peer_base) + dest_byte_off + rsrc_dst = create_buffer_resource_from_addr(dest_tok_addr) + + # Collect resources/valid-flags/weights for each k-expert slot. + expert_rsrcs = [] + expert_vlds = [] + expert_wts = [] + for k_slot in range_constexpr(experts_per_token): + slot_addr = addr_disp_tok_map + arith.zext_i64(recv_tok_id * experts_per_token + k_slot) * 8 + slot_val = load_i64_global(slot_addr) + slot_vld = slot_val != -1 + safe_slot = arith.select(slot_vld, slot_val, 0) + expert_tok_addr = addr_packed_recv_x + safe_slot * nbytes + expert_rsrcs.append(create_buffer_resource_from_addr(expert_tok_addr)) + expert_vlds.append(slot_vld) + wt_k = buffer_load(_rsrc_dow, recv_tok_id * experts_per_token + k_slot, + vec_width=1, dtype=T.f32()) + expert_wts.append(wt_k) + + # Weighted reduce across the k experts, then scatter. + for elem_off in range(lane, n_i32, 64): + expert_vals = [] + for k_slot in range_constexpr(experts_per_token): + expert_vals.append(buffer_load(expert_rsrcs[k_slot], elem_off, + vec_width=1, dtype=T.i32())) + accum = _weighted_accum_experts(expert_vals, expert_wts, + expert_vlds, smoe_all_vld) + buffer_store(accum, rsrc_dst, elem_off) + + if const_expr(enable_weights): + if const_expr(use_p2p_read): + wt_dest_off = arith.zext_i64(recv_tok_id) * weight_bytes + wt_dest_addr = _lv_unwrap(addr_comb_inp_wts) + wt_dest_off + else: + wt_pe_base = SmemPtr.load(_lds_p2p_wt_bases, [dest_pe]) + wt_dest_off = arith.zext_i64( + rank * max_tok_per_rank + dest_lid) * weight_bytes + wt_dest_addr = _lv_unwrap(wt_pe_base) + wt_dest_off + wt_src_addr = _lv_unwrap(addr_wts_buf) + arith.zext_i64(recv_tok_id) * weight_bytes + rsrc_wt_src = create_buffer_resource_from_addr(wt_src_addr) + rsrc_wt_dst = create_buffer_resource_from_addr(wt_dest_addr) + if lane < wt_n_i32: + wt_val = buffer_load(rsrc_wt_src, lane, vec_width=1, dtype=T.i32()) + buffer_store(wt_val, rsrc_wt_dst, lane) + + elif const_expr(use_p2p_read): + # Stage 1 P2P-read mode: every rank writes its post-expert + # tokens into its OWN ``shmem_comb_inp`` slot indexed by + # ``recv_tok_id`` (no remote write). Peers will read these + # buffers cross-device during Stage 3. + dual_end_aligned = (n_chunks // 128) * 128 + for recv_tok_id in range(global_warp_id, total_recv, global_warp_num): + # In mixed-mode (bf16 input → fp8 staging), the source uses + # bf16 stride (inp_nbytes) while the dest uses fp8 stride + # (nbytes); in same-dtype mode the two strides are identical. + src_tok_addr = addr_inp_tok + arith.zext_i64(recv_tok_id) * inp_nbytes + dst_tok_addr = addr_comb_inp + arith.zext_i64(recv_tok_id) * nbytes + rsrc_src = create_buffer_resource_from_addr(src_tok_addr) + rsrc_dst = create_buffer_resource_from_addr(dst_tok_addr) + if const_expr(_xfer_bf16_to_fp8): + # Mixed-dtype Stage 1: load bf16 (2 i32 / lane = 4 bf16 + # elems) → ExtF v4f32 → cvt_pk_fp8_f32 ×2 → store 1 + # fp8 i32 (4 fp8 elems) at staging offset ``elem_off``. + from flydsl._mlir.dialects import rocdl as _rocdl_s1a + _v4bf16_a = T.VectorType.get([4], T.bf16()) + _v4f32_a = T.VectorType.get([4], T.f32()) + _i32t_a = T.i32() + for elem_off in range(lane, n_i32, 64): + bf_pair = buffer_load(rsrc_src, elem_off * 2, + vec_width=2, dtype=T.i32()) + v4f = vector.bitcast(_v4bf16_a, bf_pair).extf(_v4f32_a) + f0 = vector.extract(v4f, static_position=[0]) + f1 = vector.extract(v4f, static_position=[1]) + f2 = vector.extract(v4f, static_position=[2]) + f3 = vector.extract(v4f, static_position=[3]) + zi = arith.constant(0, type=_i32t_a) + lo = _rocdl_s1a.cvt_pk_fp8_f32(res=_i32t_a, src_a=f0, src_b=f1, + old=zi, word_sel=False) + fp8_i32 = _rocdl_s1a.cvt_pk_fp8_f32(res=_i32t_a, src_a=f2, src_b=f3, + old=lo, word_sel=True) + buffer_store(fp8_i32, rsrc_dst, elem_off) + else: + # Same-dtype path: 4-i32 vector copy. ``chunk_idx`` is + # the 16-byte-chunk index this lane is currently + # copying; ``chunk_i32_off`` translates it to i32 elems. + if const_expr(dual_end_aligned >= 128): + for chunk_idx in range(lane, dual_end_aligned, 128): + chunk_i32_off = chunk_idx * 4 + chunk_i32_off_alt = (chunk_idx + 64) * 4 + vec_a = buffer_load(rsrc_src, chunk_i32_off, vec_width=4, dtype=T.i32()) + vec_b = buffer_load(rsrc_src, chunk_i32_off_alt, vec_width=4, dtype=T.i32()) + buffer_store(vec_a, rsrc_dst, chunk_i32_off) + buffer_store(vec_b, rsrc_dst, chunk_i32_off_alt) + if const_expr(dual_end_aligned < n_chunks): + for chunk_idx in range(lane + dual_end_aligned, n_chunks, 64): + chunk_i32_off = chunk_idx * 4 + vec_a = buffer_load(rsrc_src, chunk_i32_off, vec_width=4, dtype=T.i32()) + buffer_store(vec_a, rsrc_dst, chunk_i32_off) + + if const_expr(enable_weights): + for recv_tok_id in range(global_warp_id, total_recv, global_warp_num): + wt_src_addr = _lv_unwrap(addr_wts_buf) + arith.zext_i64(recv_tok_id) * weight_bytes + wt_dst_addr = _lv_unwrap(addr_comb_inp_wts) + arith.zext_i64(recv_tok_id) * weight_bytes + rsrc_wt_src = create_buffer_resource_from_addr(wt_src_addr) + rsrc_wt_dst = create_buffer_resource_from_addr(wt_dst_addr) + if lane < wt_n_i32: + wt_val = buffer_load(rsrc_wt_src, lane, vec_width=1, dtype=T.i32()) + buffer_store(wt_val, rsrc_wt_dst, lane) + + else: + # Stage 1 default mode: P2P-write each received token to the + # destination PE's ``shmem_comb_inp`` at slot (rank, dest_lid). + dual_end_aligned = (n_chunks // 128) * 128 + for recv_tok_id in range(global_warp_id, total_recv, global_warp_num): + dest_tok_enc = buffer_load(_r_tis, recv_tok_id, vec_width=1, dtype=T.i32()) + if const_expr(_log2_max_tok is not None): + dest_pe = dest_tok_enc >> _log2_max_tok + dest_lid = dest_tok_enc & _mask_max_tok + else: + dest_pe = (dest_tok_enc // max_tok_per_rank) + dest_lid = (dest_tok_enc % max_tok_per_rank) + peer_base = SmemPtr.load(_lds_p2p_bases, [dest_pe]) + # Dest stride uses ``nbytes`` (staging dtype, fp8 in mixed mode). + dest_off = arith.zext_i64(rank * max_tok_per_rank + dest_lid) * nbytes + dest_tok_addr = _lv_unwrap(peer_base) + dest_off + # Src stride uses ``inp_nbytes`` (input dtype, bf16 in mixed mode). + src_tok_addr = addr_inp_tok + arith.zext_i64(recv_tok_id) * inp_nbytes + rsrc_src = create_buffer_resource_from_addr(src_tok_addr) + rsrc_dst = create_buffer_resource_from_addr(dest_tok_addr) + if const_expr(_xfer_bf16_to_fp8): + # Mixed-dtype Stage 1: load 2 bf16 i32 (=4 bf16 elems) → + # ExtF v4f32 → cvt_pk_fp8_f32 ×2 → store 1 fp8 i32 (=4 + # fp8 elems). Loop unit is 1 fp8-i32 per lane per step. + from flydsl._mlir.dialects import rocdl as _rocdl_s1b + _v4bf16_b = T.VectorType.get([4], T.bf16()) + _v4f32_b = T.VectorType.get([4], T.f32()) + _i32t_b = T.i32() + for elem_off in range(lane, n_i32, 64): + bf_pair = buffer_load(rsrc_src, elem_off * 2, + vec_width=2, dtype=T.i32()) + v4f = vector.bitcast(_v4bf16_b, bf_pair).extf(_v4f32_b) + f0 = vector.extract(v4f, static_position=[0]) + f1 = vector.extract(v4f, static_position=[1]) + f2 = vector.extract(v4f, static_position=[2]) + f3 = vector.extract(v4f, static_position=[3]) + zi = arith.constant(0, type=_i32t_b) + lo = _rocdl_s1b.cvt_pk_fp8_f32(res=_i32t_b, src_a=f0, src_b=f1, + old=zi, word_sel=False) + fp8_i32 = _rocdl_s1b.cvt_pk_fp8_f32(res=_i32t_b, src_a=f2, src_b=f3, + old=lo, word_sel=True) + buffer_store(fp8_i32, rsrc_dst, elem_off) + else: + if const_expr(dual_end_aligned >= 128): + for chunk_idx in range(lane, dual_end_aligned, 128): + chunk_i32_off = chunk_idx * 4 + chunk_i32_off_alt = (chunk_idx + 64) * 4 + vec_a = buffer_load(rsrc_src, chunk_i32_off, vec_width=4, dtype=T.i32()) + vec_b = buffer_load(rsrc_src, chunk_i32_off_alt, vec_width=4, dtype=T.i32()) + buffer_store(vec_a, rsrc_dst, chunk_i32_off) + buffer_store(vec_b, rsrc_dst, chunk_i32_off_alt) + if const_expr(dual_end_aligned < n_chunks): + for chunk_idx in range(lane + dual_end_aligned, n_chunks, 64): + chunk_i32_off = chunk_idx * 4 + vec_a = buffer_load(rsrc_src, chunk_i32_off, vec_width=4, dtype=T.i32()) + buffer_store(vec_a, rsrc_dst, chunk_i32_off) + + if const_expr(enable_weights): + wt_pe_base = SmemPtr.load(_lds_p2p_wt_bases, [dest_pe]) + wt_dest_off = arith.zext_i64( + rank * max_tok_per_rank + dest_lid) * weight_bytes + wt_dest_addr = _lv_unwrap(wt_pe_base) + wt_dest_off + wt_src_addr = _lv_unwrap(addr_wts_buf) + arith.zext_i64(recv_tok_id) * weight_bytes + rsrc_wt_src = create_buffer_resource_from_addr(wt_src_addr) + rsrc_wt_dst = create_buffer_resource_from_addr(wt_dest_addr) + if lane < wt_n_i32: + wt_val = buffer_load(rsrc_wt_src, lane, vec_width=1, dtype=T.i32()) + buffer_store(wt_val, rsrc_wt_dst, lane) + + # Stage 2: CrossDeviceBarrier. + # Every rank publishes ``xdb_cur_flag`` into every peer's + # ``xdev_bar_mem[rank]`` slot, then waits until every peer's + # corresponding slot in our local xdev_bar_mem hits the same flag. + fx.barrier() + if tid == 0: + atomic_add_global_at(addr_comb_bar, arith.constant(1)) + + if grid_thread_id < npes: + mori_shmem.int32_wait_until_equals(addr_comb_bar, block_num) + buffer_store(arith.constant(0), _r_comb_bar, 0) + xdb_remote_addr = buffer_load(_r_p2p_xdb, grid_thread_id, vec_width=1, dtype=T.i64()) + \ + arith.zext_i64(arith.constant(rank)) * 8 + store_i64_global_system(xdb_remote_addr, xdb_cur_flag) + + if grid_thread_id == 0: + atomic_add_global_at(addr_xdb_flag, arith.constant(1, type=T.i64())) + + if tid < npes: + xdb_peer_slot = addr_xdb_mem + arith.zext_i64(tid) * 8 + mori_shmem.uint64_wait_until_equals(xdb_peer_slot, xdb_cur_flag) + + fx.barrier() + if tid == 0: + buffer_store(arith.constant(0), _r_trecv, 0) + + # Stage 3: local read + WarpAccum. + # Each output token's hidden dimension is split into ``warps_per_tok`` + # partitions; each warp handles one partition (size ``hdim_per_warp``) + # of one output token. Inside the partition, the warp reads the k + # per-expert partials from ``shmem_comb_inp``, accumulates them in + # high-precision (f32) and writes back the merged token to + # ``shmem_comb_out``. + SLC_CACHE = 2 # buffer_load/store ``cache_modifier=SLC`` (system-coherent) + rsrc_out = create_buffer_resource_from_addr(addr_comb_out) + + n_elems = n_i32 + # When ``cur_rank_num_token == 0`` the division below would divide by + # zero; clamp the denominator to 1 (loop won't execute anyway). + safe_token_count = arith.select( + cur_rank_num_token == 0, 1, cur_rank_num_token) + warps_per_tok = (global_warp_num + safe_token_count - 1) // safe_token_count + hdim_per_warp = (n_elems + warps_per_tok - 1) // warps_per_tok + s3_total_work = cur_rank_num_token * warps_per_tok + + for s3_work_idx in range(global_warp_id, s3_total_work, global_warp_num): + tok_id = (s3_work_idx // warps_per_tok) + part_id = (s3_work_idx % warps_per_tok) + hdim_off = part_id * hdim_per_warp + + expert_rsrcs = [] + expert_vlds = [] + + if const_expr(skip_stage1): + # Fused-upstream Stage 3: when ``skip_stage1`` is set the + # caller has plain-stored a per-(tok_id, k_slot) partial into + # ``shmem_comb_inp[(tok_id*k + k_slot) * token_bytes]``. Each + # k_slot is unique; there is no tok_map to decode -- the + # accumulator simply reads ``shmem_comb_inp`` for k_slot in + # [0, k). Unrouted (tok_id, k_slot) slots are zero-initialized + # by the caller and therefore contribute zero to the sum. + for k_slot in range_constexpr(experts_per_token): + slot_idx = tok_id * experts_per_token + k_slot + expert_tok_off = arith.zext_i64(slot_idx) * nbytes + expert_tok_addr = _lv_unwrap(addr_comb_inp + expert_tok_off) + expert_rsrcs.append(create_buffer_resource_from_addr(expert_tok_addr)) + expert_vlds.append(arith.constant(1, type=T.bool())) + eff_all_vld = True + else: + # Baseline Stage 3: decode (peer_pe, dest_lid) from + # ``dest_tok_map[tok_id, 0..k)`` and read the per-(peer_pe, + # dest_lid) slot of ``shmem_comb_inp``. Stage 1 has P2P- + # scattered each (src_pe, src_lid) contribution into that + # slot. Two 4-i32 loads cover the 8 k-slots in one round. + tm_base_off = tok_id * experts_per_token + tm_vec_lo = buffer_load(_rsrc_tok_map, tm_base_off, vec_width=4, dtype=T.i32()) + tm_vec_hi = buffer_load(_rsrc_tok_map, tm_base_off + 4, vec_width=4, dtype=T.i32()) + + for k_slot in range_constexpr(experts_per_token): + if const_expr(k_slot < 4): + enc_k = vector.extract(tm_vec_lo, static_position=[k_slot]) + else: + enc_k = vector.extract(tm_vec_hi, static_position=[k_slot - 4]) + if const_expr(_log2_max_recv is not None): + dest_pe_k = enc_k >> _log2_max_recv + else: + dest_pe_k = (enc_k // max_recv) + vld_k = dest_pe_k < npes # sentinel = npes + safe_pe = arith.select(vld_k, dest_pe_k, rank) + if const_expr(use_p2p_read): + dtok_global = (enc_k % max_recv) + safe_dtok = arith.select(vld_k, dtok_global, 0) + peer_base = SmemPtr.load(_lds_p2p_bases, [safe_pe]) + expert_tok_off = arith.zext_i64(safe_dtok) * nbytes + expert_tok_addr = _lv_unwrap(peer_base) + expert_tok_off + else: + expert_tok_off = arith.zext_i64(safe_pe * max_tok_per_rank + tok_id) * nbytes + expert_tok_addr = _lv_unwrap(addr_comb_inp + expert_tok_off) + expert_rsrcs.append(create_buffer_resource_from_addr(expert_tok_addr)) + expert_vlds.append(vld_k) + + all_vld = (npes >= experts_per_token) # without compaction, every k_slot must be valid + eff_all_vld = all_vld or _use_compaction + + # Two paths optimised for the per-warp partition size: + # - wide path (hdim_per_warp > 895): step=128 dual or step=256 + # quad unrolled loads, each step covers 256/512/... bytes. + # - narrow path (hdim_per_warp <= 895): plain step=64 loop. + if 895 < hdim_per_warp: + rem_hdim_128 = n_elems - hdim_off + # Effective end of THIS warp's partition, clamped to n_elems. + eff_end_128 = arith.select( + rem_hdim_128 < hdim_per_warp, rem_hdim_128, hdim_per_warp) + + if const_expr(n_i32 % 256 == 0 and warp_num_per_block < 16): + if (hdim_per_warp % 256) < 1: + # Quad-unroll: 4 sub-stores per step (offset 0/256/512/768 B). + quad_end = eff_end_128 - 192 + for ec in range(lane, quad_end, 256): + ec_abs = hdim_off + ec + vals_a, vals_b, vals_c, vals_d = [], [], [], [] + for k_slot in range_constexpr(experts_per_token): + rsrc_k = expert_rsrcs[k_slot] + vld_k = expert_vlds[k_slot] + vals_a.append(_maybe_load(rsrc_k, ec_abs, vld_k, vec_width=1, dtype=T.i32(), cache_modifier=SLC_CACHE)) + vals_b.append(_maybe_load(rsrc_k, ec_abs, vld_k, vec_width=1, dtype=T.i32(), cache_modifier=SLC_CACHE, soffset_bytes=256)) + vals_c.append(_maybe_load(rsrc_k, ec_abs, vld_k, vec_width=1, dtype=T.i32(), cache_modifier=SLC_CACHE, soffset_bytes=512)) + vals_d.append(_maybe_load(rsrc_k, ec_abs, vld_k, vec_width=1, dtype=T.i32(), cache_modifier=SLC_CACHE, soffset_bytes=768)) + acc_a = _accum_experts(vals_a, expert_vlds, eff_all_vld) + acc_b = _accum_experts(vals_b, expert_vlds, eff_all_vld) + acc_c = _accum_experts(vals_c, expert_vlds, eff_all_vld) + acc_d = _accum_experts(vals_d, expert_vlds, eff_all_vld) + if const_expr(_xfer_bf16_to_fp8): + # bf16 output: data is v2i32 (8 B / lane); the + # i32 offset doubles per token and the 4 sub- + # stores use 256->512 byte spacing. + out_off = tok_id * out_n_i32 + ec_abs * 2 + buffer_store(acc_a, rsrc_out, out_off, cache_modifier=SLC_CACHE) + buffer_store(acc_b, rsrc_out, out_off, cache_modifier=SLC_CACHE, soffset_bytes=512) + buffer_store(acc_c, rsrc_out, out_off, cache_modifier=SLC_CACHE, soffset_bytes=1024) + buffer_store(acc_d, rsrc_out, out_off, cache_modifier=SLC_CACHE, soffset_bytes=1536) + else: + out_off = tok_id * n_i32 + ec_abs + buffer_store(acc_a, rsrc_out, out_off, cache_modifier=SLC_CACHE) + buffer_store(acc_b, rsrc_out, out_off, cache_modifier=SLC_CACHE, soffset_bytes=256) + buffer_store(acc_c, rsrc_out, out_off, cache_modifier=SLC_CACHE, soffset_bytes=512) + buffer_store(acc_d, rsrc_out, out_off, cache_modifier=SLC_CACHE, soffset_bytes=768) + else: + # Dual-unroll body + 1-wide tail. + s3_dual_end = (eff_end_128 // 128) * 128 + for ec in range(lane, s3_dual_end, 128): + ec_abs = hdim_off + ec + vals_a, vals_b = [], [] + for k_slot in range_constexpr(experts_per_token): + rsrc_k = expert_rsrcs[k_slot] + vld_k = expert_vlds[k_slot] + vals_a.append(_maybe_load(rsrc_k, ec_abs, vld_k, vec_width=1, dtype=T.i32(), cache_modifier=SLC_CACHE)) + vals_b.append(_maybe_load(rsrc_k, ec_abs, vld_k, vec_width=1, dtype=T.i32(), cache_modifier=SLC_CACHE, soffset_bytes=256)) + acc_a = _accum_experts(vals_a, expert_vlds, eff_all_vld) + acc_b = _accum_experts(vals_b, expert_vlds, eff_all_vld) + if const_expr(_xfer_bf16_to_fp8): + out_off = tok_id * out_n_i32 + ec_abs * 2 + buffer_store(acc_a, rsrc_out, out_off, cache_modifier=SLC_CACHE) + buffer_store(acc_b, rsrc_out, out_off, cache_modifier=SLC_CACHE, soffset_bytes=512) + else: + out_off = tok_id * n_i32 + ec_abs + buffer_store(acc_a, rsrc_out, out_off, cache_modifier=SLC_CACHE) + buffer_store(acc_b, rsrc_out, out_off, cache_modifier=SLC_CACHE, soffset_bytes=256) + for ec in range(lane + s3_dual_end, eff_end_128, 64): + ec_abs = hdim_off + ec + vals_tail = [] + for k_slot in range_constexpr(experts_per_token): + vals_tail.append(_maybe_load(expert_rsrcs[k_slot], ec_abs, expert_vlds[k_slot], vec_width=1, dtype=T.i32(), cache_modifier=SLC_CACHE)) + acc_tail = _accum_experts(vals_tail, expert_vlds, eff_all_vld) + if const_expr(_xfer_bf16_to_fp8): + out_off = tok_id * out_n_i32 + ec_abs * 2 + buffer_store(acc_tail, rsrc_out, out_off, cache_modifier=SLC_CACHE) + else: + out_off = tok_id * n_i32 + ec_abs + buffer_store(acc_tail, rsrc_out, out_off, cache_modifier=SLC_CACHE) + else: + # Narrow path: a single step=64 main loop. + rem_hdim_64 = n_elems - hdim_off + eff_end_64 = arith.select( + rem_hdim_64 < hdim_per_warp, rem_hdim_64, hdim_per_warp) + for ec in range(lane, eff_end_64, 64): + ec_abs = hdim_off + ec + vals_main = [] + for k_slot in range_constexpr(experts_per_token): + vals_main.append(_maybe_load(expert_rsrcs[k_slot], ec_abs, expert_vlds[k_slot], vec_width=1, dtype=T.i32(), cache_modifier=SLC_CACHE)) + acc = _accum_experts(vals_main, expert_vlds, eff_all_vld) + if const_expr(_xfer_bf16_to_fp8): + out_off = tok_id * out_n_i32 + ec_abs * 2 + buffer_store(acc, rsrc_out, out_off, cache_modifier=SLC_CACHE) + else: + out_off = tok_id * n_i32 + ec_abs + buffer_store(acc, rsrc_out, out_off, cache_modifier=SLC_CACHE) + + # Stage 3b: Weight accumulation. + # Each warp handles one output token; lanes 0..k-1 each pull the + # weight value from one k-expert slot's contribution in + # ``shmem_comb_inp_wts`` (or peer-side via P2P-read), then they + # f32-sum across the k slots and write into ``shmem_comb_out_wts``. + if const_expr(enable_weights): + rsrc_out_wts = create_buffer_resource_from_addr(addr_comb_out_wts) + for wt_tok_id in range(global_warp_id, cur_rank_num_token, global_warp_num): + wt_tm_off = wt_tok_id * experts_per_token + wt_tm_vec_lo = buffer_load(_rsrc_tok_map, wt_tm_off, vec_width=4, dtype=T.i32()) + wt_tm_vec_hi = buffer_load(_rsrc_tok_map, wt_tm_off + 4, vec_width=4, dtype=T.i32()) + + if lane < experts_per_token: + wt_acc = arith.constant(0.0, type=T.f32()) + for k_slot in range_constexpr(experts_per_token): + if const_expr(k_slot < 4): + wt_enc = vector.extract(wt_tm_vec_lo, static_position=[k_slot]) + else: + wt_enc = vector.extract(wt_tm_vec_hi, static_position=[k_slot - 4]) + if const_expr(_log2_max_recv is not None): + wt_pe = wt_enc >> _log2_max_recv + else: + wt_pe = (wt_enc // max_recv) + wt_vld = wt_pe < npes + wt_safe_pe = arith.select(wt_vld, wt_pe, rank) + if const_expr(use_p2p_read): + wt_dtok = (wt_enc % max_recv) + wt_safe_dtok = arith.select(wt_vld, wt_dtok, 0) + wt_pe_base = SmemPtr.load(_lds_p2p_wt_bases, [wt_safe_pe]) + wt_src_off = arith.zext_i64(wt_safe_dtok) * weight_bytes + wt_rsrc = create_buffer_resource_from_addr( + wt_pe_base + wt_src_off) + else: + wt_src_off = arith.zext_i64( + wt_safe_pe * max_tok_per_rank + wt_tok_id) * weight_bytes + wt_rsrc = create_buffer_resource_from_addr( + addr_comb_inp_wts + wt_src_off) + wt_val = buffer_load(wt_rsrc, lane, vec_width=1, dtype=T.f32()) + if const_expr(npes >= experts_per_token): + wt_acc = wt_acc + wt_val + else: + wt_acc = wt_acc + arith.select(wt_vld, wt_val, 0.0) + wt_out_off = wt_tok_id * experts_per_token + lane + buffer_store(wt_acc, rsrc_out_wts, wt_out_off) + + ep_combine_intranode._allocator = allocator + return ep_combine_intranode + + +def make_dispatch_jit(*, rank, npes, experts_per_rank, experts_per_token, + hidden_dim, max_tok_per_rank, block_num, + warp_num_per_block, data_type, + scale_dim=0, scale_type_size=0, + enable_std_moe=False): + hidden_elem_size = torch.tensor([], dtype=data_type).element_size() + kernel = make_dispatch_kernel( + rank=rank, npes=npes, + experts_per_rank=experts_per_rank, + experts_per_token=experts_per_token, + hidden_dim=hidden_dim, + hidden_elem_size=hidden_elem_size, + max_tok_per_rank=max_tok_per_rank, + block_num=block_num, + warp_num_per_block=warp_num_per_block, + scale_dim=scale_dim, + scale_type_size=scale_type_size, + enable_std_moe=enable_std_moe, + data_type=data_type, + ) + + # Closure variables that participate in the JIT cache key. The launcher + # closes over them so that two ``@flyc.jit`` invocations with different + # configs produce distinct cached entries. + _key_rank, _key_npes, _key_block_num = rank, npes, block_num + _key_warp_per_block = warp_num_per_block + _key_max_tok = max_tok_per_rank + _key_std_moe = enable_std_moe + + @flyc.jit + def dispatch_launch( + addr_inp_tok: fx.Int64, addr_idx: fx.Int64, addr_wts: fx.Int64, + addr_out_tok: fx.Int64, addr_out_wts: fx.Int64, addr_out_idx: fx.Int64, + addr_tok_off: fx.Int64, addr_recv_num: fx.Int64, + addr_dest_ctr: fx.Int64, addr_disp_bar: fx.Int64, + addr_tok_map: fx.Int64, addr_tis: fx.Int64, + addr_total_rv: fx.Int64, + addr_p2p_tok_off: fx.Int64, addr_p2p_tis: fx.Int64, + addr_p2p_out_wts: fx.Int64, addr_p2p_out_idx: fx.Int64, + addr_p2p_out_tok: fx.Int64, addr_p2p_recv_num: fx.Int64, + addr_scales: fx.Int64, addr_p2p_out_scales: fx.Int64, + addr_packed_recv_x: fx.Int64, addr_packed_recv_count: fx.Int64, + addr_packed_recv_src_info: fx.Int64, addr_disp_tok_map: fx.Int64, + addr_disp_grid_bar: fx.Int64, + cur_tok: fx.Int32, + stream: Stream = Stream(None), + ): + _ = (_key_rank, _key_npes, _key_block_num, _key_warp_per_block, + _key_max_tok, _key_std_moe) + kernel(addr_inp_tok, addr_idx, addr_wts, + addr_out_tok, addr_out_wts, addr_out_idx, + addr_tok_off, addr_recv_num, addr_dest_ctr, + addr_disp_bar, addr_tok_map, addr_tis, + addr_total_rv, + addr_p2p_tok_off, addr_p2p_tis, + addr_p2p_out_wts, addr_p2p_out_idx, + addr_p2p_out_tok, addr_p2p_recv_num, + addr_scales, addr_p2p_out_scales, + addr_packed_recv_x, addr_packed_recv_count, + addr_packed_recv_src_info, addr_disp_tok_map, + addr_disp_grid_bar, + cur_tok).launch( + grid=(block_num, 1, 1), + block=(warp_num_per_block * 64, 1, 1), + stream=stream, + ) + + return dispatch_launch + + +def make_combine_jit(*, rank, npes, experts_per_rank=0, experts_per_token, + hidden_dim, max_tok_per_rank, block_num, + warp_num_per_block, data_type, + enable_weights=False, enable_std_moe=False, + use_p2p_read=False, skip_stage1=False, + inp_data_type=None): + hidden_elem_size = torch.tensor([], dtype=data_type).element_size() + kernel = make_combine_kernel( + rank=rank, npes=npes, + experts_per_rank=experts_per_rank, + experts_per_token=experts_per_token, + hidden_dim=hidden_dim, + hidden_elem_size=hidden_elem_size, + max_tok_per_rank=max_tok_per_rank, + block_num=block_num, + warp_num_per_block=warp_num_per_block, + data_type=data_type, + enable_weights=enable_weights, + enable_std_moe=enable_std_moe, + use_p2p_read=use_p2p_read, + skip_stage1=skip_stage1, + inp_data_type=inp_data_type, + ) + + # Closure variables that participate in the JIT cache key. The launcher + # closes over them so two ``@flyc.jit`` invocations with different + # configs produce distinct cached entries. + _key_rank, _key_npes, _key_block_num = rank, npes, block_num + _key_warp_per_block = warp_num_per_block + _key_max_tok = max_tok_per_rank + _key_weights = enable_weights + _key_std_moe = enable_std_moe + _key_p2p_read = use_p2p_read + _key_skip_s1 = skip_stage1 + _key_inp_dtype = str(inp_data_type) if inp_data_type is not None else "none" + _allocator = kernel._allocator + + @flyc.jit + def combine_launch( + addr_inp_tok: fx.Int64, addr_comb_inp: fx.Int64, + addr_comb_out: fx.Int64, addr_xdb_mem: fx.Int64, + addr_xdb_flag: fx.Int64, addr_tok_map: fx.Int64, + addr_comb_bar: fx.Int64, addr_trecv: fx.Int64, + addr_tis: fx.Int64, + addr_p2p_comb_inp: fx.Int64, addr_p2p_xdb_mem: fx.Int64, + addr_wts_buf: fx.Int64, + addr_comb_inp_wts: fx.Int64, addr_comb_out_wts: fx.Int64, + addr_p2p_comb_inp_wts: fx.Int64, + addr_packed_recv_x: fx.Int64, addr_disp_tok_map: fx.Int64, + addr_disp_out_wts: fx.Int64, + cur_rank_num_token: fx.Int32, + stream: Stream = Stream(None), + ): + _ = (_key_rank, _key_npes, _key_block_num, _key_warp_per_block, + _key_max_tok, _key_weights, _key_std_moe, _key_p2p_read, + _key_skip_s1, _key_inp_dtype) + from flydsl.compiler.kernel_function import CompilationContext + from flydsl._mlir import ir + _allocator.finalized = False + ctx = CompilationContext.get_current() + with ir.InsertionPoint(ctx.gpu_module_body): + _allocator.finalize() + + kernel(addr_inp_tok, addr_comb_inp, addr_comb_out, + addr_xdb_mem, addr_xdb_flag, addr_tok_map, + addr_comb_bar, addr_trecv, addr_tis, + addr_p2p_comb_inp, addr_p2p_xdb_mem, + addr_wts_buf, addr_comb_inp_wts, + addr_comb_out_wts, addr_p2p_comb_inp_wts, + addr_packed_recv_x, addr_disp_tok_map, + addr_disp_out_wts, + cur_rank_num_token).launch( + grid=(block_num, 1, 1), + block=(warp_num_per_block * 64, 1, 1), + stream=stream, + ) + + return combine_launch diff --git a/kernels/dispatch_combine_intranode_op.py b/kernels/dispatch_combine_intranode_op.py new file mode 100644 index 00000000..d866bcdd --- /dev/null +++ b/kernels/dispatch_combine_intranode_op.py @@ -0,0 +1,638 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright (c) 2025 FlyDSL Project Contributors + +"""FlyDSL DispatchCombine IntraNode 算子包装器。""" +from __future__ import annotations + +import os +from dataclasses import dataclass + +import torch +import torch.distributed as dist +import flydsl.compiler as flyc +import flydsl.expr as fx +import mori.shmem as ms +from mori.shmem import mori_shmem_create_tensor + +from .dispatch_combine_intranode_kernel import ( + make_dispatch_jit, + make_combine_jit, +) + + +@dataclass +class FlyDSLDispatchCombineConfig: + rank: int + world_size: int + hidden_dim: int + max_num_inp_token_per_rank: int + num_experts_per_rank: int + num_experts_per_token: int + data_type: torch.dtype = torch.bfloat16 + warp_num_per_block: int = 16 + block_num: int = 80 + chip: str = "gfx950" + scale_dim: int = 0 + scale_type_size: int = 0 + enable_std_moe: bool = False + use_external_inp_buf: bool = True + quant_type: str = "none" + + @property + def is_fp4(self): + return self.data_type == torch.float4_e2m1fn_x2 + + @property + def elem_size(self): + return torch.tensor([], dtype=self.data_type).element_size() + + @property + def token_bytes(self): + if self.is_fp4: + return self.hidden_dim // 2 + return self.hidden_dim * self.elem_size + + @property + def token_view_dim(self): + if self.is_fp4: + return self.hidden_dim // 2 + return self.hidden_dim + + @property + def block_dim(self): + return self.warp_num_per_block * 64 + + @property + def max_recv(self): + return self.world_size * self.max_num_inp_token_per_rank + + @property + def scale_bytes(self): + return self.scale_dim * self.scale_type_size + + +class FlyDSLDispatchCombineIntraNodeOp: + + def __init__(self, config): + self.cfg = config + self._dev = torch.device("cuda", config.rank) + r = config.rank + + self._alloc_buffers() + ms.shmem_barrier_all() + + npes = config.world_size + self._p2p_tok_off = torch.zeros(npes, dtype=torch.int64, device=self._dev) + self._p2p_tis = torch.zeros(npes, dtype=torch.int64, device=self._dev) + self._p2p_out_wts = torch.zeros(npes, dtype=torch.int64, device=self._dev) + self._p2p_out_idx = torch.zeros(npes, dtype=torch.int64, device=self._dev) + self._p2p_out_tok = torch.zeros(npes, dtype=torch.int64, device=self._dev) + self._p2p_recv_num = torch.zeros(npes, dtype=torch.int64, device=self._dev) + self._p2p_out_scales = torch.zeros(npes, dtype=torch.int64, device=self._dev) + for pe in range(npes): + self._p2p_tok_off[pe] = ms.shmem_ptr_p2p(self.shmem_tok_off.data_ptr(), r, pe) + self._p2p_tis[pe] = ms.shmem_ptr_p2p(self.shmem_tok_id_to_src.data_ptr(), r, pe) + self._p2p_out_wts[pe] = ms.shmem_ptr_p2p(self.shmem_disp_out_wts.data_ptr(), r, pe) + self._p2p_out_idx[pe] = ms.shmem_ptr_p2p(self.shmem_disp_out_idx.data_ptr(), r, pe) + self._p2p_out_tok[pe] = ms.shmem_ptr_p2p(self.shmem_disp_out_tok.data_ptr(), r, pe) + self._p2p_recv_num[pe] = ms.shmem_ptr_p2p(self.shmem_recv_tok_num.data_ptr(), r, pe) + self._p2p_out_scales[pe] = ms.shmem_ptr_p2p(self.shmem_out_scales.data_ptr(), r, pe) + + self._p2p_comb_inp = torch.zeros(npes, dtype=torch.int64, device=self._dev) + self._p2p_comb_inp_wts = torch.zeros(npes, dtype=torch.int64, device=self._dev) + self._p2p_xdb_mem = torch.zeros(npes, dtype=torch.int64, device=self._dev) + for pe in range(npes): + self._p2p_comb_inp[pe] = ms.shmem_ptr_p2p(self.shmem_comb_inp_tok.data_ptr(), r, pe) + self._p2p_comb_inp_wts[pe] = ms.shmem_ptr_p2p(self.shmem_comb_inp_wts.data_ptr(), r, pe) + self._p2p_xdb_mem[pe] = ms.shmem_ptr_p2p(self.shmem_xdev_bar_mem.data_ptr(), r, pe) + + _disp_wpb = config.warp_num_per_block + self._disp_fn = make_dispatch_jit( + rank=r, npes=config.world_size, + experts_per_rank=config.num_experts_per_rank, + experts_per_token=config.num_experts_per_token, + hidden_dim=config.hidden_dim, + max_tok_per_rank=config.max_num_inp_token_per_rank, + block_num=config.block_num, + warp_num_per_block=_disp_wpb, + data_type=config.data_type, + scale_dim=config.scale_dim, + scale_type_size=config.scale_type_size, + enable_std_moe=config.enable_std_moe, + ) + + _use_fp8_cast = (config.quant_type == "fp8_direct_cast" and config.data_type == torch.bfloat16) + _comb_dtype = torch.float8_e4m3fn if _use_fp8_cast else config.data_type + # Mixed-dtype Stage 1 (mori UseFp8DirectCast equivalent): when + # _use_fp8_cast is on, the user feeds bf16 input to ``combine()`` and + # the kernel performs an inline bf16 → fp8 cast in Stage 1 before P2P + # scatter. This avoids an extra ~12μs ``input.to(fp8).contiguous()`` + # PyTorch elementwise kernel that would otherwise sit on the cudagraph + # critical path. Wrapper-side allocation/views remain fp8-stride. + _comb_inp_dt = torch.bfloat16 if _use_fp8_cast else None + self._comb_fn = make_combine_jit( + rank=r, npes=config.world_size, + experts_per_rank=config.num_experts_per_rank, + experts_per_token=config.num_experts_per_token, + hidden_dim=config.hidden_dim, + max_tok_per_rank=config.max_num_inp_token_per_rank, + block_num=config.block_num, + warp_num_per_block=_disp_wpb, + data_type=_comb_dtype, + enable_weights=True, + enable_std_moe=config.enable_std_moe, + use_p2p_read=not config.use_external_inp_buf, + inp_data_type=_comb_inp_dt, + ) + self._use_fp8_cast = _use_fp8_cast + + # barrier flag 初始值必须为 1, 否则首次 wait_until_equals(slot, 0) 立即满足 + self._xdev_flag = torch.ones(1, dtype=torch.int64, device=self._dev) + + self._fx_out_tok = fx.Int64(self.shmem_disp_out_tok.data_ptr()) + self._fx_out_wts = fx.Int64(self.shmem_disp_out_wts.data_ptr()) + self._fx_out_idx = fx.Int64(self.shmem_disp_out_idx.data_ptr()) + self._fx_tok_off = fx.Int64(self.shmem_tok_off.data_ptr()) + self._fx_recv_num = fx.Int64(self.shmem_recv_tok_num.data_ptr()) + self._fx_dest_ctr = fx.Int64(self.dest_pe_ctr.data_ptr()) + self._fx_disp_bar = fx.Int64(self.disp_bar.data_ptr()) + self._fx_tok_map = fx.Int64(self.dest_tok_map.data_ptr()) + self._fx_tis = fx.Int64(self.shmem_tok_id_to_src.data_ptr()) + self._fx_total_rv = fx.Int64(self.total_recv.data_ptr()) + # combine 固定地址 + self._fx_comb_inp = fx.Int64(self.shmem_comb_inp_tok.data_ptr()) + self._fx_comb_out = fx.Int64(self.shmem_comb_out_tok.data_ptr()) + self._fx_xdb_mem = fx.Int64(self.shmem_xdev_bar_mem.data_ptr()) + self._fx_xdev_flag = fx.Int64(self._xdev_flag.data_ptr()) + self._fx_comb_bar = fx.Int64(self.comb_bar.data_ptr()) + self._fx_trecv = fx.Int64(self.total_recv.data_ptr()) + self._fx_p2p_tok_off = fx.Int64(self._p2p_tok_off.data_ptr()) + self._fx_p2p_tis = fx.Int64(self._p2p_tis.data_ptr()) + self._fx_p2p_out_wts = fx.Int64(self._p2p_out_wts.data_ptr()) + self._fx_p2p_out_idx = fx.Int64(self._p2p_out_idx.data_ptr()) + self._fx_p2p_out_tok = fx.Int64(self._p2p_out_tok.data_ptr()) + self._fx_p2p_recv_num = fx.Int64(self._p2p_recv_num.data_ptr()) + self._fx_p2p_out_scales = fx.Int64(self._p2p_out_scales.data_ptr()) + self._fx_out_scales = fx.Int64(self.shmem_out_scales.data_ptr()) + self._fx_p2p_comb_inp = fx.Int64(self._p2p_comb_inp.data_ptr()) + self._fx_p2p_comb_inp_wts = fx.Int64(self._p2p_comb_inp_wts.data_ptr()) + self._fx_p2p_xdb_mem = fx.Int64(self._p2p_xdb_mem.data_ptr()) + self._fx_comb_inp_wts = fx.Int64(self.shmem_comb_inp_wts.data_ptr()) + self._fx_comb_out_wts = fx.Int64(self.shmem_comb_out_wts.data_ptr()) + self._fx_packed_recv_count = fx.Int64(self.packed_recv_count.data_ptr()) + self._fx_packed_recv_src_info = fx.Int64(self.packed_recv_src_info.data_ptr()) + self._fx_disp_tok_map = fx.Int64(self.disp_tok_to_ep_slot_map.data_ptr()) + self._fx_disp_grid_bar = fx.Int64(self.disp_grid_bar.data_ptr()) + self._fx_disp_out_wts = fx.Int64(self.shmem_disp_out_wts.data_ptr()) + + self._disp_compiled = None + self._comb_compiled = None + # combine kernel 的 skip_stage1 变体:给 fused_gemm2_combine 算子使用, + # 此时 fused kernel 已经把 token / 权重 P2P 写入 shmem_comb_inp[_wts], + # combine 只跑 Stage 2 (CrossDeviceBarrier) + Stage 3 (本地 weighted-accum)。 + self._comb_no_s1_fn = None + self._comb_no_s1_compiled = None + + def _alloc_buffers(self): + cfg = self.cfg + npes = cfg.world_size + k = cfg.num_experts_per_token + mt = cfg.max_num_inp_token_per_rank + mr = cfg.max_recv # npes * mt + hdim = cfg.hidden_dim + esz = cfg.elem_size # bytes per element + + tb = cfg.token_bytes + tok_i16_mr = (mr * tb + 1) // 2 + tok_i16_mt = (mt * tb + 1) // 2 + + # Symmetric shmem buffers + self.shmem_disp_out_tok = mori_shmem_create_tensor((tok_i16_mr,), torch.int16) + self.shmem_disp_out_wts = mori_shmem_create_tensor((mr * k,), torch.float32) + self.shmem_disp_out_idx = mori_shmem_create_tensor((mr * k,), torch.int32) + scale_total = mr * cfg.scale_bytes if cfg.scale_bytes > 0 else 1 + self.shmem_out_scales = mori_shmem_create_tensor((scale_total,), torch.int8) + self.shmem_tok_off = mori_shmem_create_tensor((1,), torch.int32) + self.shmem_recv_tok_num = mori_shmem_create_tensor((npes,), torch.int32) + self.shmem_tok_id_to_src = mori_shmem_create_tensor((mr,), torch.int32) + self.shmem_comb_inp_tok = mori_shmem_create_tensor((tok_i16_mr,), torch.int16) + self.shmem_comb_out_tok = mori_shmem_create_tensor((tok_i16_mt,), torch.int16) + self.shmem_comb_inp_wts = mori_shmem_create_tensor((mr * k,), torch.float32) + self.shmem_comb_out_wts = mori_shmem_create_tensor((mt * k,), torch.float32) + self.shmem_xdev_bar_mem = mori_shmem_create_tensor((npes,), torch.int64) + + # mori_shmem_create_tensor 走 shmem_malloc,分配的是未初始化的 raw memory。 + # 对 fused MoE-GEMM2 + EP-Combine 路径,GEMM2 需要在 epilogue 用 + # shmem_tok_id_to_src 解码 dest_pe / dest_lid,越界 garbage 会触发 LDS + # OOB → 写到任意全局地址 → 破坏 control state。这里把所有 combine 路径 + # 直接读写的 symmetric buffer 显式清零,保证: + # - shmem_tok_id_to_src[t] 对未被 dispatch 写入的 t 解码为 (pe=0, lid=0), + # P2P scatter 退化成"安全无副作用"(多写同一槽位) + # - shmem_xdev_bar_mem 起始 0,CrossDeviceBarrier 第一次 wait 不会读到 + # 残留值(依赖 cur_flag 单调递增) + # - shmem_comb_inp_{tok,wts} 起始 0,combine_no_stage1 在 stage 3 累加 + # 时不会读到 garbage + self.shmem_tok_id_to_src.zero_() + self.shmem_comb_inp_tok.zero_() + self.shmem_comb_inp_wts.zero_() + self.shmem_xdev_bar_mem.zero_() + + # Local device buffers + self.dest_pe_ctr = torch.zeros(npes, dtype=torch.int32, device=self._dev) + self.disp_bar = torch.zeros(1, dtype=torch.int32, device=self._dev) + self.comb_bar = torch.zeros(1, dtype=torch.int32, device=self._dev) + self.total_recv = torch.zeros(1, dtype=torch.int32, device=self._dev) + sentinel = cfg.world_size * mr + self.dest_tok_map = torch.full( + (mt * k,), sentinel, dtype=torch.int32, device=self._dev) + + # StdMoE buffers + if cfg.enable_std_moe: + epr = cfg.num_experts_per_rank + max_tok_per_expert = mr # world_size * max_num_inp_token_per_rank + self.packed_recv_count = torch.zeros( + epr, dtype=torch.int32, device=self._dev) + self.packed_recv_src_info = torch.zeros( + epr * max_tok_per_expert, dtype=torch.int32, device=self._dev) + self.disp_tok_to_ep_slot_map = torch.full( + (mr * k,), -1, dtype=torch.int64, device=self._dev) + self.disp_grid_bar = torch.zeros( + 1, dtype=torch.int32, device=self._dev) + else: + self.packed_recv_count = torch.zeros(1, dtype=torch.int32, device=self._dev) + self.packed_recv_src_info = torch.zeros(1, dtype=torch.int32, device=self._dev) + self.disp_tok_to_ep_slot_map = torch.zeros(1, dtype=torch.int64, device=self._dev) + self.disp_grid_bar = torch.zeros(1, dtype=torch.int32, device=self._dev) + + def barrier(self): + ms.shmem_barrier_all() + + def reset(self): + self.barrier() + + def dispatch(self, input, weights, scales, indices, + packed_recv_x=None, + block_num=-1, rdma_block_num=-1, warp_per_block=-1): + cfg = self.cfg + cur_tok = input.shape[0] + stream = torch.cuda.current_stream() + inp_c = input if input.is_contiguous() else input.contiguous() + wts_c = weights if weights.is_contiguous() else weights.contiguous() + idx_c = indices if (indices.dtype == torch.int32 and indices.is_contiguous()) \ + else indices.to(torch.int32).contiguous() + + sc_ptr = scales.data_ptr() if scales is not None else 0 + prx_ptr = packed_recv_x.data_ptr() if packed_recv_x is not None else 0 + + if cfg.enable_std_moe: + self.packed_recv_count.zero_() + + _std_args = ( + self._fx_packed_recv_count if cfg.enable_std_moe else fx.Int64(0), + self._fx_packed_recv_src_info, + self._fx_disp_tok_map, + self._fx_disp_grid_bar, + ) + + if self._disp_compiled is None: + args = ( + fx.Int64(inp_c.data_ptr()), + fx.Int64(idx_c.data_ptr()), + fx.Int64(wts_c.data_ptr()), + self._fx_out_tok, + self._fx_out_wts, + self._fx_out_idx, + self._fx_tok_off, + self._fx_recv_num, + self._fx_dest_ctr, + self._fx_disp_bar, + self._fx_tok_map, + self._fx_tis, + self._fx_total_rv, + self._fx_p2p_tok_off, + self._fx_p2p_tis, + self._fx_p2p_out_wts, + self._fx_p2p_out_idx, + self._fx_p2p_out_tok, + self._fx_p2p_recv_num, + fx.Int64(sc_ptr), + self._fx_p2p_out_scales, + fx.Int64(prx_ptr), + *_std_args, + cur_tok, + stream, + ) + self._disp_compiled = flyc.compile(self._disp_fn, *args) + else: + self._disp_compiled( + inp_c.data_ptr(), + idx_c.data_ptr(), + wts_c.data_ptr(), + self._fx_out_tok, + self._fx_out_wts, + self._fx_out_idx, + self._fx_tok_off, + self._fx_recv_num, + self._fx_dest_ctr, + self._fx_disp_bar, + self._fx_tok_map, + self._fx_tis, + self._fx_total_rv, + self._fx_p2p_tok_off, + self._fx_p2p_tis, + self._fx_p2p_out_wts, + self._fx_p2p_out_idx, + self._fx_p2p_out_tok, + self._fx_p2p_recv_num, + sc_ptr, + self._fx_p2p_out_scales, + prx_ptr, + *_std_args, + cur_tok, + stream, + ) + + mr = cfg.max_recv + hdim = cfg.hidden_dim + k = cfg.num_experts_per_token + + out_tok = self.shmem_disp_out_tok.view(torch.int8)[ + :mr * cfg.token_bytes].view(cfg.data_type).view(mr, cfg.token_view_dim) + out_wts = self.shmem_disp_out_wts.view(mr, k) + out_idx = self.shmem_disp_out_idx.view(mr, k) + out_scales = None + if cfg.scale_bytes > 0: + out_scales = self.shmem_out_scales[:mr * cfg.scale_bytes].view( + mr, cfg.scale_dim * cfg.scale_type_size) + + result = (out_tok, out_wts, out_scales, out_idx, self.total_recv) + if cfg.enable_std_moe: + epr = cfg.num_experts_per_rank + result = result + ( + self.packed_recv_count[:epr], + self.packed_recv_src_info, + ) + return result + + def combine(self, input, weights, indices, + packed_recv_x=None, cur_tok=None, + block_num=-1, rdma_block_num=-1, warp_per_block=-1, + use_external_inp_buf=-1, call_reset=False): + cfg = self.cfg + stream = torch.cuda.current_stream() + + # In _use_fp8_cast mode, the combine kernel does the bf16 → fp8 cast + # inline in Stage 1 (mori UseFp8DirectCast equivalent), so the wrapper + # passes bf16 input straight through. Skipping the PyTorch-level + # ``.to(fp8).contiguous()`` saves ~12μs per iter on the cudagraph + # critical path. + inp_c = input if input.is_contiguous() else input.contiguous() + _cur_tok = cur_tok if cur_tok is not None else cfg.max_num_inp_token_per_rank + + wts_ptr = self.shmem_disp_out_wts.data_ptr() if weights is None else weights.data_ptr() + + _prx_ref = None + if self._use_fp8_cast and packed_recv_x is not None: + # std-MoE expert-major buffer (`packed_recv_x`) is produced in bf16 + # by the upstream pipeline; downstream Stage 1 reads it in fp8 + # dtype, so we still cast here. This branch is independent from + # the regular combine input path above. + _prx_ref = packed_recv_x.view(torch.bfloat16).to(torch.float8_e4m3fn).contiguous() + prx_ptr = _prx_ref.data_ptr() + else: + prx_ptr = packed_recv_x.data_ptr() if packed_recv_x is not None else 0 + + _std_args_comb = ( + fx.Int64(prx_ptr), + self._fx_disp_tok_map, + self._fx_disp_out_wts, + ) + + if self._comb_compiled is None: + args = ( + fx.Int64(inp_c.data_ptr()), + self._fx_comb_inp, + self._fx_comb_out, + self._fx_xdb_mem, + self._fx_xdev_flag, + self._fx_tok_map, + self._fx_comb_bar, + self._fx_trecv, + self._fx_tis, + self._fx_p2p_comb_inp, + self._fx_p2p_xdb_mem, + fx.Int64(wts_ptr), + self._fx_comb_inp_wts, + self._fx_comb_out_wts, + self._fx_p2p_comb_inp_wts, + *_std_args_comb, + _cur_tok, + stream, + ) + self._comb_compiled = flyc.compile(self._comb_fn, *args) + else: + self._comb_compiled( + inp_c.data_ptr(), + self._fx_comb_inp, + self._fx_comb_out, + self._fx_xdb_mem, + self._fx_xdev_flag, + self._fx_tok_map, + self._fx_comb_bar, + self._fx_trecv, + self._fx_tis, + self._fx_p2p_comb_inp, + self._fx_p2p_xdb_mem, + wts_ptr, + self._fx_comb_inp_wts, + self._fx_comb_out_wts, + self._fx_p2p_comb_inp_wts, + prx_ptr, + self._fx_disp_tok_map, + self._fx_disp_out_wts, + _cur_tok, + stream, + ) + + mt = cfg.max_num_inp_token_per_rank + hdim = cfg.hidden_dim + k = cfg.num_experts_per_token + + # fp8_direct_cast contract: external dtype is bf16 on both ends; the + # combine kernel itself writes bf16 to ``shmem_comb_out_tok`` (Stage 3 + # _from_accum casts v4f32 → v4bf16 inline), so we view the buffer as + # bf16 directly with no extra PyTorch-level cast on the critical path. + out_tok = self.shmem_comb_out_tok.view(torch.int8)[ + :mt * cfg.token_bytes].view(cfg.data_type).view(mt, cfg.token_view_dim) + out_wts = self.shmem_comb_out_wts.view(mt, k) + + if call_reset: + self.reset() + return out_tok, out_wts + + def combine_no_stage1(self, input, weights, indices, + packed_recv_x=None, cur_tok=None, + call_reset=False, + enable_weights: bool = True): + """combine 的 stage1-skipped 变体。 + + 语义:跳过 P2P scatter(外部 fused kernel 已把数据写入 shmem_comb_inp[_wts]), + 只执行 Stage 2 (CrossDeviceBarrier) + Stage 3 (本地 weighted-accum)。 + + Parameters + ---------- + enable_weights + ``True`` (默认) 兼容当前 fused-with-weight 链路:在 combine + kernel 内保留 Stage 1 的 weight scatter + Stage 3b 的 weight + accumulate。weight scatter 显式留在 combine kernel 内(而不是 + 放在上游 fused GEMM2 的 epilogue 里),因为 16B 小写若与上游 + token P2P 并发会被 ROCm IPC fabric 静默丢,必须放在静态 fabric + 上由 combine kernel 完成。 + ``False`` 给 weight-free fused 路径(fused MoE 上游已经把 + weight 处理掉了,combine 端不需要 out_wts):完全 DCE 掉 + weight scatter + Stage 3b,省 ~3-5 μs。 + 两种变体走不同的 JIT 缓存,互不污染。 + + 约定:调用前 fused kernel 必须保证: + - shmem_comb_inp_tok 已写入本 PE 应接收的所有 token(按 max_tok_per_rank 槽位) + - shmem_comb_inp_wts 已写入对应权重(仅 enable_weights=True 时需要) + - total_recv 已被 dispatch 设置完毕(Stage 3 用于读 cur_rank_num_token) + """ + cfg = self.cfg + stream = torch.cuda.current_stream() + + # When skip_stage1=True (the only mode this method ever compiles for), + # the combine kernel does NOT read inp_c — Stage 1 is bypassed and the + # kernel reads from shmem_comb_inp_tok directly (already populated by + # the upstream fused GEMM2 epilogue P2P scatter). So skip the + # potentially-expensive Python-level fp8 cast (.to(fp8) + .contiguous()) + # if the caller gave us a fp8 input or even a placeholder bf16: the + # cast is a ~12us elementwise kernel that gets captured by cudagraph + # and ends up serially on the chain critical path for nothing. + # Caller (fused op wrapper) already CV-casted in the GEMM2 epilogue. + if self._use_fp8_cast and input.dtype != torch.float8_e4m3fn: + inp_c = input.to(torch.float8_e4m3fn).contiguous() + else: + inp_c = input if input.is_contiguous() else input.contiguous() + _cur_tok = cur_tok if cur_tok is not None else cfg.max_num_inp_token_per_rank + + wts_ptr = self.shmem_disp_out_wts.data_ptr() if weights is None else weights.data_ptr() + + _prx_ref = None + if self._use_fp8_cast and packed_recv_x is not None: + _prx_ref = packed_recv_x.view(torch.bfloat16).to(torch.float8_e4m3fn).contiguous() + prx_ptr = _prx_ref.data_ptr() + else: + prx_ptr = packed_recv_x.data_ptr() if packed_recv_x is not None else 0 + + # JIT 缓存按 enable_weights 区分(两份编译产物)。 + # 历史 self._comb_no_s1_fn / _compiled 升级为 dict[bool, fn]。 + if not isinstance(self._comb_no_s1_fn, dict): + self._comb_no_s1_fn = {} + self._comb_no_s1_compiled = {} + + if enable_weights not in self._comb_no_s1_fn: + from .dispatch_combine_intranode_kernel import make_combine_jit + _use_fp8_cast = self._use_fp8_cast + _comb_dtype = torch.float8_e4m3fn if _use_fp8_cast else cfg.data_type + # Mixed-dtype contract for fp8_direct_cast: external dtype = bf16, + # transport dtype = fp8. Stage 3 _from_accum will cast f32 → bf16 + # inline so kernel writes bf16 directly to shmem_comb_out_tok and + # the wrapper does NOT need a post .to(bf16) cast. + _comb_inp_dt = torch.bfloat16 if _use_fp8_cast else None + # enable_weights=False 路径(fused MoE 不需要 out_wts): + # weight scatter + Stage 3b weight accumulate 都在 const_expr + # 处被 DCE 掉,省 ~3-5μs。 + # enable_weights=True 路径(兼容 fused-with-weight): + # combine kernel 在 skip_stage1=True 下默认仍跑 weight scatter, + # 因为同 fabric 上与 token P2P 并发的 16B 小写会被静默丢,必须 + # 放在静态 fabric 上由 combine kernel 完成。 + self._comb_no_s1_fn[enable_weights] = make_combine_jit( + rank=cfg.rank, npes=cfg.world_size, + experts_per_rank=cfg.num_experts_per_rank, + experts_per_token=cfg.num_experts_per_token, + hidden_dim=cfg.hidden_dim, + max_tok_per_rank=cfg.max_num_inp_token_per_rank, + block_num=cfg.block_num, + warp_num_per_block=cfg.warp_num_per_block, + data_type=_comb_dtype, + enable_weights=bool(enable_weights), + enable_std_moe=cfg.enable_std_moe, + use_p2p_read=not cfg.use_external_inp_buf, + skip_stage1=True, + inp_data_type=_comb_inp_dt, + ) + + if enable_weights not in self._comb_no_s1_compiled: + args = ( + fx.Int64(inp_c.data_ptr()), + self._fx_comb_inp, + self._fx_comb_out, + self._fx_xdb_mem, + self._fx_xdev_flag, + self._fx_tok_map, + self._fx_comb_bar, + self._fx_trecv, + self._fx_tis, + self._fx_p2p_comb_inp, + self._fx_p2p_xdb_mem, + fx.Int64(wts_ptr), + self._fx_comb_inp_wts, + self._fx_comb_out_wts, + self._fx_p2p_comb_inp_wts, + fx.Int64(prx_ptr), + self._fx_disp_tok_map, + self._fx_disp_out_wts, + _cur_tok, + stream, + ) + self._comb_no_s1_compiled[enable_weights] = flyc.compile( + self._comb_no_s1_fn[enable_weights], *args + ) + else: + self._comb_no_s1_compiled[enable_weights]( + inp_c.data_ptr(), + self._fx_comb_inp, + self._fx_comb_out, + self._fx_xdb_mem, + self._fx_xdev_flag, + self._fx_tok_map, + self._fx_comb_bar, + self._fx_trecv, + self._fx_tis, + self._fx_p2p_comb_inp, + self._fx_p2p_xdb_mem, + wts_ptr, + self._fx_comb_inp_wts, + self._fx_comb_out_wts, + self._fx_p2p_comb_inp_wts, + prx_ptr, + self._fx_disp_tok_map, + self._fx_disp_out_wts, + _cur_tok, + stream, + ) + + mt = cfg.max_num_inp_token_per_rank + hdim = cfg.hidden_dim + k = cfg.num_experts_per_token + + # fp8_direct_cast contract: combine kernel writes bf16 to + # ``shmem_comb_out_tok`` directly (see ``combine`` above for details). + out_tok = self.shmem_comb_out_tok.view(torch.int8)[ + :mt * cfg.token_bytes].view(cfg.data_type).view(mt, cfg.token_view_dim) + out_wts = self.shmem_comb_out_wts.view(mt, k) + + if call_reset: + self.reset() + return out_tok, out_wts + + def get_dispatch_src_token_pos(self): + torch.cuda.synchronize() + n = int(self.total_recv[0].item()) + return self.shmem_tok_id_to_src[:n].clone() + + def get_registered_combine_input_buffer(self, dtype, hidden_dim=-1): + h = hidden_dim if hidden_dim > 0 else self.cfg.token_view_dim + dt = dtype if dtype is not None else self.cfg.data_type + return self.shmem_comb_inp_tok.view(torch.int8).view(dt).view(-1, h) diff --git a/python/flydsl/compiler/ast_rewriter.py b/python/flydsl/compiler/ast_rewriter.py index 54daa5ca..6b73a9e4 100644 --- a/python/flydsl/compiler/ast_rewriter.py +++ b/python/flydsl/compiler/ast_rewriter.py @@ -275,7 +275,14 @@ def _visit_stmt_block(self, stmts): return new_stmts def visit_FunctionDef(self, node: ast.FunctionDef): - if getattr(node, _ASTREWRITE_MARKER, False): + # ``_ASTREWRITE_MARKER`` is set by ReplaceIfWithDispatch / + # InsertEmptyYieldForSCFFor on the synthetic then/else/body functions + # they generate. It records *which* transformer created the node so + # only that transformer skips re-visiting -- other passes still need + # to recurse into the synthetic function body (e.g. so a ``for`` loop + # generated inside an if-then gets lowered to scf.for_dispatch). + marker = getattr(node, _ASTREWRITE_MARKER, False) + if marker is True or marker == type(self).__name__: return node with self.symbol_scopes.function_scope(): @@ -797,7 +804,7 @@ def _state_return_node(): decorator_list=[], type_params=[], ) - setattr(then_func, _ASTREWRITE_MARKER, True) + setattr(then_func, _ASTREWRITE_MARKER, type(self).__name__) then_func = ast.copy_location(then_func, node) then_func = ast.fix_missing_locations(then_func) @@ -839,7 +846,7 @@ def _state_return_node(): decorator_list=[], type_params=[], ) - setattr(else_func, _ASTREWRITE_MARKER, True) + setattr(else_func, _ASTREWRITE_MARKER, type(self).__name__) else_func = ast.copy_location(else_func, node) else_func = ast.fix_missing_locations(else_func) dispatch_args.append(ast.Name(else_name, ctx=ast.Load())) @@ -861,7 +868,7 @@ def _state_return_node(): decorator_list=[], type_params=[], ) - setattr(else_func, _ASTREWRITE_MARKER, True) + setattr(else_func, _ASTREWRITE_MARKER, type(self).__name__) else_func = ast.copy_location(else_func, node) else_func = ast.fix_missing_locations(else_func) dispatch_args.append(ast.Name(else_name, ctx=ast.Load())) diff --git a/python/flydsl/expr/arith.py b/python/flydsl/expr/arith.py index 38403f6c..114857c1 100644 --- a/python/flydsl/expr/arith.py +++ b/python/flydsl/expr/arith.py @@ -65,3 +65,74 @@ def cmpf(predicate, lhs, rhs, **kwargs): + + +@traced_op +def divui(lhs, rhs, **kwargs): + """Unsigned integer divide accepting DSL types and Python int constants. + + Generates ``arith.divui`` (efficient ``udiv`` on AMD GPU). + + Args: + lhs: Dividend (ArithValue, ir.Value, or DSL Numeric). + rhs: Divisor (ArithValue, ir.Value, DSL Numeric, or Python int). + """ + lhs_v = _to_raw(lhs) + if isinstance(rhs, int): + rhs_v = _to_raw(constant(rhs, type=lhs_v.type)) + else: + rhs_v = _to_raw(rhs) + return _mlir_arith.DivUIOp(lhs_v, rhs_v, **kwargs).result + + +@traced_op +def remui(lhs, rhs, **kwargs): + """Unsigned integer remainder accepting DSL types and Python int constants. + + Generates ``arith.remui`` (efficient ``urem`` on AMD GPU). + + Args: + lhs: Dividend (ArithValue, ir.Value, or DSL Numeric). + rhs: Divisor (ArithValue, ir.Value, DSL Numeric, or Python int). + """ + lhs_v = _to_raw(lhs) + if isinstance(rhs, int): + rhs_v = _to_raw(constant(rhs, type=lhs_v.type)) + else: + rhs_v = _to_raw(rhs) + return _mlir_arith.RemUIOp(lhs_v, rhs_v, **kwargs).result + + +def zext_i64(val): + """Zero-extend integer value to i64, idempotent if already i64. + + Returns ArithValue for use in arithmetic expressions. + """ + from .._mlir.extras import types as T + v = _to_raw(val) + i64 = T.i64() + if v.type == i64: + return v + return _mlir_arith.ExtUIOp(i64, v).result + + +@traced_op +def select_by_index(index_val, values): + """Select one of *values* by integer *index_val* via chained ``arith.select``. + + Equivalent to a compile-time switch: returns ``values[index_val]``. + + Args: + index_val: Integer index (i32 ``ir.Value``). + values: List of ``ir.Value`` to select from. + + Returns: + The selected ``ir.Value``. + """ + out = values[0] + for i in range(1, len(values)): + pred = _mlir_arith.CmpIOp( + _mlir_arith.CmpIPredicate.eq, index_val, constant(i, type=index_val.type) + ).result + out = _mlir_arith.SelectOp(pred, values[i], out).result + return out diff --git a/python/flydsl/expr/rocdl/__init__.py b/python/flydsl/expr/rocdl/__init__.py index abf94119..cd074a39 100644 --- a/python/flydsl/expr/rocdl/__init__.py +++ b/python/flydsl/expr/rocdl/__init__.py @@ -519,3 +519,31 @@ def ds_bpermute(res, index, src, **kw): def readfirstlane(res, src, **kw): return _ods_readfirstlane(res=res, src=_to_ir(src), **kw) + + + +def ballot_i64(cond, *, loc=None): + """Warp ballot returning 64-bit lane mask, with auto i1 coercion.""" + from ..._mlir.ir import IntegerType + from ..._mlir.dialects import llvm as _llvm, rocdl as _rocdl + + pred = _to_ir(cond) + i1 = IntegerType.get_signless(1) + if pred.type != i1: + pred = _llvm.TruncOp(i1, pred).result + i64 = IntegerType.get_signless(64) + return _rocdl.BallotOp(i64, pred, loc=loc).result + + +def readlane(val, lane, *, loc=None): + """Read a value from a specific warp lane, accepting Python int for *lane*.""" + from ..._mlir.ir import IntegerType, IntegerAttr + from ..._mlir.dialects import rocdl as _rocdl, arith as _arith + + src = _to_ir(val) + i32 = IntegerType.get_signless(32) + if isinstance(lane, int): + lane_v = _arith.ConstantOp(i32, IntegerAttr.get(i32, lane)).result + else: + lane_v = _to_ir(lane) + return _rocdl.ReadlaneOp(i32, src, lane_v, loc=loc).result diff --git a/python/flydsl/expr/vector.py b/python/flydsl/expr/vector.py index c9742cf3..4d7fcb98 100644 --- a/python/flydsl/expr/vector.py +++ b/python/flydsl/expr/vector.py @@ -118,3 +118,34 @@ def bitcast(result_type, source, *, loc=None, ip=None): loc=loc, ip=ip, ).result + + + +# Scalar <-> vector bitcast (requires llvm.BitcastOp). +# arith.bitcast and vector.BitCastOp do not support shape changes +# (e.g. i32 <-> vector<2xbf16>); llvm.BitcastOp is required. + +def bitcast_i32_to_v2bf16(val, *, loc=None): + """Bitcast i32 scalar to vector<2xbf16> (bit-identical reinterpretation). + + Used to reinterpret a packed i32 load result as two bf16 elements. + """ + from . import arith as _arith_ext + from .._mlir.dialects import llvm as _llvm + from .._mlir.extras import types as _T + + v2bf16 = _T.VectorType.get([2], _T.bf16()) + return _llvm.BitcastOp(v2bf16, _arith_ext.unwrap(val, loc=loc), loc=loc).res + + +def bitcast_v2bf16_to_i32(val, *, loc=None): + """Bitcast vector<2xbf16> to i32 (bit-identical reinterpretation). + + Used to pack two bf16 accumulator results into an i32 for store. + """ + from . import arith as _arith_ext + from .._mlir.dialects import llvm as _llvm + from .._mlir.ir import IntegerType + + i32 = IntegerType.get_signless(32) + return _llvm.BitcastOp(i32, _arith_ext.unwrap(val, loc=loc), loc=loc).res diff --git a/tests/kernels/test_profiler_dispatch_combine.py b/tests/kernels/test_profiler_dispatch_combine.py new file mode 100644 index 00000000..68ca91c9 --- /dev/null +++ b/tests/kernels/test_profiler_dispatch_combine.py @@ -0,0 +1,1211 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright (c) 2025 FlyDSL Project Contributors + +""" +FlyDSL 和 mori ref 的 dispatch/combine kernel 性能测试。 + +两个正交维度可自由组合: + --mode 测量方式:profile(torch.profiler 采集)| bench(CUDA Event 计时) + --cudagraph 执行方式:不带此标志 = eager 模式 | 带 = CUDAGraph capture+replay + +四种组合: + 1. profile + eager : torch.profiler 采集 eager 执行的 kernel + E2E + CPU 时间 + 2. bench + eager : CUDA Event 计时 eager dispatch/combine(无 profiler 开销) + 3. profile + cudagraph: torch.profiler 采集 CUDAGraph replay 中的 kernel 时间 + 4. bench + cudagraph: CUDA Event 计时 CUDAGraph replay(零 Python launch 开销) + +启动方式(支持 torchrun 或直接 python): + # profile + eager(默认) + python tests/kernels/test_profiler_dispatch_combine.py --max-tokens 512 + + # bench + eager + python tests/kernels/test_profiler_dispatch_combine.py --mode bench + + # bench + cudagraph + python tests/kernels/test_profiler_dispatch_combine.py --mode bench --cudagraph + + # profile + cudagraph + python tests/kernels/test_profiler_dispatch_combine.py --mode profile --cudagraph + + # 只测 FlyDSL + python tests/kernels/test_profiler_dispatch_combine.py --bench-op flydsl +""" +from __future__ import annotations + +import argparse +import json +import os +import sys + +import torch +import torch.distributed as dist +from torch.profiler import ProfilerActivity, profile, record_function + +os.environ.setdefault("MORI_SHMEM_HEAP_SIZE", "16G") + +# ── dtype 映射 ── +DTYPE_MAP = { + "bf16": torch.bfloat16, + "f32": torch.float32, + "fp8_ocp": torch.float8_e4m3fn, + "fp8_fnuz": torch.float8_e4m3fnuz, + "fp4": torch.float4_e2m1fn_x2, +} + +MORI_KERNEL_SUFFIX = { + "bf16": "bf16", + "f32": "f32", + "fp8_ocp": "fp8_ocp", + "fp8_fnuz": "fp8_fnuz", + "fp4": "fp4", +} + +_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "../..")) +for _p in [_ROOT, "/home/yashao/FlyDSL/python", "/home/yashao/mori/python"]: + if _p not in sys.path: + sys.path.insert(0, _p) + +import mori.shmem as ms +from kernels.dispatch_combine_intranode_op import ( + FlyDSLDispatchCombineConfig, + FlyDSLDispatchCombineIntraNodeOp, +) + + +# ─── 分布式初始化 ───────────────────────────────────────────────────────────── +def setup_distributed(rank, world_size, master_port=29600): + if "LOCAL_RANK" not in os.environ: + os.environ.update({ + "LOCAL_RANK": str(rank), "RANK": str(rank), + "WORLD_SIZE": str(world_size), + "MASTER_ADDR": "localhost", + "MASTER_PORT": str(master_port), + }) + local_rank = int(os.environ.get("LOCAL_RANK", rank)) + torch.cuda.set_device(local_rank) + dev = torch.device("cuda", local_rank) + dist.init_process_group( + backend="cpu:gloo,cuda:nccl", + rank=rank, world_size=world_size, device_id=dev, + ) + import torch._C._distributed_c10d as c10d + c10d._register_process_group("default", dist.group.WORLD) + ms.shmem_torch_process_group_init("default") + return local_rank, world_size + + +def cleanup(): + try: + ms.shmem_finalize() + except Exception: + pass + if dist.is_initialized(): + try: + dist.barrier() + except Exception: + pass + dist.destroy_process_group() + + +_MORI_SUPPORTED_DTYPES = {torch.bfloat16, torch.float32, torch.float8_e4m3fn, torch.float4_e2m1fn_x2} + + +def build_mori_ref(rank, world_size, cfg, + block_num: int = None, warp_per_block: int = None): + if cfg.data_type not in _MORI_SUPPORTED_DTYPES: + raise RuntimeError(f"mori does not support dtype {cfg.data_type} on this platform") + from mori.ops.dispatch_combine import EpDispatchCombineConfig, EpDispatchCombineOp + elem = torch.tensor([], dtype=cfg.data_type).element_size() + mcfg = EpDispatchCombineConfig( + data_type=cfg.data_type, + rank=rank, world_size=world_size, + hidden_dim=cfg.hidden_dim, + scale_dim=cfg.num_experts_per_token, scale_type_size=4, + max_token_type_size=elem, + max_num_inp_token_per_rank=cfg.max_num_inp_token_per_rank, + num_experts_per_rank=cfg.num_experts_per_rank, + num_experts_per_token=cfg.num_experts_per_token, + warp_num_per_block=warp_per_block if warp_per_block is not None else cfg.warp_num_per_block, + block_num=block_num if block_num is not None else cfg.block_num, + gpu_per_node=world_size, + use_external_inp_buf=cfg.use_external_inp_buf, + quant_type=cfg.quant_type, + ) + return EpDispatchCombineOp(mcfg) + + +def _save_profile_json(prof, out_path: str, rank: int, op_tag: str, meta: dict): + """将 profiler 结果序列化为 JSON 文件。 + + JSON 结构: + { + "meta": {op_tag, rank, max_tokens, hidden_dim, k, world_size, ...}, + "kernel_stats": [ {name, calls, cuda_time_avg_us, cpu_time_avg_us}, ... ] + } + """ + rows = [] + for evt in prof.key_averages(): + rows.append({ + "name": evt.key, + "calls": evt.count, + "cuda_time_avg_us": round(evt.device_time, 2), + "cuda_time_total_us": round(evt.device_time * evt.count, 2), + "cpu_time_avg_us": round(evt.cpu_time, 2), + "cpu_time_total_us": round(evt.cpu_time * evt.count, 2), + }) + # 按 GPU time 降序 + rows.sort(key=lambda r: r["cuda_time_total_us"], reverse=True) + + payload = { + "meta": {**meta, "op": op_tag, "rank": rank}, + "kernel_stats": rows, + } + os.makedirs(os.path.dirname(out_path), exist_ok=True) + with open(out_path, "w") as f: + json.dump(payload, f, indent=2, ensure_ascii=False) + + trace_path = out_path.replace(".json", "_trace.json") + prof.export_chrome_trace(trace_path) + + +def _allreduce_stats(prof, op_tag: str, rank: int, world_size: int, + dev: torch.device, dtype_key: str = "bf16", + quant_type: str = "none", + use_p2p_read: bool = False) -> dict: + """从本卡 profiler 提取关键指标,跨卡 all_reduce 后返回 avg/min/max 字典。 + + 采集 6 项指标(顺序固定,打包成 float64 tensor 做 all_reduce): + 0: dispatch GPU kernel time (μs/call) + 1: combine GPU kernel time (μs/call) + 2: dispatch record_function CUDA time (μs/call) + 3: combine record_function CUDA time (μs/call) + 4: dispatch record_function CPU time (μs/call) + 5: combine record_function CPU time (μs/call) + """ + msuf = MORI_KERNEL_SUFFIX.get(dtype_key, "bf16") + _cast_suf = "_fp8cast" if (quant_type == "fp8_direct_cast" and not use_p2p_read) else "" + _p2p_suf = "_p2p" if use_p2p_read else "_nop2p" + if op_tag == "flydsl": + d_kernel = "ep_dispatch_intranode_0" + c_kernel = "ep_combine_intranode_0" + else: + d_kernel = f"EpDispatchIntraNodeKernel_{msuf}" + c_kernel = f"EpCombineIntraNodeKernel_{msuf}{_p2p_suf}{_cast_suf}" + d_label = f"{op_tag}::dispatch" + c_label = f"{op_tag}::combine" + + ev = {e.key: e for e in prof.key_averages()} + + def gpu_us(key): + e = ev.get(key) + return e.device_time if (e and e.count) else 0.0 + + def cpu_us(key): + e = ev.get(key) + return e.cpu_time if (e and e.count) else 0.0 + + local = torch.tensor([ + gpu_us(d_kernel), gpu_us(c_kernel), + gpu_us(d_label), gpu_us(c_label), + cpu_us(d_label), cpu_us(c_label), + ], dtype=torch.float64, device=dev) + + s = local.clone(); dist.all_reduce(s, op=dist.ReduceOp.SUM) + mx = local.clone(); dist.all_reduce(mx, op=dist.ReduceOp.MAX) + mn = local.clone(); dist.all_reduce(mn, op=dist.ReduceOp.MIN) + avg = s / world_size + + keys = ["dispatch_gpu", "combine_gpu", + "dispatch_cuda_e2e", "combine_cuda_e2e", + "dispatch_cpu_e2e", "combine_cpu_e2e"] + return {k: {"avg": avg[i].item(), "min": mn[i].item(), "max": mx[i].item()} + for i, k in enumerate(keys)} + + +def _print_aggregated(stats: dict, op_tag: str, world_size: int, meta: dict): + """rank 0 打印全卡聚合统计。""" + sep = "=" * 72 + print(f"\n{sep}") + print(f" {op_tag.upper()} EP={world_size} bs={meta['max_tokens']} " + f"h={meta['hidden_dim']} k={meta['k']} ({meta['iters']} iters)") + print(f" 所有 {world_size} 张卡的 avg / min / max(μs/call)") + print(sep) + hdr = f" {'指标':<36} {'avg':>8} {'min':>8} {'max':>8}" + print(hdr) + print(f" {'-'*60}") + + rows = [ + ("[Device] dispatch kernel GPU time", "dispatch_gpu"), + ("[Device] combine kernel GPU time", "combine_gpu"), + ("[E2E] dispatch CUDA time (含sync)", "dispatch_cuda_e2e"), + ("[E2E] combine CUDA time (含sync)", "combine_cuda_e2e"), + ("[Host] dispatch CPU time", "dispatch_cpu_e2e"), + ("[Host] combine CPU time", "combine_cpu_e2e"), + ] + for label, key in rows: + v = stats[key] + print(f" {label:<36} {v['avg']:>8.1f} {v['min']:>8.1f} {v['max']:>8.1f}") + print() + + +def _allreduce_cudagraph_stats_from_key_averages( + prof, op_tag: str, rank: int, world_size: int, + dev: torch.device, dtype_key: str = "bf16", + quant_type: str = "none", + use_p2p_read: bool = False) -> dict: + """从 key_averages() 提取指标(仅含 active 阶段数据),跨卡 all_reduce。 + + 采集 4 项: + 0: dispatch kernel GPU time + 1: combine kernel GPU time + 2: cudagraph_replay CUDA E2E time + 3: cudagraph_replay CPU E2E time + """ + msuf = MORI_KERNEL_SUFFIX.get(dtype_key, "bf16") + _cast_suf = "_fp8cast" if (quant_type == "fp8_direct_cast" and not use_p2p_read) else "" + _p2p_suf = "_p2p" if use_p2p_read else "_nop2p" + if op_tag == "flydsl": + d_kernel = "ep_dispatch_intranode_0" + c_kernel = "ep_combine_intranode_0" + else: + d_kernel = f"EpDispatchIntraNodeKernel_{msuf}" + c_kernel = f"EpCombineIntraNodeKernel_{msuf}{_p2p_suf}{_cast_suf}" + cg_label = f"{op_tag}::cudagraph_replay" + + ev = {e.key: e for e in prof.key_averages()} + + def gpu_us(key): + e = ev.get(key) + return e.device_time if (e and e.count) else 0.0 + + def cpu_us(key): + e = ev.get(key) + return e.cpu_time if (e and e.count) else 0.0 + + local = torch.tensor([ + gpu_us(d_kernel), gpu_us(c_kernel), + gpu_us(cg_label), cpu_us(cg_label), + ], dtype=torch.float64, device=dev) + + s = local.clone(); dist.all_reduce(s, op=dist.ReduceOp.SUM) + mx = local.clone(); dist.all_reduce(mx, op=dist.ReduceOp.MAX) + mn = local.clone(); dist.all_reduce(mn, op=dist.ReduceOp.MIN) + avg = s / world_size + + keys = ["dispatch_gpu", "combine_gpu", "replay_cuda_e2e", "replay_cpu_e2e"] + return {k: {"avg": avg[i].item(), "min": mn[i].item(), "max": mx[i].item()} + for i, k in enumerate(keys)} + + +def _cudagraph_stats_from_trace(trace_path: str, op_tag: str, + rank: int, world_size: int, + dev: torch.device, + active_iters: int, skip_first: int = 5, + dtype_key: str = "bf16", + quant_type: str = "none", + use_p2p_read: bool = False) -> dict: + """从 chrome trace JSON 手动统计 kernel 性能,跳过前 skip_first 次 active 调用。 + + 流程:解析 trace → 按时间排序取最后 active_iters 个事件 → 丢弃前 skip_first 个 → 跨卡聚合。 + """ + with open(trace_path) as f: + tr = json.load(f) + + msuf = MORI_KERNEL_SUFFIX.get(dtype_key, "bf16") + _cast_suf = "_fp8cast" if (quant_type == "fp8_direct_cast" and not use_p2p_read) else "" + _p2p_suf = "_p2p" if use_p2p_read else "_nop2p" + if op_tag == "flydsl": + d_name, c_name = "ep_dispatch_intranode_0", "ep_combine_intranode_0" + else: + d_name = f"EpDispatchIntraNodeKernel_{msuf}" + c_name = f"EpCombineIntraNodeKernel_{msuf}{_p2p_suf}{_cast_suf}" + cg_name = f"{op_tag}::cudagraph_replay" + + kernel_events = [e for e in tr["traceEvents"] if e.get("cat") == "kernel"] + d_all = sorted([e for e in kernel_events if d_name in e.get("name", "")], + key=lambda e: e["ts"]) + c_all = sorted([e for e in kernel_events if c_name in e.get("name", "")], + key=lambda e: e["ts"]) + cg_all = sorted([e for e in tr["traceEvents"] + if e.get("cat") == "gpu_user_annotation" + and cg_name in e.get("name", "")], + key=lambda e: e["ts"]) + + d_active = [e["dur"] for e in d_all[-active_iters:]] + c_active = [e["dur"] for e in c_all[-active_iters:]] + cg_active = [e["dur"] for e in cg_all[-active_iters:]] + + d_valid = d_active[skip_first:] + c_valid = c_active[skip_first:] + cg_valid = cg_active[skip_first:] + + valid_n = len(d_valid) + if rank == 0: + print(f"[trace-stats] {op_tag}: trace 中 dispatch={len(d_all)} combine={len(c_all)} 个事件," + f"取最后 {active_iters} 个 active,跳过前 {skip_first},有效 {valid_n} 个") + + d_avg = sum(d_valid) / valid_n if valid_n else 0.0 + c_avg = sum(c_valid) / valid_n if valid_n else 0.0 + cg_avg = sum(cg_valid) / len(cg_valid) if cg_valid else 0.0 + + local = torch.tensor([d_avg, c_avg, cg_avg, 0.0], + dtype=torch.float64, device=dev) + s = local.clone(); dist.all_reduce(s, op=dist.ReduceOp.SUM) + mx = local.clone(); dist.all_reduce(mx, op=dist.ReduceOp.MAX) + mn = local.clone(); dist.all_reduce(mn, op=dist.ReduceOp.MIN) + avg = s / world_size + + keys = ["dispatch_gpu", "combine_gpu", "replay_cuda_e2e", "replay_cpu_e2e"] + return {k: {"avg": avg[i].item(), "min": mn[i].item(), "max": mx[i].item()} + for i, k in enumerate(keys)} + + +def _print_cudagraph_aggregated(stats: dict, op_tag: str, world_size: int, meta: dict, + active_iters: int = None): + """rank 0 打印 cudagraph profiler 全卡聚合统计。""" + n = active_iters if active_iters is not None else meta['iters'] + sep = "=" * 72 + print(f"\n{sep}") + print(f" {op_tag.upper()} [CUDAGraph+Profiler] EP={world_size} bs={meta['max_tokens']} " + f"h={meta['hidden_dim']} k={meta['k']} ({n} iters)") + print(f" 所有 {world_size} 张卡的 avg / min / max(μs/call)") + print(sep) + hdr = f" {'指标':<36} {'avg':>8} {'min':>8} {'max':>8}" + print(hdr) + print(f" {'-'*60}") + + rows = [ + ("[Device] dispatch kernel GPU time", "dispatch_gpu"), + ("[Device] combine kernel GPU time", "combine_gpu"), + ("[E2E] replay CUDA time (含sync)", "replay_cuda_e2e"), + ("[Host] replay CPU time", "replay_cpu_e2e"), + ] + for label, key in rows: + v = stats[key] + print(f" {label:<36} {v['avg']:>8.1f} {v['min']:>8.1f} {v['max']:>8.1f}") + print() + + +def _make_profiler(active_iters: int = None, prof_warmup: int = 10): + """创建 profiler。 + + 使用 schedule 让前 (1 + prof_warmup) 步不做/轻量追踪, + 减少 ROCTracer 在多 GPU P2P shmem 场景下的累积压力。 + """ + kwargs = dict( + activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], + record_shapes=False, + with_stack=False, + ) + if active_iters is not None and active_iters > 0: + kwargs["schedule"] = torch.profiler.schedule( + wait=1, warmup=prof_warmup, active=active_iters, repeat=1, + ) + return profile(**kwargs) + + +# ─── bench 模式:不用 profiler,用 CUDA Event 计时 ──────────────────────────── +def bench_op(op, op_tag: str, inp, wts, idx, wc_buf, k, + rank: int, world_size: int, dev: torch.device, + warmup: int, iters: int, meta: dict, + scales=None, packed_recv_x=None): + """无 profiler 的纯计时模式,输出 dispatch / combine 的 GPU 耗时(avg/min/max)。""" + _dkw = dict(packed_recv_x=packed_recv_x) if packed_recv_x is not None else {} + _ckw = dict(packed_recv_x=packed_recv_x) if packed_recv_x is not None else {} + ms.shmem_barrier_all() + if rank == 0: + print(f"\n[bench] {op_tag} 预热 {warmup} 轮...") + for _ in range(warmup): + op.reset() + ret = op.dispatch(inp, wts, scales, idx, **_dkw) + op.combine(ret[0], None, ret[3], **_ckw) + torch.cuda.synchronize() + dist.barrier() + + if rank == 0: + print(f"[bench] {op_tag} 计时 {iters} 轮...") + + d_events = [(torch.cuda.Event(enable_timing=True), + torch.cuda.Event(enable_timing=True)) for _ in range(iters)] + c_events = [(torch.cuda.Event(enable_timing=True), + torch.cuda.Event(enable_timing=True)) for _ in range(iters)] + + for i in range(iters): + # op.reset() + dist.barrier() + + d_events[i][0].record() + ret = op.dispatch(inp, wts, scales, idx, **_dkw) + d_events[i][1].record() + + dist.barrier() + + c_events[i][0].record() + op.combine(ret[0], None, ret[3], **_ckw) + c_events[i][1].record() + + torch.cuda.synchronize() + d_list = [d_events[i][0].elapsed_time(d_events[i][1]) * 1000 for i in range(iters)] + c_list = [c_events[i][0].elapsed_time(c_events[i][1]) * 1000 for i in range(iters)] + + # 全卡聚合 avg / min / max + local = torch.tensor([ + sum(d_list) / len(d_list), min(d_list), max(d_list), + sum(c_list) / len(c_list), min(c_list), max(c_list), + ], dtype=torch.float64, device=dev) + s = local.clone(); dist.all_reduce(s, op=dist.ReduceOp.SUM) + mx = local.clone(); dist.all_reduce(mx, op=dist.ReduceOp.MAX) + mn = local.clone(); dist.all_reduce(mn, op=dist.ReduceOp.MIN) + avg_d = (s[0] / world_size).item(); mn_d = mn[0].item(); mx_d = mx[2].item() + avg_c = (s[3] / world_size).item(); mn_c = mn[3].item(); mx_c = mx[5].item() + + if rank == 0: + sep = "=" * 68 + tag = (f"{op_tag.upper()} EP={meta['world_size']} bs={meta['max_tokens']} " + f"h={meta['hidden_dim']} k={meta['k']} ({iters} iters)") + print(f"\n{sep}\n {tag}\n 所有 {world_size} 张卡的 avg / min / max(μs/call)\n{sep}") + print(f" {'指标':<36} {'avg':>8} {'min':>8} {'max':>8}") + print(f" {'-'*58}") + print(f" {'[E2E] dispatch CUDA time':<36} {avg_d:>8.1f} {mn_d:>8.1f} {mx_d:>8.1f}") + print(f" {'[E2E] combine CUDA time':<36} {avg_c:>8.1f} {mn_c:>8.1f} {mx_c:>8.1f}") + print() + + +# ─── cudagraph 模式:CUDA Graph capture + replay 计时 ───────────────────────── +def _cudagraph_capture_flydsl(op, inp, wts, idx, wc_buf, capture_stream, + scales=None, packed_recv_x=None): + """FlyDSL:录制 dispatch+combine 到 CUDA Graph。 + + dispatch/combine 均返回全尺寸 tensor(无 .item()、无动态切片)。 + 需要先 eager 调用一次触发 flyc.compile() JIT 编译(编译过程使用 + default stream,不能在 capture 期间执行),之后 capture 中仅录制 + 已编译的 kernel launch。 + """ + _dkw = dict(packed_recv_x=packed_recv_x) if packed_recv_x is not None else {} + _ckw = dict(packed_recv_x=packed_recv_x) if packed_recv_x is not None else {} + op.reset() + ret = op.dispatch(inp, wts, scales, idx, **_dkw) + op.combine(ret[0], None, ret[3], **_ckw) + + op.barrier() + g = torch.cuda.CUDAGraph() + with torch.cuda.graph(g, stream=capture_stream): + ret = op.dispatch(inp, wts, scales, idx, **_dkw) + op.combine(ret[0], None, ret[3], **_ckw) + return g, capture_stream + + +def _cudagraph_capture_mori(op, inp, wts, idx, wc_buf, capture_stream, + scales=None, packed_recv_x=None): + """Mori 专用:直接在 graph capture 中录制 dispatch+combine。 + + Mori 的 dispatch 在 capture 模式下返回真实 tensor,combine kernel + 从 HBM 读取 totalRecvTokenNum,无需 pre-capture eager call。 + 参考 mori/tests/python/ops/bench_dispatch_combine.py stress_graph 写法。 + """ + ms.shmem_barrier_all() + g = torch.cuda.CUDAGraph() + with torch.cuda.graph(g, stream=capture_stream): + ret = op.dispatch(inp, wts, None, idx) + op.combine(ret[0], None, ret[3]) + return g, capture_stream + + +def cudagraph_op(op, op_tag: str, inp, wts, idx, wc_buf, k, + rank: int, world_size: int, dev: torch.device, + warmup: int, iters: int, meta: dict, + scales=None, packed_recv_x=None): + """CUDA Graph 模式:capture dispatch+combine kernel,replay 计时。""" + capture_stream = torch.cuda.Stream() + if op_tag == "flydsl": + g, cs = _cudagraph_capture_flydsl( + op, inp, wts, idx, wc_buf, capture_stream, + scales=scales, packed_recv_x=packed_recv_x) + else: + g, cs = _cudagraph_capture_mori( + op, inp, wts, idx, wc_buf, capture_stream, + scales=scales, packed_recv_x=packed_recv_x) + + if rank == 0: + print(f"\n[cudagraph] {op_tag} capture done") + + # replay warmup(HIP graph 冷启动 + GPU 缓存预热) + replay_warmup = 10 + if rank == 0: + print(f"[cudagraph] replay warmup {replay_warmup} 轮 + 计时 {iters} 轮(no-reset)...") + for _ in range(replay_warmup): + g.replay() + torch.cuda.synchronize() + + # 计时:预分配 event pairs,循环结束后统一 sync + events = [(torch.cuda.Event(enable_timing=True), + torch.cuda.Event(enable_timing=True)) for _ in range(iters)] + + for i in range(iters): + events[i][0].record() + g.replay() + events[i][1].record() + + torch.cuda.synchronize() + gpu_times = [events[i][0].elapsed_time(events[i][1]) * 1000 for i in range(iters)] + + # per-replay 诊断 + per_replay_t = torch.tensor(gpu_times, dtype=torch.float64, device=dev) + all_per_replay = [torch.zeros_like(per_replay_t) for _ in range(world_size)] + dist.all_gather(all_per_replay, per_replay_t) + + local = torch.tensor([ + sum(gpu_times) / len(gpu_times), min(gpu_times), max(gpu_times), + ], dtype=torch.float64, device=dev) + s = local.clone(); dist.all_reduce(s, op=dist.ReduceOp.SUM) + mx = local.clone(); dist.all_reduce(mx, op=dist.ReduceOp.MAX) + mn = local.clone(); dist.all_reduce(mn, op=dist.ReduceOp.MIN) + avg_g = (s[0] / world_size).item(); mn_g = mn[0].item(); mx_g = mx[2].item() + + if rank == 0: + sep = "=" * 68 + tag = (f"{op_tag.upper()} [CUDAGraph] EP={meta['world_size']} " + f"bs={meta['max_tokens']} h={meta['hidden_dim']} k={meta['k']} " + f"({iters} replays)") + print(f"\n{sep}\n {tag}\n 所有 {world_size} 张卡的 avg / min / max(μs/call)\n{sep}") + print(f" {'指标':<36} {'avg':>8} {'min':>8} {'max':>8}") + print(f" {'-'*58}") + print(f" {'[GPU] dispatch+combine (event)':<36} {avg_g:>8.1f} {mn_g:>8.1f} {mx_g:>8.1f}") + + print(f"\n Per-replay GPU time (μs) — all {world_size} ranks:") + hdr = f" {'replay':>6}" + "".join(f" {'R'+str(r):>8}" for r in range(world_size)) + f" {'max':>8}" + print(hdr) + mat = torch.stack(all_per_replay) + for i in range(iters): + vals = [mat[r, i].item() for r in range(world_size)] + mx_i = max(vals) + row = f" {i:>6}" + "".join(f" {v:>8.1f}" for v in vals) + f" {mx_i:>8.1f}" + if mx_i > avg_g * 3: + row += " ← SPIKE" + print(row) + print() + + +# ─── 单算子 profiler 采集 ────────────────────────────────────────────────────── +def profile_op(op, op_tag: str, inp, wts, idx, wc_buf, k, + rank: int, world_size: int, dev: torch.device, + iters: int, out_dir: str, meta: dict, + scales=None, packed_recv_x=None, + dtype_key: str = "bf16", + quant_type: str = "none", + use_p2p_read: bool = False): + """对单个算子(FlyDSL 或 mori)独立 profiling,保存 JSON 并打印全卡聚合统计。 + + 使用 schedule(wait=1, warmup=10, active=iters) 让 ROCTracer 在前 11 步 + 不做/轻量追踪,减少与多 GPU P2P shmem 操作的冲突。 + """ + ms.shmem_barrier_all() + prof_warmup = 10 + total_steps = iters + 1 + prof_warmup # wait=1 + warmup=prof_warmup + active=iters + if rank == 0: + print(f"\n[profiler] {op_tag} 开始采集({iters} 轮 active + {1 + prof_warmup} 轮 ramp-up)...") + + _dkw = dict(packed_recv_x=packed_recv_x) if packed_recv_x is not None else {} + _ckw = dict(packed_recv_x=packed_recv_x) if packed_recv_x is not None else {} + with _make_profiler(active_iters=iters, prof_warmup=prof_warmup) as prof: + for step in range(total_steps): + # with record_function(f"{op_tag}::reset"): + # op.reset() + dist.barrier() + + with record_function(f"{op_tag}::dispatch"): + ret = op.dispatch(inp, wts, scales, idx, **_dkw) + + dist.barrier() + + with record_function(f"{op_tag}::combine"): + op.combine(ret[0], None, ret[3], **_ckw) + + # dist.barrier() + + prof.step() + + # 保存 JSON:每张卡各自保存,文件名含 op_tag 和 rank + out_path = os.path.join(out_dir, f"{op_tag}_rank{rank}.json") + _save_profile_json(prof, out_path, rank, op_tag, meta) + if rank == 0: + print(f"[profiler] {op_tag} trace → {out_path}") + + # 跨卡聚合统计(all_reduce),rank 0 打印 + agg_stats = _allreduce_stats(prof, op_tag, rank, world_size, dev, + dtype_key=dtype_key, quant_type=quant_type, + use_p2p_read=use_p2p_read) + if rank == 0: + _print_aggregated(agg_stats, op_tag, world_size, meta) + return prof + + +# ─── profile + cudagraph 模式 ───────────────────────────────────────────────── +def profile_cudagraph_op(op, op_tag: str, inp, wts, idx, wc_buf, k, + rank: int, world_size: int, dev: torch.device, + warmup: int, iters: int, out_dir: str, meta: dict, + scales=None, packed_recv_x=None, + dtype_key: str = "bf16", + quant_type: str = "none", + use_p2p_read: bool = False): + """torch.profiler 采集 CUDAGraph replay,保存 JSON 并打印全卡聚合统计。 + + 流程:eager warmup → graph capture → replay warmup → profiler 包裹的 replay。 + """ + ms.shmem_barrier_all() + + capture_stream = torch.cuda.Stream() + if op_tag == "flydsl": + g, cs = _cudagraph_capture_flydsl( + op, inp, wts, idx, wc_buf, capture_stream, + scales=scales, packed_recv_x=packed_recv_x) + else: + g, cs = _cudagraph_capture_mori( + op, inp, wts, idx, wc_buf, capture_stream, + scales=scales, packed_recv_x=packed_recv_x) + + if rank == 0: + print(f"\n[profile+cudagraph] {op_tag} capture done") + + # replay warmup(HIP graph 冷启动 + GPU 缓存预热) + replay_warmup = 10 + for _ in range(replay_warmup): + g.replay() + torch.cuda.synchronize() + + prof_warmup = 5 + active_iters = iters + skip_first = 5 + valid_iters = max(active_iters - skip_first, 1) + total_steps = 1 + prof_warmup + active_iters # wait=1 + warmup + active + if rank == 0: + print(f"[profile+cudagraph] {op_tag} scheduled profiler: " + f"warmup={prof_warmup}, active={active_iters}, " + f"丢弃前 {skip_first} 次,有效 {valid_iters} 次(no-reset)...") + + with _make_profiler(active_iters=active_iters, prof_warmup=prof_warmup) as prof: + for step in range(total_steps): + with record_function(f"{op_tag}::cudagraph_replay"): + g.replay() + prof.step() + + out_path = os.path.join(out_dir, f"{op_tag}_cudagraph_rank{rank}.json") + _save_profile_json(prof, out_path, rank, op_tag, meta) + trace_path = out_path.replace(".json", "_trace.json") + if rank == 0: + print(f"[profile+cudagraph] {op_tag} trace → {trace_path}") + + agg_stats = _cudagraph_stats_from_trace( + trace_path, op_tag, rank, world_size, dev, + active_iters=active_iters, skip_first=skip_first, + dtype_key=dtype_key, quant_type=quant_type, + use_p2p_read=use_p2p_read) + if rank == 0: + _print_cudagraph_aggregated(agg_stats, op_tag, world_size, meta, + active_iters=valid_iters) + return prof + + +# ─── verify 模式:正确性验证 ───────────────────────────────────────────────── +VERIFY_TOL = { + "f32": {"atol": 1e-5, "rtol": 1e-4}, + "bf16": {"atol": 1e-2, "rtol": 1e-2}, + "fp8_ocp": {"atol": 1e-1, "rtol": 5e-2}, + "fp8_fnuz": {"atol": 1e-1, "rtol": 5e-2}, + "fp4": {"atol": 5e-1, "rtol": 1e-1}, +} + + +def _check_close(name, a, b, atol, rtol, rank, cast_to=None): + """Compare two tensors and print PASS/FAIL.""" + if cast_to is not None: + a, b = a.to(cast_to), b.to(cast_to) + ok = torch.allclose(a, b, atol=atol, rtol=rtol) + max_diff = (a.float() - b.float()).abs().max().item() + status = "PASS" if ok else "FAIL" + if rank == 0: + print(f" [{status}] {name:40s} max_diff={max_diff:.6g} atol={atol} rtol={rtol}") + return ok + + +def _check_exact(name, a, b, rank): + """Compare two tensors for exact equality.""" + ok = torch.equal(a, b) + if not ok: + diff_count = (a != b).sum().item() + status = "FAIL" + else: + diff_count = 0 + status = "PASS" + if rank == 0: + print(f" [{status}] {name:40s} diff_elements={diff_count}") + return ok + + +def verify_self(op_fly, inp, wts, idx, k, + rank, world_size, dev, dtype_key, cfg): + """FlyDSL self-check when mori is unavailable. + + dispatch → combine → verify output ≈ weighted sum of input. + With uniform weights (1/k) and k distinct PEs, combine output should ≈ input. + """ + tol = VERIFY_TOL.get(dtype_key, VERIFY_TOL["bf16"]) + if cfg.quant_type == "fp8_direct_cast": + tol = {"atol": 2.0 * k, "rtol": 0.5} + all_pass = True + + if rank == 0: + print(f"\n{'='*65}") + print(f" VERIFY (self-check, mori unavailable) dtype={dtype_key} " + f"EP={world_size} bs={inp.shape[0]} h={cfg.hidden_dim} k={k}") + print(f"{'='*65}") + + op_fly.reset() + ms.shmem_barrier_all() + + packed_recv_x = None + if cfg.enable_std_moe: + epr = cfg.num_experts_per_rank + mr = cfg.max_recv + _prx_nbytes = epr * mr * cfg.token_bytes + packed_recv_x = torch.zeros( + _prx_nbytes, dtype=torch.uint8, device=dev + ).view(cfg.data_type).view(epr * mr, cfg.token_view_dim) + + scales = None + if cfg.scale_dim > 0 and cfg.scale_type_size > 0: + _sc_bytes = cfg.scale_dim * cfg.scale_type_size + scales = torch.randn(inp.shape[0], _sc_bytes // 4, + dtype=torch.float32, device=dev).contiguous() + scales = scales.view(torch.uint8).view(inp.shape[0], _sc_bytes) + + ret_f = op_fly.dispatch(inp, wts, scales, idx, + packed_recv_x=packed_recv_x) + torch.cuda.synchronize() + dist.barrier() + + if rank == 0: + tr = ret_f[4].item() + print(f"\n total_recv = {tr}") + + cout_f = op_fly.combine(ret_f[0], None, ret_f[3], + packed_recv_x=packed_recv_x) + torch.cuda.synchronize() + dist.barrier() + + mt = cfg.max_num_inp_token_per_rank + f_tok = cout_f[0][:mt] + + + if cfg.enable_std_moe: + scale_factor = 1 + check_label = "out_tok vs inp (StdMoE weighted)" + else: + scale_factor = k + check_label = "out_tok vs k*inp" + + if rank == 0: + print(f"\n ── Self-check: combine output vs {'inp' if scale_factor == 1 else 'k*input'} ──") + if cfg.data_type == torch.float4_e2m1fn_x2: + if k == 1 and not cfg.enable_std_moe: + ok = torch.equal(f_tok.view(torch.uint8), inp.view(torch.uint8)) + status = "PASS" if ok else "FAIL" + print(f" [{status}] out_tok vs inp (byte-level, k=1)") + all_pass &= ok + else: + print(f" [SKIP] fp4 numeric check not supported " + f"(k={k}, std_moe={cfg.enable_std_moe})") + else: + cast_to = torch.float32 if cfg.data_type in ( + torch.float8_e4m3fn, torch.float8_e4m3fnuz) else None + try: + expected = (inp.float() * scale_factor).to(cfg.data_type) + all_pass &= _check_close( + check_label, f_tok, expected, + tol["atol"], tol["rtol"], rank, cast_to=cast_to) + except Exception as e: + has_nan = torch.isnan(f_tok.float()).any().item() + has_inf = torch.isinf(f_tok.float()).any().item() + print(f" [INFO] Self-check exception (NaN={has_nan}, Inf={has_inf}): {e}") + all_pass &= (not has_nan and not has_inf) + + if rank == 0: + result = "ALL PASS" if all_pass else "SOME FAILED" + print(f"\n >>> {result} <<<\n") + return all_pass + + +def verify_op(op_fly, op_mori, inp, wts, idx, k, + rank, world_size, dev, dtype_key, cfg, args): + """Run FlyDSL and mori dispatch+combine, compare outputs. + + Dispatch output ordering is non-deterministic (atomic fetch-and-add), so + we only compare total_recv. Combine output is the final accumulated result + and should be semantically identical. + """ + tol = VERIFY_TOL.get(dtype_key, VERIFY_TOL["bf16"]) + all_pass = True + + if rank == 0: + print(f"\n{'='*65}") + print(f" VERIFY dtype={dtype_key} EP={world_size} bs={inp.shape[0]} " + f"h={cfg.hidden_dim} k={k}") + print(f"{'='*65}") + + # ── Dispatch ── + op_fly.reset(); op_mori.reset() + ms.shmem_barrier_all() + + ret_f = op_fly.dispatch(inp, wts, None, idx) + ret_m = op_mori.dispatch(inp, wts, None, idx) + torch.cuda.synchronize() + + tr_f = ret_f[4].clone(); tr_m = ret_m[4].clone() + dist.barrier() + if rank == 0: + print("\n ── Dispatch 对比(仅 total_recv,token 排列因原子序不同) ──") + + all_pass &= _check_exact("total_recv", tr_f, tr_m, rank) + + # ── Combine ── + ms.shmem_barrier_all() + cout_f = op_fly.combine(ret_f[0], None, ret_f[3]) + cout_m = op_mori.combine(ret_m[0], None, ret_m[3]) + torch.cuda.synchronize() + dist.barrier() + + if rank == 0: + print("\n ── Combine 输出对比 ──") + + mt = cfg.max_num_inp_token_per_rank + cast_to = torch.float32 if cfg.data_type in ( + torch.float8_e4m3fn, torch.float8_e4m3fnuz, torch.float4_e2m1fn_x2) else None + + f_tok = cout_f[0][:mt] if cout_f[0] is not None else None + m_tok = cout_m[0][:mt] if cout_m[0] is not None else None + + # Diagnostic: compare both outputs with expected k*input (skip for packed types) + if rank == 0 and f_tok is not None and m_tok is not None: + try: + expected = (inp.float() * k).to(cfg.data_type) + f_vs_exp = (f_tok.float() - expected.float()).abs().max().item() + m_vs_exp = (m_tok.float() - expected.float()).abs().max().item() + f_vs_m = (f_tok.float() - m_tok.float()).abs().max().item() + print(f" [DIAG] fly vs k*inp: {f_vs_exp:.4f} mori vs k*inp: {m_vs_exp:.4f} fly vs mori: {f_vs_m:.4f}") + except Exception: + pass + + if f_tok is not None and m_tok is not None: + all_pass &= _check_close( + "out_tok[:mt]", f_tok, m_tok, + tol["atol"], tol["rtol"], rank, cast_to=cast_to) + elif rank == 0: + print(f" [SKIP] out_tok: fly={f_tok is not None}, mori={m_tok is not None}") + + f_wts = cout_f[1] if (len(cout_f) > 1 and cout_f[1] is not None) else None + m_wts = cout_m[1] if (len(cout_m) > 1 and cout_m[1] is not None) else None + if f_wts is not None and m_wts is not None: + all_pass &= _check_close( + "out_wts[:mt]", f_wts[:mt], m_wts[:mt], + 1e-4, 1e-3, rank) + elif rank == 0: + print(f" [SKIP] out_wts: fly={f_wts is not None}, mori={m_wts is not None}") + + if rank == 0: + result = "ALL PASS" if all_pass else "SOME FAILED" + print(f"\n >>> {result} <<<\n") + return all_pass + + +# ─── 主逻辑 ─────────────────────────────────────────────────────────────────── +def run_profiler(rank, world_size, args): + dev = torch.device("cuda", rank) + k = args.k + cur_tok = args.max_tokens + n_exp = world_size * args.num_experts_per_rank + + _dtype = DTYPE_MAP.get(args.dtype, torch.bfloat16) + cfg = FlyDSLDispatchCombineConfig( + rank=rank, world_size=world_size, + hidden_dim=args.hidden_dim, + max_num_inp_token_per_rank=cur_tok, + num_experts_per_rank=args.num_experts_per_rank, + num_experts_per_token=k, + data_type=_dtype, + warp_num_per_block=args.warp_per_block, + block_num=args.block_num, + chip=args.chip, + use_external_inp_buf=args.use_external_inp_buf, + enable_std_moe=args.enable_std_moe, + scale_dim=args.scale_dim, + scale_type_size=args.scale_type_size, + quant_type=args.quant_type, + ) + + mori_bn = args.mori_block_num if args.mori_block_num > 0 else cfg.block_num + mori_wpb = args.mori_warp_per_block if args.mori_warp_per_block > 0 else cfg.warp_num_per_block + meta = dict( + world_size=world_size, max_tokens=cur_tok, + hidden_dim=cfg.hidden_dim, k=k, + num_experts_per_rank=args.num_experts_per_rank, + warmup=args.warmup, iters=args.iters, + flydsl_block_num=cfg.block_num, + flydsl_warp_per_block=cfg.warp_num_per_block, + mori_block_num=mori_bn, + mori_warp_per_block=mori_wpb, + use_external_inp_buf=cfg.use_external_inp_buf, + enable_std_moe=cfg.enable_std_moe, + scale_dim=cfg.scale_dim, + scale_type_size=cfg.scale_type_size, + quant_type=cfg.quant_type, + ) + + # 输出目录:/tmp/ep{ws}_bs{cur_tok}/ + out_dir = os.path.join(args.output_dir, f"ep{world_size}_bs{cur_tok}") + os.makedirs(out_dir, exist_ok=True) + + # ── 构建算子 ─────────────────────────────────────────────────────────────── + if rank == 0: + print(f"\n{'='*65}") + print(f"[profiler] EP={world_size}, bs={cur_tok}, h={cfg.hidden_dim}, k={k}") + print(f"{'='*65}") + print("[profiler] 构建 FlyDSL...") + op_fly = FlyDSLDispatchCombineIntraNodeOp(cfg) + + op_ref = None + if args.compare and not cfg.enable_std_moe: + mori_bn = args.mori_block_num if args.mori_block_num > 0 else None + mori_wpb = args.mori_warp_per_block if args.mori_warp_per_block > 0 else None + bn_str = mori_bn if mori_bn else cfg.block_num + wpb_str = mori_wpb if mori_wpb else cfg.warp_num_per_block + if rank == 0: + print(f"[profiler] 构建 mori ref (block_num={bn_str}, warp_per_block={wpb_str})...") + try: + op_ref = build_mori_ref(rank, world_size, cfg, + block_num=mori_bn, warp_per_block=mori_wpb) + except Exception as e: + if rank == 0: + print(f"[warn] mori ref 不可用: {e}") + elif cfg.enable_std_moe and rank == 0: + print("[info] StdMoE 模式:跳过 mori ref,使用自洽验证") + ms.shmem_barrier_all() + + # ── 准备输入(固定 seed,FlyDSL 和 mori 使用完全相同的输入)──────────────── + torch.manual_seed(42 + rank) + if cfg.data_type == torch.float4_e2m1fn_x2: + inp = torch.randint(0, 256, (cur_tok, cfg.hidden_dim // 2), dtype=torch.uint8, device=dev).view(torch.float4_e2m1fn_x2) + elif cfg.data_type in (torch.float8_e4m3fn, torch.float8_e4m3fnuz): + inp = torch.randn(cur_tok, cfg.hidden_dim, dtype=torch.bfloat16, device=dev).to(cfg.data_type) + else: + inp = torch.randn(cur_tok, cfg.hidden_dim, dtype=cfg.data_type, device=dev) + wts = torch.rand(cur_tok, k, dtype=torch.float32, device=dev) + wts = wts / wts.sum(-1, keepdim=True) + epr = args.num_experts_per_rank + idx = torch.zeros(cur_tok, k, dtype=torch.int32, device=dev) + if args.mode == "verify" and k <= world_size: + # Ensure each token's k experts go to k DISTINCT PEs. + # FlyDSL dispatch deduplicates same-PE assignments, mori does not. + for t in range(cur_tok): + pes = torch.randperm(world_size, device=dev)[:k] + for j in range(k): + idx[t, j] = pes[j] * epr + torch.randint(0, epr, (1,), device=dev) + else: + for t in range(cur_tok): + idx[t] = torch.randperm(n_exp, device=dev)[:k] + + # 预分配 combine 权重 buffer(FlyDSL 和 mori 共用,避免计时窗口内额外 GPU 核) + max_recv = world_size * cur_tok + wc_buf = torch.full((max_recv, k), 1.0 / k, dtype=torch.float32, device=dev) + + # ── 构造 scales / packed_recv_x(所有模式共用)───────────────────────── + packed_recv_x = None + if cfg.enable_std_moe: + _prx_nbytes = cfg.num_experts_per_rank * cfg.max_recv * cfg.token_bytes + packed_recv_x = torch.zeros( + _prx_nbytes, dtype=torch.uint8, device=dev + ).view(cfg.data_type).view( + cfg.num_experts_per_rank * cfg.max_recv, cfg.token_view_dim) + + scales = None + if cfg.scale_dim > 0 and cfg.scale_type_size > 0: + _sc_bytes = cfg.scale_dim * cfg.scale_type_size + scales = torch.randn(cur_tok, _sc_bytes // 4, + dtype=torch.float32, device=dev).contiguous() + scales = scales.view(torch.uint8).view(cur_tok, _sc_bytes) + + # profile+eager 模式需要外部预热;其他 3 种组合由各自函数内部处理 + do_warmup = (args.mode == "profile" and not args.cudagraph) + + if do_warmup: + if rank == 0: + print(f"[setup] 预热 FlyDSL {args.warmup} 轮...") + for _ in range(args.warmup): + op_fly.reset() + ret = op_fly.dispatch(inp, wts, scales, idx, + packed_recv_x=packed_recv_x) + op_fly.combine(ret[0], None, ret[3], + packed_recv_x=packed_recv_x) + torch.cuda.synchronize() + + if op_ref is not None: + if rank == 0: + print(f"[setup] 预热 mori ref {args.warmup} 轮...") + for _ in range(args.warmup): + op_ref.reset() + ret_r = op_ref.dispatch(inp, wts, None, idx) + op_ref.combine(ret_r[0], None, ret_r[3]) + torch.cuda.synchronize() + + ms.shmem_barrier_all() + + # ── 根据 mode × cudagraph 分发执行 ───────────────────────────────────── + test_flydsl = args.bench_op in ("flydsl", "both") + test_mori = args.bench_op in ("mori", "both") and op_ref is not None + + if args.mode == "verify": + if op_ref is not None: + verify_op(op_fly, op_ref, inp, wts, idx, k, + rank, world_size, dev, args.dtype, cfg, args) + else: + verify_self(op_fly, inp, wts, idx, k, + rank, world_size, dev, args.dtype, cfg) + return + + if args.mode == "bench" and not args.cudagraph: + if test_flydsl: + bench_op(op_fly, "flydsl", inp, wts, idx, wc_buf, k, + rank, world_size, dev, args.warmup, args.iters, meta, + scales=scales, packed_recv_x=packed_recv_x) + if test_mori: + ms.shmem_barrier_all() + bench_op(op_ref, "mori", inp, wts, idx, wc_buf, k, + rank, world_size, dev, args.warmup, args.iters, meta) + + elif args.mode == "bench" and args.cudagraph: + if test_flydsl: + cudagraph_op(op_fly, "flydsl", inp, wts, idx, wc_buf, k, + rank, world_size, dev, args.warmup, args.iters, meta, + scales=scales, packed_recv_x=packed_recv_x) + if test_mori: + ms.shmem_barrier_all() + cudagraph_op(op_ref, "mori", inp, wts, idx, wc_buf, k, + rank, world_size, dev, args.warmup, args.iters, meta) + + elif args.mode == "profile" and not args.cudagraph: + _p2p = not args.use_external_inp_buf + if test_flydsl: + profile_op(op_fly, "flydsl", inp, wts, idx, wc_buf, k, + rank, world_size, dev, args.iters, out_dir, meta, + scales=scales, packed_recv_x=packed_recv_x, + dtype_key=args.dtype, quant_type=args.quant_type, + use_p2p_read=_p2p) + if test_mori: + ms.shmem_barrier_all() + profile_op(op_ref, "mori", inp, wts, idx, wc_buf, k, + rank, world_size, dev, args.iters, out_dir, meta, + dtype_key=args.dtype, quant_type=args.quant_type, + use_p2p_read=_p2p) + if rank == 0: + print(f"\n[profiler] 全部结果已保存到: {out_dir}/") + + elif args.mode == "profile" and args.cudagraph: + _p2p = not args.use_external_inp_buf + if test_flydsl: + profile_cudagraph_op(op_fly, "flydsl", inp, wts, idx, wc_buf, k, + rank, world_size, dev, args.warmup, args.iters, + out_dir, meta, + scales=scales, packed_recv_x=packed_recv_x, + dtype_key=args.dtype, quant_type=args.quant_type, + use_p2p_read=_p2p) + if test_mori: + ms.shmem_barrier_all() + profile_cudagraph_op(op_ref, "mori", inp, wts, idx, wc_buf, k, + rank, world_size, dev, args.warmup, args.iters, + out_dir, meta, + dtype_key=args.dtype, quant_type=args.quant_type, + use_p2p_read=_p2p) + if rank == 0: + print(f"\n[profiler] 全部结果已保存到: {out_dir}/") + + +# ─── Worker / 命令行入口 ────────────────────────────────────────────────────── +def _worker(rank, world_size, args, master_port): + setup_distributed(rank, world_size, master_port) + try: + run_profiler(rank, world_size, args) + except Exception as e: + import traceback as tb + print(f"[rank {rank}] ERROR: {e}") + tb.print_exc() + finally: + cleanup() + + +def _parse_args(): + p = argparse.ArgumentParser(description="torch.profiler 分析 dispatch/combine") + p.add_argument("--world-size", type=int, default=8) + p.add_argument("--max-tokens", type=int, default=512) + p.add_argument("--hidden-dim", type=int, default=7168) + p.add_argument("--num-experts-per-rank", type=int, default=32) + p.add_argument("--k", type=int, default=8) + p.add_argument("--block-num", type=int, default=80) + p.add_argument("--warp-per-block", type=int, default=4) + p.add_argument("--mori-block-num", type=int, default=0, + help="mori 专用 block_num(0=与FlyDSL相同,mori默认最优=80)") + p.add_argument("--mori-warp-per-block", type=int, default=0, + help="mori 专用 warp_per_block(0=与FlyDSL相同,mori默认最优=8)") + p.add_argument("--chip", type=str, default="gfx950") + p.add_argument("--dtype", type=str, default="bf16", + choices=list(DTYPE_MAP.keys()), + help="数据类型(默认 bf16)") + p.add_argument("--warmup", type=int, default=5, + help="预热轮次(不进 profiler,确保 JIT 编译完成)") + p.add_argument("--iters", type=int, default=5, + help="profiler 采集轮次") + p.add_argument("--output-dir", type=str, default="dispatch_profile", + help="JSON 输出根目录(相对当前目录),子目录按 ep{ws}_bs{tok} 命名") + p.add_argument("--port", type=int, default=29800) + p.add_argument("--no-compare", dest="compare", action="store_false") + # ── 模式选择 ────────────────────────────────────────────────────────────── + p.add_argument("--mode", choices=["profile", "bench", "verify"], default="profile", + help="测量方式:profile=torch.profiler 采集(默认); bench=CUDA Event 计时; verify=正确性验证") + p.add_argument("--cudagraph", action="store_true", + help="使用 CUDAGraph capture+replay 执行(默认 eager)") + p.add_argument("--bench-op", choices=["flydsl", "mori", "both"], default="both", + help="测哪个算子(默认 both)") + # ── 功能开关 ────────────────────────────────────────────────────────────── + p.add_argument("--no-external-inp-buf", dest="use_external_inp_buf", + action="store_false", default=True, + help="使用 P2P Read combine 变体(默认使用 external inp buf)") + p.add_argument("--enable-std-moe", action="store_true", default=False, + help="启用 Standard MoE Adapt 模式") + p.add_argument("--scale-dim", type=int, default=0, + help="Scale 张量维度(0=不使用 scale)") + p.add_argument("--scale-type-size", type=int, default=0, + help="Scale 类型大小(字节,0=不使用 scale)") + p.add_argument("--quant-type", type=str, default="none", + choices=["none", "fp8_direct_cast"], + help="量化类型(none=默认,fp8_direct_cast=FP8直接转换combine)") + p.set_defaults(compare=True) + return p.parse_args() + + +def main(): + args = _parse_args() + if "LOCAL_RANK" in os.environ: + rank = int(os.environ["LOCAL_RANK"]) + world_size = int(os.environ.get("WORLD_SIZE", args.world_size)) + _worker(rank, world_size, args, master_port=args.port) + else: + ws = min(args.world_size, torch.cuda.device_count()) + if ws < args.world_size: + print(f"[warn] 可用 GPU={torch.cuda.device_count()}, " + f"world_size 调整: {args.world_size} → {ws}") + torch.multiprocessing.spawn( + _worker, args=(ws, args, args.port), + nprocs=ws, join=True, + ) + + +if __name__ == "__main__": + main()