From 22ae730c39c8ded3cc464ab0ddd4f9583c6ee957 Mon Sep 17 00:00:00 2001 From: coderfeli Date: Wed, 13 May 2026 02:51:02 +0000 Subject: [PATCH 01/10] feat(fly): add fly-attach-lds-alias-scope pass + named dyn-shared Without this pass, multiple `fly.get_dyn_shared` bases inside one kernel collapse to a single LLVM allocation in `LowerModuleLDS`, which makes `SIInsertWaitcnts` conservatively serialise every cross-name LDS access with `s_waitcnt vmcnt(N)` and slows the kernel down by ~3x compared to the static `[N x i8]` SmemAllocator pattern. This change adds: * An optional `sym_name` attribute on `fly.get_dyn_shared` whose lowering emits a distinct external `[0 x i8] addrspace(3)` LDS global per name (all aliasing the same runtime LDS region). * A new `fly-attach-lds-alias-scope` pass on `gpu.module` that walks every external 0-sized LDS global, gives each one a distinct `alias_scope` under a shared `FlyDynSharedDomain`, and tags every load / store / `amdgcn.raw.ptr.buffer.load.lds` whose addrspace(3) pointer can be statically traced back to a single global through `addressof / ptrtoint / add / inttoptr / GEP` with that scope plus a noalias-set covering all sibling globals. * The pass is registered into the ROCm pipeline right after `reconcile-unrealized-casts`, so the metadata flows into the LLVM IR that `gpu-module-to-binary` hands to AMDGPU codegen. Verified end-to-end on the FP8 4-wave GEMM: static-LDS baseline: ~535 TFLOPS (117 vmcnt waitcnts) named dyn-shared, no pass: ~180 TFLOPS (334 vmcnt waitcnts) named dyn-shared + new pass: ~535 TFLOPS (matches baseline) Static `[N x i8]` LDS globals (SmemAllocator) and single-global modules are skipped: their alias info already comes from distinct LLVM symbols, and the pass requires at least two named bases to have anything to disambiguate. Co-authored-by: Cursor --- include/flydsl/Dialect/Fly/IR/FlyOps.td | 17 +- .../flydsl/Dialect/Fly/Transforms/Passes.td | 26 ++ lib/Conversion/FlyToROCDL/FlyToROCDL.cpp | 39 ++- lib/Dialect/Fly/CMakeLists.txt | 2 + .../Fly/Transforms/AttachLDSAliasScope.cpp | 250 ++++++++++++++++++ python/flydsl/compiler/backends/rocm.py | 1 + python/flydsl/expr/primitive.py | 15 +- .../Transforms/attach_lds_alias_scope.mlir | 86 ++++++ 8 files changed, 423 insertions(+), 13 deletions(-) create mode 100644 lib/Dialect/Fly/Transforms/AttachLDSAliasScope.cpp create mode 100644 tests/mlir/Transforms/attach_lds_alias_scope.mlir diff --git a/include/flydsl/Dialect/Fly/IR/FlyOps.td b/include/flydsl/Dialect/Fly/IR/FlyOps.td index 36706f4aa..b1e74ec52 100644 --- a/include/flydsl/Dialect/Fly/IR/FlyOps.td +++ b/include/flydsl/Dialect/Fly/IR/FlyOps.td @@ -435,7 +435,22 @@ def Fly_MakePtrOp : Fly_Op<"make_ptr", []> { let results = (outs Fly_Pointer:$result); } def Fly_GetDynSharedOp : Fly_Op<"get_dyn_shared", [Pure, DeclareOpInterfaceMethods]> { - let arguments = (ins); + let summary = "Pointer to the kernel's dynamic shared memory"; + let description = [{ + Returns a pointer into the kernel's dynamic shared memory region. + + By default, the lowering reuses a single ``__dynamic_shared_*`` LLVM + global for all calls in a kernel. When `sym_name` is provided, the + lowering instead emits a distinct external ``[0 x i8]`` LDS global + with that exact name. Multiple named bases all alias the same + runtime LDS region (each starts at offset 0 of the dynamic LDS + area), but their distinct LLVM symbols give the + ``fly-attach-lds-alias-scope`` pass the provenance it needs to + attach ``alias_scope``/``noalias`` metadata, which lets AMDGPU's SI + Wait Counter pass elide defensive ``s_waitcnt vmcnt(N)`` between + accesses through different names. + }]; + let arguments = (ins OptionalAttr:$sym_name); let results = (outs Fly_Pointer:$result); let assemblyFormat = "`(` `)` attr-dict `:` qualified(type($result))"; } diff --git a/include/flydsl/Dialect/Fly/Transforms/Passes.td b/include/flydsl/Dialect/Fly/Transforms/Passes.td index 44a34f565..e04fea3db 100644 --- a/include/flydsl/Dialect/Fly/Transforms/Passes.td +++ b/include/flydsl/Dialect/Fly/Transforms/Passes.td @@ -103,4 +103,30 @@ def FlyPromoteRegMemToVectorSSAPass : Pass<"fly-promote-regmem-to-vectorssa"> { ]; } +def FlyAttachLDSAliasScopePass : Pass<"fly-attach-lds-alias-scope", "::mlir::gpu::GPUModuleOp"> { + let summary = "Attach alias scope metadata to dyn-shared LDS accesses"; + let description = [{ + Walks every external `[0 x i8] addrspace(3)` LLVM global in the + `gpu.module` (the dyn-shared LDS bases produced by + `fly.get_dyn_shared(sym_name="...")`) and attaches per-symbol + `alias_scopes` / `noalias_scopes` metadata to every load, store, + and `llvm.amdgcn.raw.ptr.buffer.load.lds` call whose addrspace(3) + pointer can be statically traced back to a single global through + `addressof / ptrtoint / add / inttoptr / getelementptr`. + + Without this metadata, AMDGPU's `LowerModuleLDS` pass collapses + the named dyn-shared globals into one underlying allocation and + `SIInsertWaitcnts` then conservatively serialises every cross-name + LDS access with `s_waitcnt vmcnt(N)`. The metadata flows through + the merge and lets the SI Wait Counter pass treat distinct-named + accesses as no-alias, restoring static-LDS-class scheduling. + + Single-global modules are skipped (no benefit). + }]; + + let dependentDialects = [ + "LLVM::LLVMDialect" + ]; +} + #endif // FLY_PASSES diff --git a/lib/Conversion/FlyToROCDL/FlyToROCDL.cpp b/lib/Conversion/FlyToROCDL/FlyToROCDL.cpp index 7985ca428..c3a32b963 100644 --- a/lib/Conversion/FlyToROCDL/FlyToROCDL.cpp +++ b/lib/Conversion/FlyToROCDL/FlyToROCDL.cpp @@ -123,7 +123,8 @@ class GetDynSharedOpLowering : public OpConversionPattern { if (!moduleOp) return op->emitError("get_dyn_shared must be inside a gpu.module"); - LLVM::GlobalOp sharedGlobal = getOrCreateDynSharedGlobal(rewriter, moduleOp, loc, addrSpace); + LLVM::GlobalOp sharedGlobal = + getOrCreateDynSharedGlobal(rewriter, moduleOp, loc, addrSpace, op.getSymNameAttr()); OpBuilder::InsertionGuard guard(rewriter); rewriter.setInsertionPoint(op); @@ -142,21 +143,39 @@ class GetDynSharedOpLowering : public OpConversionPattern { private: static LLVM::GlobalOp getOrCreateDynSharedGlobal(ConversionPatternRewriter &rewriter, gpu::GPUModuleOp moduleOp, Location loc, - unsigned addrSpace) { + unsigned addrSpace, + StringAttr requestedName) { + // When sym_name is requested we look up by exact name and create a + // distinct external [0 x i8] LDS global if missing. Otherwise we + // reuse the first existing matching dyn-shared global, falling back + // to a freshly generated `__dynamic_shared_` symbol. llvm::StringSet<> existingNames; + LLVM::GlobalOp firstMatch = nullptr; for (auto globalOp : moduleOp.getBody()->getOps()) { existingNames.insert(globalOp.getSymName()); - if (auto arrayType = dyn_cast(globalOp.getType())) { - if (globalOp.getAddrSpace() == addrSpace && arrayType.getNumElements() == 0 && - globalOp.getAlignment().value_or(0) == 1024) - return globalOp; + if (requestedName && globalOp.getSymName() == requestedName.getValue()) + return globalOp; + if (!requestedName) { + if (auto arrayType = dyn_cast(globalOp.getType())) { + if (!firstMatch && globalOp.getAddrSpace() == addrSpace && + arrayType.getNumElements() == 0 && + globalOp.getAlignment().value_or(0) == 1024) + firstMatch = globalOp; + } } } + if (!requestedName && firstMatch) + return firstMatch; - unsigned counter = 0; - SmallString<128> symName = SymbolTable::generateSymbolName<128>( - "__dynamic_shared_", [&](StringRef candidate) { return existingNames.contains(candidate); }, - counter); + SmallString<128> symName; + if (requestedName) { + symName.assign(requestedName.getValue()); + } else { + unsigned counter = 0; + symName = SymbolTable::generateSymbolName<128>( + "__dynamic_shared_", + [&](StringRef candidate) { return existingNames.contains(candidate); }, counter); + } OpBuilder::InsertionGuard guard(rewriter); rewriter.setInsertionPointToStart(moduleOp.getBody()); diff --git a/lib/Dialect/Fly/CMakeLists.txt b/lib/Dialect/Fly/CMakeLists.txt index 65e500d34..455de6e20 100644 --- a/lib/Dialect/Fly/CMakeLists.txt +++ b/lib/Dialect/Fly/CMakeLists.txt @@ -15,6 +15,7 @@ add_mlir_dialect_library(MLIRFlyDialect Transforms/ConvertAtomCallToSSAForm.cpp Transforms/PromoteRegMemToVectorSSA.cpp Transforms/IntSwizzleSimplify.cpp + Transforms/AttachLDSAliasScope.cpp DEPENDS MLIRFlyIncGen @@ -24,5 +25,6 @@ add_mlir_dialect_library(MLIRFlyDialect LINK_LIBS MLIRGPUDialect MLIRIR + MLIRLLVMDialect MLIRTargetLLVMIRExport ) diff --git a/lib/Dialect/Fly/Transforms/AttachLDSAliasScope.cpp b/lib/Dialect/Fly/Transforms/AttachLDSAliasScope.cpp new file mode 100644 index 000000000..bc835f91d --- /dev/null +++ b/lib/Dialect/Fly/Transforms/AttachLDSAliasScope.cpp @@ -0,0 +1,250 @@ +// SPDX-License-Identifier: Apache-2.0 +// Copyright (c) 2026 FlyDSL Project Contributors + +#include "mlir/Dialect/GPU/IR/GPUDialect.h" +#include "mlir/Dialect/LLVMIR/LLVMAttrs.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/Pass/Pass.h" + +#include "flydsl/Dialect/Fly/Transforms/Passes.h" + +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringRef.h" + +using namespace mlir; + +namespace mlir { +namespace fly { +#define GEN_PASS_DEF_FLYATTACHLDSALIASSCOPEPASS +#include "flydsl/Dialect/Fly/Transforms/Passes.h.inc" +} // namespace fly +} // namespace mlir + +namespace { + +// LDS address space on AMDGPU. +static constexpr unsigned kLDSAddrSpace = 3; + +/// Returns true if `g` is an external `[0 x i8] addrspace(3)` global, +/// i.e. a dyn-shared LDS base. We restrict on size 0 (HSA dynamic LDS +/// convention) so we don't accidentally tag SmemAllocator-style static +/// globals (whose alias info already comes from distinct symbols). +static bool isDynSharedGlobal(LLVM::GlobalOp g) { + if (g.getAddrSpace() != kLDSAddrSpace) + return false; + if (g.getLinkage() != LLVM::Linkage::External) + return false; + auto arrTy = dyn_cast(g.getType()); + if (!arrTy) + return false; + return arrTy.getNumElements() == 0; +} + +/// Per-SSA-value provenance maps. Absent entry == "no provenance". +/// A null mapped value means "ambiguous (mixes multiple globals)". +using PtrProvenance = llvm::DenseMap; +using IntProvenance = llvm::DenseMap; + +/// True iff `op` is an `llvm.amdgcn.raw.ptr.buffer.load.lds` intrinsic. +static bool isBufferLoadLDS(LLVM::CallOp call) { + auto callee = call.getCallee(); + if (!callee) + return false; + return callee->starts_with("llvm.amdgcn.raw.ptr.buffer.load.lds"); +} + +/// Returns the addrspace(3) pointer operand consumed by `op`, or +/// nullptr if there isn't exactly one such operand worth tagging. +static Value memoryPointerForOp(Operation *op) { + if (auto load = dyn_cast(op)) { + auto ptrTy = dyn_cast(load.getAddr().getType()); + if (ptrTy && ptrTy.getAddressSpace() == kLDSAddrSpace) + return load.getAddr(); + return nullptr; + } + if (auto store = dyn_cast(op)) { + auto ptrTy = dyn_cast(store.getAddr().getType()); + if (ptrTy && ptrTy.getAddressSpace() == kLDSAddrSpace) + return store.getAddr(); + return nullptr; + } + if (auto call = dyn_cast(op)) { + if (!isBufferLoadLDS(call)) + return nullptr; + // The LDS pointer is the second arg (after the buffer-desc ptr). + if (call.getNumOperands() >= 2) { + Value lds = call.getOperand(1); + auto ptrTy = dyn_cast(lds.getType()); + if (ptrTy && ptrTy.getAddressSpace() == kLDSAddrSpace) + return lds; + } + return nullptr; + } + return nullptr; +} + +/// Forward dataflow that maps SSA values back to the LDS global they +/// derive from. Only tracks the patterns we care about: +/// * `LLVM::AddressOfOp(@g)` -> ptr provenance(@g) +/// * `LLVM::PtrToIntOp(p)` -> int provenance(p) +/// * `LLVM::AddOp(a, b)` / `LLVM::OrOp(a, b)` / `LLVM::SubOp(a, b)` -> +/// int provenance(a)|provenance(b) (single non-null wins; +/// conflict marks ambiguous so we don't tag downstream uses) +/// * `LLVM::IntToPtrOp(i)` -> ptr provenance(i) +/// * `LLVM::GEPOp(p)` -> ptr provenance(p) +static void computeProvenance( + LLVM::LLVMFuncOp func, + const llvm::DenseMap &nameToGlobal, + PtrProvenance &ptrProv, IntProvenance &intProv) { + // Combine two provenance entries. Returns (global, hasInfo) where + // hasInfo=false means "still no info" and a null global with + // hasInfo=true means "ambiguous". + auto combine = [](LLVM::GlobalOp a, bool aSeen, LLVM::GlobalOp b, + bool bSeen) -> std::pair { + if (!aSeen && !bSeen) + return {nullptr, false}; + if (!aSeen) + return {b, true}; + if (!bSeen) + return {a, true}; + if (a == b) + return {a, true}; + return {nullptr, true}; // ambiguous + }; + + func.walk([&](Operation *op) { + if (auto addrOf = dyn_cast(op)) { + auto it = nameToGlobal.find(addrOf.getGlobalName()); + if (it != nameToGlobal.end()) + ptrProv[addrOf.getResult()] = it->second; + return; + } + if (auto p2i = dyn_cast(op)) { + auto it = ptrProv.find(p2i.getArg()); + if (it != ptrProv.end() && it->second) + intProv[p2i.getResult()] = it->second; + return; + } + auto handleAddLike = [&](Value lhs, Value rhs, Value result) { + auto la = intProv.find(lhs); + auto lb = intProv.find(rhs); + bool aSeen = la != intProv.end(); + bool bSeen = lb != intProv.end(); + auto [g, hasInfo] = + combine(aSeen ? la->second : nullptr, aSeen, + bSeen ? lb->second : nullptr, bSeen); + if (hasInfo && g) + intProv[result] = g; + }; + if (auto add = dyn_cast(op)) { + handleAddLike(add.getLhs(), add.getRhs(), add.getResult()); + return; + } + if (auto orOp = dyn_cast(op)) { + handleAddLike(orOp.getLhs(), orOp.getRhs(), orOp.getResult()); + return; + } + if (auto sub = dyn_cast(op)) { + handleAddLike(sub.getLhs(), sub.getRhs(), sub.getResult()); + return; + } + if (auto i2p = dyn_cast(op)) { + auto it = intProv.find(i2p.getArg()); + if (it != intProv.end() && it->second) { + auto ptrTy = dyn_cast(i2p.getResult().getType()); + if (ptrTy && ptrTy.getAddressSpace() == kLDSAddrSpace) + ptrProv[i2p.getResult()] = it->second; + } + return; + } + if (auto gep = dyn_cast(op)) { + auto it = ptrProv.find(gep.getBase()); + if (it != ptrProv.end() && it->second) { + auto ptrTy = dyn_cast(gep.getResult().getType()); + if (ptrTy && ptrTy.getAddressSpace() == kLDSAddrSpace) + ptrProv[gep.getResult()] = it->second; + } + return; + } + }); +} + +class FlyAttachLDSAliasScopePass + : public mlir::fly::impl::FlyAttachLDSAliasScopePassBase< + FlyAttachLDSAliasScopePass> { +public: + using mlir::fly::impl::FlyAttachLDSAliasScopePassBase< + FlyAttachLDSAliasScopePass>::FlyAttachLDSAliasScopePassBase; + + void runOnOperation() override { + gpu::GPUModuleOp gpuModule = getOperation(); + + // Collect dyn-shared globals in declaration order. + SmallVector dynGlobals; + llvm::DenseMap nameToGlobal; + for (auto g : gpuModule.getOps()) { + if (isDynSharedGlobal(g)) { + dynGlobals.push_back(g); + nameToGlobal[g.getSymName()] = g; + } + } + if (dynGlobals.size() < 2) + return; // Single (or no) dyn-shared region: nothing to disambiguate. + + MLIRContext *ctx = &getContext(); + OpBuilder builder(ctx); + + // One domain per gpu.module, one scope per dyn-shared global. + auto domain = LLVM::AliasScopeDomainAttr::get( + ctx, builder.getStringAttr("FlyDynSharedDomain")); + + llvm::DenseMap globalToScope; + for (auto g : dynGlobals) { + auto scope = LLVM::AliasScopeAttr::get( + domain, builder.getStringAttr(g.getSymName())); + globalToScope[g] = scope; + } + + // Pre-compute the noalias-set per global = all scopes except its + // own. This is what makes cross-global accesses no-alias. + llvm::DenseMap globalToNoalias; + for (auto g : dynGlobals) { + SmallVector others; + others.reserve(dynGlobals.size() - 1); + for (auto og : dynGlobals) + if (og != g) + others.push_back(globalToScope[og]); + globalToNoalias[g] = ArrayAttr::get(ctx, others); + } + + for (auto func : gpuModule.getOps()) { + if (func.empty()) + continue; + PtrProvenance ptrProv; + IntProvenance intProv; + computeProvenance(func, nameToGlobal, ptrProv, intProv); + + func.walk([&](Operation *op) { + Value lds = memoryPointerForOp(op); + if (!lds) + return; + auto it = ptrProv.find(lds); + if (it == ptrProv.end() || !it->second) + return; + LLVM::GlobalOp g = it->second; + auto scopeIt = globalToScope.find(g); + auto noaliasIt = globalToNoalias.find(g); + if (scopeIt == globalToScope.end() || noaliasIt == globalToNoalias.end()) + return; + auto scopeAttr = ArrayAttr::get(ctx, {scopeIt->second}); + op->setAttr("alias_scopes", scopeAttr); + op->setAttr("noalias_scopes", noaliasIt->second); + }); + } + } +}; + +} // namespace diff --git a/python/flydsl/compiler/backends/rocm.py b/python/flydsl/compiler/backends/rocm.py index c32a328bf..f12001dce 100644 --- a/python/flydsl/compiler/backends/rocm.py +++ b/python/flydsl/compiler/backends/rocm.py @@ -84,6 +84,7 @@ def _pipeline_parts(self, *, compile_hints: dict) -> Tuple[List[str], str]: "convert-arith-to-llvm", "convert-func-to-llvm", "reconcile-unrealized-casts", + "gpu.module(fly-attach-lds-alias-scope)", *( ["ensure-debug-info-scope-on-llvm-func{emission-kind=LineTablesOnly}"] if env.debug.enable_debug_info diff --git a/python/flydsl/expr/primitive.py b/python/flydsl/expr/primitive.py index f67d9dc65..fb989e570 100644 --- a/python/flydsl/expr/primitive.py +++ b/python/flydsl/expr/primitive.py @@ -999,8 +999,19 @@ def make_ptr(result_type, args, loc=None, ip=None): @traced_op -def get_dyn_shared(loc=None, ip=None): - return fly.get_dyn_shared(loc=loc, ip=ip) +def get_dyn_shared(sym_name=None, loc=None, ip=None): + """Get a base pointer into the kernel's dynamic shared memory. + + If ``sym_name`` is provided the lowering emits a distinct external + ``[0 x i8] addrspace(3) align 1024`` global with that exact name. + All named bases share the same runtime LDS region (each starts at + offset 0 of the dynamic LDS area) but the + ``fly-attach-lds-alias-scope`` pass uses the distinct symbols to + attach ``alias_scope``/``noalias`` metadata, so AMDGPU's + ``SIInsertWaitcnts`` pass treats accesses through different names + as no-alias even though ``LowerModuleLDS`` later merges them. + """ + return fly.get_dyn_shared(sym_name=sym_name, loc=loc, ip=ip) @traced_op diff --git a/tests/mlir/Transforms/attach_lds_alias_scope.mlir b/tests/mlir/Transforms/attach_lds_alias_scope.mlir new file mode 100644 index 000000000..4719e37f7 --- /dev/null +++ b/tests/mlir/Transforms/attach_lds_alias_scope.mlir @@ -0,0 +1,86 @@ +// SPDX-License-Identifier: Apache-2.0 +// Copyright (c) 2026 FlyDSL Project Contributors +// RUN: %fly-opt %s --pass-pipeline='builtin.module(gpu.module(fly-attach-lds-alias-scope))' | FileCheck %s + +// fly-attach-lds-alias-scope finds external `[0 x i8] addrspace(3)` +// LDS globals in a gpu.module, gives each one a distinct alias scope, +// and tags every load / store / amdgcn.raw.ptr.buffer.load.lds whose +// addrspace(3) pointer can be traced back to a single global through +// addressof / ptrtoint / add / inttoptr / GEP. + +// ----------------------------------------------------------------------------- +// Two named dyn-shared globals -> per-symbol alias_scopes / noalias_scopes on +// loads, with the int-derived pointer being recognised through +// ptrtoint+add+inttoptr. +// ----------------------------------------------------------------------------- + +// CHECK-DAG: #[[DOMAIN:.+]] = #llvm.alias_scope_domain<{{.*}}description = "FlyDynSharedDomain"> +// CHECK-DAG: #[[SCOPE_A:.+]] = #llvm.alias_scope<{{.*}}domain = #[[DOMAIN]], description = "buf_a"> +// CHECK-DAG: #[[SCOPE_B:.+]] = #llvm.alias_scope<{{.*}}domain = #[[DOMAIN]], description = "buf_b"> + +// CHECK-LABEL: gpu.module @two_named +gpu.module @two_named { + llvm.mlir.global external @buf_a() {addr_space = 3 : i32, alignment = 1024 : i64, dso_local} : !llvm.array<0 x i8> + llvm.mlir.global external @buf_b() {addr_space = 3 : i32, alignment = 1024 : i64, dso_local} : !llvm.array<0 x i8> + + // CHECK-LABEL: llvm.func @load_pair + llvm.func @load_pair(%off: i32) -> vector<4xi32> { + %a_ptr = llvm.mlir.addressof @buf_a : !llvm.ptr<3> + %b_ptr = llvm.mlir.addressof @buf_b : !llvm.ptr<3> + %a_int = llvm.ptrtoint %a_ptr : !llvm.ptr<3> to i32 + %b_int = llvm.ptrtoint %b_ptr : !llvm.ptr<3> to i32 + %a_off = llvm.add %a_int, %off : i32 + %b_off = llvm.add %b_int, %off : i32 + %a_p = llvm.inttoptr %a_off : i32 to !llvm.ptr<3> + %b_p = llvm.inttoptr %b_off : i32 to !llvm.ptr<3> + // CHECK: llvm.load %{{.+}} {alias_scopes = [#[[SCOPE_A]]], noalias_scopes = [#[[SCOPE_B]]]} + %va = llvm.load %a_p : !llvm.ptr<3> -> vector<4xi32> + // CHECK: llvm.load %{{.+}} {alias_scopes = [#[[SCOPE_B]]], noalias_scopes = [#[[SCOPE_A]]]} + %vb = llvm.load %b_p : !llvm.ptr<3> -> vector<4xi32> + %sum = llvm.add %va, %vb : vector<4xi32> + llvm.return %sum : vector<4xi32> + } +} + +// ----- + +// ----------------------------------------------------------------------------- +// Single-global module: pass is a no-op. Tagging a single scope gives the +// SI Wait Counter pass nothing extra to disambiguate. +// ----------------------------------------------------------------------------- + +// CHECK-LABEL: gpu.module @one_named +gpu.module @one_named { + llvm.mlir.global external @only_buf() {addr_space = 3 : i32, alignment = 1024 : i64, dso_local} : !llvm.array<0 x i8> + + // CHECK-LABEL: llvm.func @load_only + llvm.func @load_only() -> vector<4xi32> { + %p = llvm.mlir.addressof @only_buf : !llvm.ptr<3> + // CHECK: llvm.load + // CHECK-NOT: alias_scopes + %v = llvm.load %p : !llvm.ptr<3> -> vector<4xi32> + llvm.return %v : vector<4xi32> + } +} + +// ----- + +// ----------------------------------------------------------------------------- +// Static [N x i8] LDS globals (N > 0, the SmemAllocator pattern) are skipped. +// Their alias info already comes from distinct LLVM symbols. +// ----------------------------------------------------------------------------- + +// CHECK-LABEL: gpu.module @static_lds +gpu.module @static_lds { + llvm.mlir.global external @smem_a() {addr_space = 3 : i32, alignment = 1024 : i64, dso_local} : !llvm.array<4096 x i8> + llvm.mlir.global external @smem_b() {addr_space = 3 : i32, alignment = 1024 : i64, dso_local} : !llvm.array<4096 x i8> + + // CHECK-LABEL: llvm.func @load_static + llvm.func @load_static() -> vector<4xi32> { + %p = llvm.mlir.addressof @smem_a : !llvm.ptr<3> + // CHECK: llvm.load + // CHECK-NOT: alias_scopes + %v = llvm.load %p : !llvm.ptr<3> -> vector<4xi32> + llvm.return %v : vector<4xi32> + } +} From 3412141f72d1d1f142b7c269e3be6bfd8ec9a8da Mon Sep 17 00:00:00 2001 From: coderfeli Date: Wed, 13 May 2026 04:56:17 +0000 Subject: [PATCH 02/10] fix(fly-attach-lds-alias-scope): propagate ambiguous, drop or/sub The original dataflow had two soundness gaps that would let the pass emit a `noalias` annotation about a load whose pointer might really land in another global's region: * `add ptrtoint(@A), ptrtoint(@B)` produced an int with no entry in the provenance map (because the previous combine helper only stored non-null globals). A subsequent `add %amb, c` then saw only one operand with provenance and inherited it, mis-tagging every downstream use as belonging to a single global. * `or` and `sub` were treated like `add`. `or @G, mask` is only addition-equivalent when the operands are bit-disjoint, which we can't prove from the IR; `sub` of two pointer-derived ints is a `ptrdiff_t`, not a pointer. The combine helper now uses a tri-state DenseMap (absent / G / nullptr-sentinel) and explicitly stores the ambiguous sentinel so downstream `add` / `inttoptr` / `gep` walk the ambiguous tag forward. `or` / `sub` / `xor` / `and` / `shl` / `shr` / `bitcast` are no longer treated as canonical pointer arithmetic; values flowing through them lose their provenance and are skipped at the tag site. Two FileCheck cases lock the new behavior in: an explicit `add ptrtoint(@A), ptrtoint(@B)` chain stays untagged, and a `ptrtoint + or` chain stays untagged. Co-authored-by: Cursor --- .../Fly/Transforms/AttachLDSAliasScope.cpp | 94 +++++++++++-------- .../Transforms/attach_lds_alias_scope.mlir | 57 +++++++++++ 2 files changed, 112 insertions(+), 39 deletions(-) diff --git a/lib/Dialect/Fly/Transforms/AttachLDSAliasScope.cpp b/lib/Dialect/Fly/Transforms/AttachLDSAliasScope.cpp index bc835f91d..7f6ab0d6f 100644 --- a/lib/Dialect/Fly/Transforms/AttachLDSAliasScope.cpp +++ b/lib/Dialect/Fly/Transforms/AttachLDSAliasScope.cpp @@ -43,8 +43,14 @@ static bool isDynSharedGlobal(LLVM::GlobalOp g) { return arrTy.getNumElements() == 0; } -/// Per-SSA-value provenance maps. Absent entry == "no provenance". -/// A null mapped value means "ambiguous (mixes multiple globals)". +/// Per-SSA-value provenance, encoded as a tri-state DenseMap: +/// - absent entry => unknown / not derived from any tracked global +/// - mapped to G => derived from exactly the LDS global G +/// - mapped to nullptr => *known* to mix two or more globals (ambiguous); +/// downstream uses that consume this value must +/// also be marked ambiguous so the pass never +/// tags an access with a single scope when its +/// true scope set is larger. using PtrProvenance = llvm::DenseMap; using IntProvenance = llvm::DenseMap; @@ -87,21 +93,37 @@ static Value memoryPointerForOp(Operation *op) { } /// Forward dataflow that maps SSA values back to the LDS global they -/// derive from. Only tracks the patterns we care about: +/// derive from. Only the canonical pointer-arithmetic chain is tracked +/// so that we never tag an access whose pointer might really span more +/// than one global: /// * `LLVM::AddressOfOp(@g)` -> ptr provenance(@g) -/// * `LLVM::PtrToIntOp(p)` -> int provenance(p) -/// * `LLVM::AddOp(a, b)` / `LLVM::OrOp(a, b)` / `LLVM::SubOp(a, b)` -> -/// int provenance(a)|provenance(b) (single non-null wins; -/// conflict marks ambiguous so we don't tag downstream uses) -/// * `LLVM::IntToPtrOp(i)` -> ptr provenance(i) -/// * `LLVM::GEPOp(p)` -> ptr provenance(p) +/// * `LLVM::PtrToIntOp(p)` -> int provenance(p) +/// * `LLVM::AddOp(a, b)` -> int provenance(a) iff *exactly one* +/// operand carries provenance; if both +/// carry provenance, the result mixes +/// globals and is recorded as ambiguous +/// * `LLVM::IntToPtrOp(i)` -> ptr provenance(i) +/// * `LLVM::GEPOp(p)` -> ptr provenance(p) +/// +/// `or`/`sub`/`xor`/`and`/`shl`/`shr`/`bitcast` and any other op are +/// treated as provenance-destroying. The dataflow is intentionally +/// fail-safe: when in doubt, drop the tag rather than emit one that +/// could wrongly tell LLVM "no alias" about pointers that really do +/// alias at runtime. static void computeProvenance( LLVM::LLVMFuncOp func, const llvm::DenseMap &nameToGlobal, PtrProvenance &ptrProv, IntProvenance &intProv) { - // Combine two provenance entries. Returns (global, hasInfo) where - // hasInfo=false means "still no info" and a null global with - // hasInfo=true means "ambiguous". + // Tri-state DenseMap merge. Mirrors the encoding documented on + // `IntProvenance` / `PtrProvenance`: + // - absent entry => unknown + // - present, G => provenance(G) + // - present, null => ambiguous + // + // Returns (resultProvenance, hasInfo). When hasInfo is false the + // caller stores nothing (keeps the value unknown); when hasInfo is + // true and resultProvenance is null the caller stores a sentinel + // entry so subsequent uses also propagate as ambiguous. auto combine = [](LLVM::GlobalOp a, bool aSeen, LLVM::GlobalOp b, bool bSeen) -> std::pair { if (!aSeen && !bSeen) @@ -110,9 +132,15 @@ static void computeProvenance( return {b, true}; if (!bSeen) return {a, true}; - if (a == b) - return {a, true}; - return {nullptr, true}; // ambiguous + // Both operands have known provenance entries. + // - either is ambiguous (null) -> ambiguous + // - same non-null global -> *still* ambiguous, because adding + // a pointer-derived int to itself doesn't represent any single + // well-formed pointer + // - different non-null globals -> ambiguous + if (!a || !b || a != b) + return {nullptr, true}; + return {nullptr, true}; // see comment above (a == b case) }; func.walk([&](Operation *op) { @@ -124,45 +152,33 @@ static void computeProvenance( } if (auto p2i = dyn_cast(op)) { auto it = ptrProv.find(p2i.getArg()); - if (it != ptrProv.end() && it->second) - intProv[p2i.getResult()] = it->second; + if (it != ptrProv.end()) + intProv[p2i.getResult()] = it->second; // may store ambiguous return; } - auto handleAddLike = [&](Value lhs, Value rhs, Value result) { - auto la = intProv.find(lhs); - auto lb = intProv.find(rhs); + if (auto add = dyn_cast(op)) { + auto la = intProv.find(add.getLhs()); + auto lb = intProv.find(add.getRhs()); bool aSeen = la != intProv.end(); bool bSeen = lb != intProv.end(); - auto [g, hasInfo] = - combine(aSeen ? la->second : nullptr, aSeen, - bSeen ? lb->second : nullptr, bSeen); - if (hasInfo && g) - intProv[result] = g; - }; - if (auto add = dyn_cast(op)) { - handleAddLike(add.getLhs(), add.getRhs(), add.getResult()); - return; - } - if (auto orOp = dyn_cast(op)) { - handleAddLike(orOp.getLhs(), orOp.getRhs(), orOp.getResult()); - return; - } - if (auto sub = dyn_cast(op)) { - handleAddLike(sub.getLhs(), sub.getRhs(), sub.getResult()); + auto [g, hasInfo] = combine(aSeen ? la->second : nullptr, aSeen, + bSeen ? lb->second : nullptr, bSeen); + if (hasInfo) + intProv[add.getResult()] = g; // g may be null = ambiguous sentinel return; } if (auto i2p = dyn_cast(op)) { auto it = intProv.find(i2p.getArg()); - if (it != intProv.end() && it->second) { + if (it != intProv.end()) { auto ptrTy = dyn_cast(i2p.getResult().getType()); if (ptrTy && ptrTy.getAddressSpace() == kLDSAddrSpace) - ptrProv[i2p.getResult()] = it->second; + ptrProv[i2p.getResult()] = it->second; // propagate ambiguous too } return; } if (auto gep = dyn_cast(op)) { auto it = ptrProv.find(gep.getBase()); - if (it != ptrProv.end() && it->second) { + if (it != ptrProv.end()) { auto ptrTy = dyn_cast(gep.getResult().getType()); if (ptrTy && ptrTy.getAddressSpace() == kLDSAddrSpace) ptrProv[gep.getResult()] = it->second; diff --git a/tests/mlir/Transforms/attach_lds_alias_scope.mlir b/tests/mlir/Transforms/attach_lds_alias_scope.mlir index 4719e37f7..7decfda47 100644 --- a/tests/mlir/Transforms/attach_lds_alias_scope.mlir +++ b/tests/mlir/Transforms/attach_lds_alias_scope.mlir @@ -84,3 +84,60 @@ gpu.module @static_lds { llvm.return %v : vector<4xi32> } } + +// ----- + +// ----------------------------------------------------------------------------- +// Ambiguous provenance: an `add` whose lhs is `ptrtoint(@A)` and rhs is +// `ptrtoint(@B)` produces an int that simultaneously carries provenance for +// both globals. Anything downstream must NOT be tagged with a single scope, +// otherwise we'd be telling LLVM "no alias to @B" about a load that may very +// well land in @B's region. +// ----------------------------------------------------------------------------- + +// CHECK-LABEL: gpu.module @ambiguous_add +gpu.module @ambiguous_add { + llvm.mlir.global external @amb_a() {addr_space = 3 : i32, alignment = 1024 : i64, dso_local} : !llvm.array<0 x i8> + llvm.mlir.global external @amb_b() {addr_space = 3 : i32, alignment = 1024 : i64, dso_local} : !llvm.array<0 x i8> + + // CHECK-LABEL: llvm.func @load_ambiguous + llvm.func @load_ambiguous(%c: i32) -> vector<4xi32> { + %a = llvm.mlir.addressof @amb_a : !llvm.ptr<3> + %b = llvm.mlir.addressof @amb_b : !llvm.ptr<3> + %ai = llvm.ptrtoint %a : !llvm.ptr<3> to i32 + %bi = llvm.ptrtoint %b : !llvm.ptr<3> to i32 + %amb = llvm.add %ai, %bi : i32 + %off = llvm.add %amb, %c : i32 + %p = llvm.inttoptr %off : i32 to !llvm.ptr<3> + // CHECK: llvm.load + // CHECK-NOT: alias_scopes + %v = llvm.load %p : !llvm.ptr<3> -> vector<4xi32> + llvm.return %v : vector<4xi32> + } +} + +// ----- + +// ----------------------------------------------------------------------------- +// `or`/`sub`/`xor` are NOT canonical pointer arithmetic via int. Even when +// they happen to be equivalent to `add` they can break provenance, so the +// pass refuses to forward through them. +// ----------------------------------------------------------------------------- + +// CHECK-LABEL: gpu.module @nontracked_op +gpu.module @nontracked_op { + llvm.mlir.global external @nt_a() {addr_space = 3 : i32, alignment = 1024 : i64, dso_local} : !llvm.array<0 x i8> + llvm.mlir.global external @nt_b() {addr_space = 3 : i32, alignment = 1024 : i64, dso_local} : !llvm.array<0 x i8> + + // CHECK-LABEL: llvm.func @load_via_or + llvm.func @load_via_or(%mask: i32) -> vector<4xi32> { + %a = llvm.mlir.addressof @nt_a : !llvm.ptr<3> + %ai = llvm.ptrtoint %a : !llvm.ptr<3> to i32 + %off = llvm.or %ai, %mask : i32 + %p = llvm.inttoptr %off : i32 to !llvm.ptr<3> + // CHECK: llvm.load + // CHECK-NOT: alias_scopes + %v = llvm.load %p : !llvm.ptr<3> -> vector<4xi32> + llvm.return %v : vector<4xi32> + } +} From 67c0e762a2fba3dd41a7aa6bac5b36fcd8cbbb07 Mon Sep 17 00:00:00 2001 From: coderfeli Date: Wed, 13 May 2026 05:06:29 +0000 Subject: [PATCH 03/10] test(fly-attach-lds-alias-scope): cover block-arg, deep chain, mixed kernel Add three more lit cases that lock in robustness corners surfaced by the audit: * `phi_block_arg`: a pointer flowing into a block argument (LLVM phi) loses provenance, so the load on the merged value stays untagged regardless of which predecessor branched in. This is important post `convert-scf-to-cf`, where `scf.for` carried values become block arguments. * `deep_chain`: addressof -> gep -> ptrtoint -> add -> add -> inttoptr resolves to the originating global; multi-step add chains correctly forward provenance. * `mixed_dyn_static`: when a `gpu.module` has both dyn-shared `[0 x i8]` globals and SmemAllocator-style `[N x i8]` static globals, only the dyn-shared loads receive scopes; static loads pass through untouched. Co-authored-by: Cursor --- .../Transforms/attach_lds_alias_scope.mlir | 84 +++++++++++++++++++ 1 file changed, 84 insertions(+) diff --git a/tests/mlir/Transforms/attach_lds_alias_scope.mlir b/tests/mlir/Transforms/attach_lds_alias_scope.mlir index 7decfda47..8d9fdf28f 100644 --- a/tests/mlir/Transforms/attach_lds_alias_scope.mlir +++ b/tests/mlir/Transforms/attach_lds_alias_scope.mlir @@ -141,3 +141,87 @@ gpu.module @nontracked_op { llvm.return %v : vector<4xi32> } } + +// ----- + +// ----------------------------------------------------------------------------- +// Pointer flowing through a block argument (LLVM phi) loses provenance: the +// entry to ^bb1 doesn't know which addressof produced %p, so the load must +// stay untagged regardless of which predecessor branched in. +// ----------------------------------------------------------------------------- + +// CHECK-LABEL: gpu.module @phi_block_arg +gpu.module @phi_block_arg { + llvm.mlir.global external @phi_a() {addr_space = 3 : i32, alignment = 1024 : i64, dso_local} : !llvm.array<0 x i8> + llvm.mlir.global external @phi_b() {addr_space = 3 : i32, alignment = 1024 : i64, dso_local} : !llvm.array<0 x i8> + + // CHECK-LABEL: llvm.func @load_via_phi + llvm.func @load_via_phi(%cond: i1) -> vector<4xi32> { + %a = llvm.mlir.addressof @phi_a : !llvm.ptr<3> + %b = llvm.mlir.addressof @phi_b : !llvm.ptr<3> + llvm.cond_br %cond, ^bb1(%a : !llvm.ptr<3>), ^bb1(%b : !llvm.ptr<3>) + ^bb1(%p: !llvm.ptr<3>): + // CHECK: llvm.load + // CHECK-NOT: alias_scopes + %v = llvm.load %p : !llvm.ptr<3> -> vector<4xi32> + llvm.return %v : vector<4xi32> + } +} + +// ----- + +// ----------------------------------------------------------------------------- +// Deep arithmetic chain through gep + add + inttoptr still resolves to the +// originating global. Two named globals so the pass actually runs (single +// global short-circuits). +// ----------------------------------------------------------------------------- + +// CHECK-LABEL: gpu.module @deep_chain +gpu.module @deep_chain { + llvm.mlir.global external @deep_a() {addr_space = 3 : i32, alignment = 1024 : i64, dso_local} : !llvm.array<0 x i8> + llvm.mlir.global external @deep_b() {addr_space = 3 : i32, alignment = 1024 : i64, dso_local} : !llvm.array<0 x i8> + + // CHECK-LABEL: llvm.func @load_deep + llvm.func @load_deep(%c0: i32, %c1: i32) -> vector<4xi32> { + %a = llvm.mlir.addressof @deep_a : !llvm.ptr<3> + %a_gep = llvm.getelementptr %a[1] : (!llvm.ptr<3>) -> !llvm.ptr<3>, i8 + %a_int = llvm.ptrtoint %a_gep : !llvm.ptr<3> to i32 + %a_off1 = llvm.add %a_int, %c0 : i32 + %a_off2 = llvm.add %a_off1, %c1 : i32 + %a_p = llvm.inttoptr %a_off2 : i32 to !llvm.ptr<3> + // CHECK: llvm.load %{{.+}} {alias_scopes = [#{{.*}}], noalias_scopes = [#{{.*}}]} : !llvm.ptr<3> -> vector<4xi32> + %v = llvm.load %a_p : !llvm.ptr<3> -> vector<4xi32> + llvm.return %v : vector<4xi32> + } +} + +// ----- + +// ----------------------------------------------------------------------------- +// Mixed kernel: dyn-shared (gets tagged) and static [N x i8] (skipped) +// coexist. Only the dyn-shared loads carry alias scopes. +// ----------------------------------------------------------------------------- + +// CHECK-LABEL: gpu.module @mixed_dyn_static +gpu.module @mixed_dyn_static { + llvm.mlir.global external @mix_dyn_a() {addr_space = 3 : i32, alignment = 1024 : i64, dso_local} : !llvm.array<0 x i8> + llvm.mlir.global external @mix_dyn_b() {addr_space = 3 : i32, alignment = 1024 : i64, dso_local} : !llvm.array<0 x i8> + llvm.mlir.global external @mix_static() {addr_space = 3 : i32, alignment = 1024 : i64, dso_local} : !llvm.array<4096 x i8> + + // CHECK-LABEL: llvm.func @load_mixed + llvm.func @load_mixed(%off: i32) -> vector<4xi32> { + %da = llvm.mlir.addressof @mix_dyn_a : !llvm.ptr<3> + %da_i = llvm.ptrtoint %da : !llvm.ptr<3> to i32 + %da_o = llvm.add %da_i, %off : i32 + %da_p = llvm.inttoptr %da_o : i32 to !llvm.ptr<3> + // CHECK: llvm.load %{{.+}} {alias_scopes = [#{{.*}}], noalias_scopes = [#{{.*}}]} : !llvm.ptr<3> -> vector<4xi32> + %v_da = llvm.load %da_p : !llvm.ptr<3> -> vector<4xi32> + + %s = llvm.mlir.addressof @mix_static : !llvm.ptr<3> + // CHECK: llvm.load %{{.+}} : !llvm.ptr<3> -> vector<4xi32> + %v_s = llvm.load %s : !llvm.ptr<3> -> vector<4xi32> + + %sum = llvm.add %v_da, %v_s : vector<4xi32> + llvm.return %sum : vector<4xi32> + } +} From 06059b2eb4e455b3b245599dc266a3311c429f3b Mon Sep 17 00:00:00 2001 From: coderfeli Date: Wed, 13 May 2026 06:05:02 +0000 Subject: [PATCH 04/10] refactor(fp8_gemm_4wave): switch LDS to named dyn-shared + alias-scope pass Replace the 8-allocator SmemAllocator scaffolding and the stdlib-memref -> inttoptr adapter in ``_lds_dst_at`` with named ``fx.get_dyn_shared(sym_name=...)`` bases. The ``fly-attach-lds-alias-scope`` pass attaches per-symbol alias scopes so AMDGPU's SI Wait Counter pass treats accesses through different named bases as no-alias, which gives the same scheduling as the previous 8-distinct-static-globals layout. What goes away: * ``SmemAllocator`` / ``SmemPtr`` plumbing and the finalize-in-jit dance in ``launch_gemm``. * ``_lds_dst_at``'s ``extract_aligned_pointer_as_index + index_cast + inttoptr`` adapter -- it existed solely to bridge stdlib ``memref`` to a ``fly.tensor`` view that ``fx.copy`` accepts. With ``fx.get_dyn_shared`` returning a ``fly.ptr`` directly, the bridge collapses to a plain ``inttoptr`` of an i32 base + offset. * ``Vec.load`` of ``vector<16xf8>`` for the LDS->reg path, replaced by ``fx.memref_load_vec`` on a ``vector<4xi32>`` view (16 fp8 = 4 i32) so the lowering also sidesteps the missing LLVM type for ``vector<16xf8>``. The launch wrapper now just declares ``smem=_TOTAL_LDS_BYTES``; the runtime allocates the dyn-shared region. Perf across all parametrized shapes is within run-to-run noise of the pre-refactor static-LDS baseline (~535 / ~1820 / ~2150 / ~2140 TFLOPS on the four shapes). Co-authored-by: Cursor --- kernels/fp8_gemm_4wave.py | 144 ++++++++++++++++++-------------------- 1 file changed, 67 insertions(+), 77 deletions(-) diff --git a/kernels/fp8_gemm_4wave.py b/kernels/fp8_gemm_4wave.py index d9f4fcd06..e88d2ba29 100644 --- a/kernels/fp8_gemm_4wave.py +++ b/kernels/fp8_gemm_4wave.py @@ -12,19 +12,20 @@ the chained Vec(4, f32) accumulator stays on AGPR. The XOR swizzle and the 8-buffer LDS pipeline ping-pong are kept as direct arithmetic to preserve the original kernel's interleaved-cluster scheduling. + +LDS storage uses 8 named ``fx.get_dyn_shared`` bases carved into one +dyn-shared region; the ``fly-attach-lds-alias-scope`` MLIR pass +attaches per-symbol alias scopes so AMDGPU's SI Wait Counter pass +treats cross-buffer accesses as no-alias. """ import flydsl.compiler as flyc import flydsl.expr as fx -from flydsl._mlir.dialects import arith as _arith_dialect from flydsl._mlir.dialects import fly as _fly_dialect from flydsl._mlir.dialects import llvm as _llvm -from flydsl._mlir.dialects import memref as _memref_dialect from flydsl._mlir.dialects.fly_rocdl import TargetAddressSpace as _TgtAS -from flydsl.compiler.kernel_function import CompilationContext -from flydsl.expr import arith, const_expr, range_constexpr +from flydsl.expr import arith, const_expr, range_constexpr, rocdl from flydsl.expr.typing import Vector as Vec -from flydsl.utils.smem_allocator import SmemAllocator, SmemPtr def _divmod(a, b): @@ -77,26 +78,26 @@ def compile_fp8_gemm(*, M: int, N: int, K: int, BLOCK_M: int = 256, BLOCK_N: int _use_interleaved_block = BLOCK_M == 256 and BLOCK_N == 256 - A_lds_cur0_alloc = SmemAllocator(None, "gfx950", "A_lds_cur_0") - A_lds_cur1_alloc = SmemAllocator(None, "gfx950", "A_lds_cur_1") - A_lds_next0_alloc = SmemAllocator(None, "gfx950", "A_lds_next_0") - A_lds_next1_alloc = SmemAllocator(None, "gfx950", "A_lds_next_1") - B_lds_cur0_alloc = SmemAllocator(None, "gfx950", "B_lds_cur_0") - B_lds_cur1_alloc = SmemAllocator(None, "gfx950", "B_lds_cur_1") - B_lds_next0_alloc = SmemAllocator(None, "gfx950", "B_lds_next_0") - B_lds_next1_alloc = SmemAllocator(None, "gfx950", "B_lds_next_1") - a_lds_size = LDS_BLOCK_M * BLOCK_K b_lds_size = LDS_BLOCK_N * BLOCK_K - A_lds_cur0_alloc.ptr = a_lds_size - A_lds_cur1_alloc.ptr = a_lds_size - A_lds_next0_alloc.ptr = a_lds_size - A_lds_next1_alloc.ptr = a_lds_size - B_lds_cur0_alloc.ptr = b_lds_size - B_lds_cur1_alloc.ptr = b_lds_size - B_lds_next0_alloc.ptr = b_lds_size - B_lds_next1_alloc.ptr = b_lds_size + # 8 disjoint sub-buffers carved out of a single dyn-shared LDS region: + # A_lds_cur_{0,1}, A_lds_next_{0,1}, B_lds_cur_{0,1}, B_lds_next_{0,1}. + # ``fx.get_dyn_shared(sym_name=...)`` emits one external [0 x i8] + # addrspace(3) global per name; ``fly-attach-lds-alias-scope`` + # gives each global its own alias scope so the AMDGPU SI Wait + # Counter pass treats cross-name accesses as no-alias. + _LDS_SUBBUFS = [ + ("A_lds_cur_0", 0 * a_lds_size), + ("A_lds_cur_1", 1 * a_lds_size), + ("A_lds_next_0", 2 * a_lds_size), + ("A_lds_next_1", 3 * a_lds_size), + ("B_lds_cur_0", 4 * a_lds_size + 0 * b_lds_size), + ("B_lds_cur_1", 4 * a_lds_size + 1 * b_lds_size), + ("B_lds_next_0", 4 * a_lds_size + 2 * b_lds_size), + ("B_lds_next_1", 4 * a_lds_size + 3 * b_lds_size), + ] + _TOTAL_LDS_BYTES = 4 * a_lds_size + 4 * b_lds_size @flyc.kernel def kernel_gemm( @@ -109,20 +110,25 @@ def kernel_gemm( MfmaAccum_t = Vec.make_type(4, fx.Float32) RT_C_i = Vec.filled(4, 0.0, fx.Float32) F8_IR_t = fx.Float8E4M3FN.ir_type - Vec16_t = Vec.make_type(16, fx.Float8E4M3FN) - - a_cur0 = SmemPtr(A_lds_cur0_alloc.get_base(), 0, F8_IR_t, shape=(a_lds_size,)).get() - a_cur1 = SmemPtr(A_lds_cur1_alloc.get_base(), 0, F8_IR_t, shape=(a_lds_size,)).get() - a_next0 = SmemPtr(A_lds_next0_alloc.get_base(), 0, F8_IR_t, shape=(a_lds_size,)).get() - a_next1 = SmemPtr(A_lds_next1_alloc.get_base(), 0, F8_IR_t, shape=(a_lds_size,)).get() - - b_cur0 = SmemPtr(B_lds_cur0_alloc.get_base(), 0, F8_IR_t, shape=(b_lds_size,)).get() - b_cur1 = SmemPtr(B_lds_cur1_alloc.get_base(), 0, F8_IR_t, shape=(b_lds_size,)).get() - b_next0 = SmemPtr(B_lds_next0_alloc.get_base(), 0, F8_IR_t, shape=(b_lds_size,)).get() - b_next1 = SmemPtr(B_lds_next1_alloc.get_base(), 0, F8_IR_t, shape=(b_lds_size,)).get() _AS_SHARED = 2 - _shared_ptr_ty = fx.PointerType.get(F8_IR_t, _AS_SHARED, 512) + _shared_f8_ptr_ty = fx.PointerType.get(F8_IR_t, _AS_SHARED, 512) + _shared_i32_ptr_ty = fx.PointerType.get(fx.T.i32(), _AS_SHARED, 512) + + # One ptrtoint per named base; per-access offsets are added in i32 + # before ``fx.inttoptr``. ``fly-attach-lds-alias-scope`` traces + # each access back to its base symbol and tags loads / stores / + # buffer_load_lds with the corresponding alias scope. + _lds_int = { + name: fx.ptrtoint(fx.get_dyn_shared(sym_name=name)) + for name, _ in _LDS_SUBBUFS + } + _lds_off = dict(_LDS_SUBBUFS) + + a_cur0, a_cur1 = "A_lds_cur_0", "A_lds_cur_1" + a_next0, a_next1 = "A_lds_next_0", "A_lds_next_1" + b_cur0, b_cur1 = "B_lds_cur_0", "B_lds_cur_1" + b_next0, b_next1 = "B_lds_next_0", "B_lds_next_1" lane_id = fx.thread_idx.x % 64 wave_id = fx.thread_idx.x // 64 @@ -196,35 +202,36 @@ def _compute_lds_swizzle(wave_idx, n_tiles): # state carries the runtime ``soffset`` set to ``k_offset``. g2lds_atom = fx.make_copy_atom(fx.rocdl.BufferCopyLDS128b(), 128) - # LDS dst pointers for ``buffer_load_lds`` go through - # ``extract_aligned_pointer_as_index + add + inttoptr`` to break - # LLVM's alias chain on the LDS sub-buffer symbols; otherwise the - # AMDGPU backend inserts defensive ``s_waitcnt vmcnt(N)`` between - # G->LDS writes and the subsequent ``ds_read``. - def _lds_dst_at(lds_dst_mem, byte_offset_runtime): - base_idx = _memref_dialect.extract_aligned_pointer_as_index(lds_dst_mem) - offset_idx = base_idx + fx.Index(byte_offset_runtime) - offset_i64 = _arith_dialect.index_cast(fx.T.i64(), offset_idx) - lds_ptr = fx.inttoptr(_shared_ptr_ty, offset_i64) - return fx.make_view(lds_ptr, fx.make_layout(1, 1)) - - def _load_lds(gl_src_div, lds_dst_mem, k_offset, gl_offsets, n_tiles): + def _lds_dst_at(name, byte_offset_runtime): + off = _lds_int[name] + fx.Int32(_lds_off[name] + byte_offset_runtime) + ptr = fx.inttoptr(_shared_f8_ptr_ty, off) + return fx.make_view(ptr, fx.make_layout(1, 1)) + + def _load_lds(gl_src_div, name, k_offset, gl_offsets, n_tiles): assert len(gl_offsets) >= n_tiles for step in range_constexpr(n_tiles): src = fx.slice(gl_src_div, (None, fx.Int32(gl_offsets[step]))) - dst = _lds_dst_at(lds_dst_mem, wave_id * 1024 + step * 4096) + dst = _lds_dst_at(name, wave_id * 1024 + step * 4096) fx.copy(g2lds_atom, src, dst, soffset=fx.Int32(k_offset)) - def _load_one_lds(gl_src_div, lds_dst_mem, k_offset, gl_offsets, tile_idx): + def _load_one_lds(gl_src_div, name, k_offset, gl_offsets, tile_idx): assert len(gl_offsets) > tile_idx src = fx.slice(gl_src_div, (None, fx.Int32(gl_offsets[tile_idx]))) - dst = _lds_dst_at(lds_dst_mem, wave_id * 1024 + tile_idx * 4096) + dst = _lds_dst_at(name, wave_id * 1024 + tile_idx * 4096) fx.copy(g2lds_atom, src, dst, soffset=fx.Int32(k_offset)) def _pack_i32x4_i32x8(lo, hi): return lo.shuffle(hi, list(range(8))) - def _load_rt(lds_src, wave_idx, n_tiles): + # 16 fp8 == 4 i32; load via i32-typed ptr to sidestep the missing + # LLVM vector type for vector<16xf8>. + def _vec_load_lds_i32x4(name, fp8_elem_offset): + off = _lds_int[name] + fx.Int32(_lds_off[name] + fp8_elem_offset) + ptr = fx.inttoptr(_shared_i32_ptr_ty, off) + view = fx.make_view(ptr, fx.make_layout(4, 1)) + return Vec(fx.memref_load_vec(view)) + + def _load_rt(name, wave_idx, n_tiles): frag = [] for i in range_constexpr(n_tiles): row = wave_idx * (n_tiles * 16) + i * 16 + lane_id % 16 @@ -232,14 +239,12 @@ def _load_rt(lds_src, wave_idx, n_tiles): for step in range_constexpr(2): col = (lane_id // 16) * 16 + step * 64 r, c = _swizzle_128(row, col) - v = Vec.load(Vec16_t, lds_src, [fx.Index(r * BLOCK_K + c)]) - halves.append(v.bitcast(fx.Int32)) + halves.append(_vec_load_lds_i32x4(name, r * BLOCK_K + c)) frag.append(_pack_i32x4_i32x8(halves[0], halves[1])) return frag - def _load_one_rt(lds_src, lds_swz, row, k): - v = Vec.load(Vec16_t, lds_src, [fx.Index(lds_swz[row][k])]) - return v.bitcast(fx.Int32) + def _load_one_rt(name, lds_swz, row, k): + return _vec_load_lds_i32x4(name, lds_swz[row][k]) def _c_idx(i, j): return i * N_TILES_B + j @@ -525,26 +530,6 @@ def launch_gemm( B_scale: fx.Tensor, stream: fx.Stream, ): - from flydsl._mlir import ir - - A_lds_cur0_alloc.finalized = False - A_lds_cur1_alloc.finalized = False - A_lds_next0_alloc.finalized = False - A_lds_next1_alloc.finalized = False - B_lds_cur0_alloc.finalized = False - B_lds_cur1_alloc.finalized = False - B_lds_next0_alloc.finalized = False - B_lds_next1_alloc.finalized = False - ctx = CompilationContext.get_current() - with ir.InsertionPoint(ctx.gpu_module_body): - A_lds_cur0_alloc.finalize() - A_lds_cur1_alloc.finalize() - A_lds_next0_alloc.finalize() - A_lds_next1_alloc.finalize() - B_lds_cur0_alloc.finalize() - B_lds_cur1_alloc.finalize() - B_lds_next0_alloc.finalize() - B_lds_next1_alloc.finalize() grid_x = (M * N) // (BLOCK_M * BLOCK_N) kernel_gemm( A, @@ -553,6 +538,11 @@ def launch_gemm( A_scale, B_scale, value_attrs={"rocdl.waves_per_eu": 1, "rocdl.flat_work_group_size": "256,256"}, - ).launch(grid=(grid_x, 1, 1), block=(256, 1, 1), stream=stream) + ).launch( + grid=(grid_x, 1, 1), + block=(256, 1, 1), + smem=_TOTAL_LDS_BYTES, + stream=stream, + ) return launch_gemm From 77dedf15592fd9974abf4c04ff650ba1a0e8501d Mon Sep 17 00:00:00 2001 From: coderfeli Date: Wed, 13 May 2026 07:42:14 +0000 Subject: [PATCH 05/10] feat(fp8_gemm_4wave): route non-interleaved MFMAs through fx.gemm The ``_compute_cluster`` path (BLOCK < 256) now spills each per-atom Vec(8,i32)/Vec(4,f32) operand into a register-memref fragment and calls ``fx.gemm`` against a 4-wave 2x2 ``tiled_mma`` instead of emitting ``fly.mma_atom_call_ssa`` directly. ``fly-convert-atom-call-to-ssa-form`` + ``fly-promote-regmem-to-vectorssa`` elide the alloca / store / load round trip cleanly: the resulting LLVM IR has zero ``alloca`` for the fragments and the MFMA call chain stays purely on ``<4 x float>`` SSA, so ISel still maps the accumulator onto AGPR. The interleaved BLOCK==256 path keeps the direct ``fly.mma_atom_call_ssa`` route -- its manual per-atom interleaving with G->LDS / LDS->reg loads is the whole point of the cluster layout, and ``fx.gemm`` would batch the atoms in a way that contradicts that schedule. Perf is within run-to-run noise of baseline on the BLOCK=64 shape (538-544 TFLOPS) and unchanged on BLOCK=256 paths. Co-authored-by: Cursor --- kernels/fp8_gemm_4wave.py | 40 +++++++++++++++++++++++++++++++++++---- 1 file changed, 36 insertions(+), 4 deletions(-) diff --git a/kernels/fp8_gemm_4wave.py b/kernels/fp8_gemm_4wave.py index e88d2ba29..a0b7177c7 100644 --- a/kernels/fp8_gemm_4wave.py +++ b/kernels/fp8_gemm_4wave.py @@ -296,22 +296,54 @@ def _wait_barrier(count): has_side_effects=True, ) - # MFMA via ``fly.mma_atom_call_ssa``. The atom carries scale_a / - # scale_b state (default 0x7F7F7F7F = no scaling). Returns a - # chained Vec(4, f32) SSA so the accumulator stays on AGPR. mma_atom = fx.make_mma_atom(fx.rocdl.cdna4.MFMA_Scale(16, 16, 128, fx.Float8E4M3FN)) + # Non-interleaved path goes through ``fx.gemm``. Each call + # spills the Vec operands into register memref fragments + # (i32x8 for A/B, f32x4 for the accumulator) and pulls the + # accumulator back out; ``fly-convert-atom-call-to-ssa-form`` + + # ``fly-promote-regmem-to-vectorssa`` then elide the alloca / + # store / load round trip and leave a plain + # ``llvm.amdgcn.mfma.scale.f32.16x16x128`` call chained on + # ``<4 x float>`` SSA values, which ISel maps to AGPR. + a_atom_i32_elems = 8 + b_atom_i32_elems = 8 + c_atom_f32_elems = 4 + a_reg_ty = fx.MemRefType.get(fx.T.i32(), fx.LayoutType.get(a_atom_i32_elems, 1), _AS_REG) + b_reg_ty = fx.MemRefType.get(fx.T.i32(), fx.LayoutType.get(b_atom_i32_elems, 1), _AS_REG) + c_reg_ty = fx.MemRefType.get(fx.T.f32(), fx.LayoutType.get(c_atom_f32_elems, 1), _AS_REG) + + # Single-atom tiled_mma over the 4-wave 2x2 layout in (M, N). + tiled_mma_single = fx.make_tiled_mma( + mma_atom, + fx.make_layout((2, 2, 1), (1, 2, 0)), + ) + + # Direct ``fly.mma_atom_call_ssa`` is kept for the interleaved + # BLOCK==256 cluster where the manual MFMA / load schedule + # matters more than the layout-API abstraction. def _mfma(a_val, b_val, c_val): return _fly_dialect.mma_atom_call_ssa([MfmaAccum_t], mma_atom, a_val, b_val, c_val) + def _mfma_fxgemm(a_vec, b_vec, c_vec): + a_mem = fx.memref_alloca(a_reg_ty, fx.make_layout(a_atom_i32_elems, 1)) + b_mem = fx.memref_alloca(b_reg_ty, fx.make_layout(b_atom_i32_elems, 1)) + c_mem = fx.memref_alloca(c_reg_ty, fx.make_layout(c_atom_f32_elems, 1)) + fx.memref_store_vec(a_vec, a_mem) + fx.memref_store_vec(b_vec, b_mem) + fx.memref_store_vec(c_vec, c_mem) + fx.gemm(tiled_mma_single, c_mem, a_mem, b_mem, c_mem) + return Vec(fx.memref_load_vec(c_mem)) + def _mfma_ABt_all(a, b, c): assert len(a) == N_TILES_A assert len(b) == N_TILES_B assert len(c) == N_TILES_A * N_TILES_B + mma = _mfma if const_expr(_use_interleaved_block) else _mfma_fxgemm for i in range_constexpr(N_TILES_A): for j in range_constexpr(N_TILES_B): - c[_c_idx(i, j)] = _mfma(a[i], b[j], c[_c_idx(i, j)]) + c[_c_idx(i, j)] = mma(a[i], b[j], c[_c_idx(i, j)]) return c def _mfma_ABt_one(a, b, c, m, n): From cf97ebe5b75cef49ddaf2cc9fd2a2593b58e08a8 Mon Sep 17 00:00:00 2001 From: coderfeli Date: Wed, 13 May 2026 07:49:31 +0000 Subject: [PATCH 06/10] refactor(fp8_gemm_4wave): unify all MFMAs through fx.gemm Drop the dual mma path. The interleaved BLOCK=256 cluster now routes its per-atom MFMAs through the same ``_mfma`` helper that the non-interleaved BLOCK<256 cluster uses, where each call spills the Vec operands into register-memref fragments and invokes ``fx.gemm`` against the 4-wave 2x2 ``tiled_mma``. ``fly-convert-atom-call-to-ssa-form`` + ``fly-promote-regmem-to-vectorssa`` elide the alloca / store / load round trip for every call site (0 ``alloca`` left in the final LLVM IR), keeping the per-atom accumulator on ``<4 x float>`` SSA values so ISel still maps it to AGPR. The interleaved cluster's load schedule is preserved because ``_mfma_ABt_one`` still gets called one atom at a time between the G->LDS / LDS->reg loads. Drops the now-unused ``MfmaAccum_t`` alias and the ``flydsl._mlir.dialects.fly`` import. Perf across the four parametrized shapes is within run-to-run noise of the pre-unification numbers (537-543 / 1836-1839 / 2137-2172 / 2134-2166 TFLOPS). Co-authored-by: Cursor --- kernels/fp8_gemm_4wave.py | 23 +++++++---------------- 1 file changed, 7 insertions(+), 16 deletions(-) diff --git a/kernels/fp8_gemm_4wave.py b/kernels/fp8_gemm_4wave.py index a0b7177c7..38093dbfa 100644 --- a/kernels/fp8_gemm_4wave.py +++ b/kernels/fp8_gemm_4wave.py @@ -21,7 +21,6 @@ import flydsl.compiler as flyc import flydsl.expr as fx -from flydsl._mlir.dialects import fly as _fly_dialect from flydsl._mlir.dialects import llvm as _llvm from flydsl._mlir.dialects.fly_rocdl import TargetAddressSpace as _TgtAS from flydsl.expr import arith, const_expr, range_constexpr, rocdl @@ -107,7 +106,6 @@ def kernel_gemm( A_scale: fx.Tensor, B_scale: fx.Tensor, ): - MfmaAccum_t = Vec.make_type(4, fx.Float32) RT_C_i = Vec.filled(4, 0.0, fx.Float32) F8_IR_t = fx.Float8E4M3FN.ir_type @@ -298,11 +296,11 @@ def _wait_barrier(count): mma_atom = fx.make_mma_atom(fx.rocdl.cdna4.MFMA_Scale(16, 16, 128, fx.Float8E4M3FN)) - # Non-interleaved path goes through ``fx.gemm``. Each call - # spills the Vec operands into register memref fragments - # (i32x8 for A/B, f32x4 for the accumulator) and pulls the - # accumulator back out; ``fly-convert-atom-call-to-ssa-form`` + - # ``fly-promote-regmem-to-vectorssa`` then elide the alloca / + # All MFMAs go through ``fx.gemm``. Each call spills the Vec + # operands into register memref fragments (i32x8 for A/B, + # f32x4 for the accumulator) and pulls the accumulator back + # out; ``fly-convert-atom-call-to-ssa-form`` + + # ``fly-promote-regmem-to-vectorssa`` elide the alloca / # store / load round trip and leave a plain # ``llvm.amdgcn.mfma.scale.f32.16x16x128`` call chained on # ``<4 x float>`` SSA values, which ISel maps to AGPR. @@ -319,13 +317,7 @@ def _wait_barrier(count): fx.make_layout((2, 2, 1), (1, 2, 0)), ) - # Direct ``fly.mma_atom_call_ssa`` is kept for the interleaved - # BLOCK==256 cluster where the manual MFMA / load schedule - # matters more than the layout-API abstraction. - def _mfma(a_val, b_val, c_val): - return _fly_dialect.mma_atom_call_ssa([MfmaAccum_t], mma_atom, a_val, b_val, c_val) - - def _mfma_fxgemm(a_vec, b_vec, c_vec): + def _mfma(a_vec, b_vec, c_vec): a_mem = fx.memref_alloca(a_reg_ty, fx.make_layout(a_atom_i32_elems, 1)) b_mem = fx.memref_alloca(b_reg_ty, fx.make_layout(b_atom_i32_elems, 1)) c_mem = fx.memref_alloca(c_reg_ty, fx.make_layout(c_atom_f32_elems, 1)) @@ -340,10 +332,9 @@ def _mfma_ABt_all(a, b, c): assert len(b) == N_TILES_B assert len(c) == N_TILES_A * N_TILES_B - mma = _mfma if const_expr(_use_interleaved_block) else _mfma_fxgemm for i in range_constexpr(N_TILES_A): for j in range_constexpr(N_TILES_B): - c[_c_idx(i, j)] = mma(a[i], b[j], c[_c_idx(i, j)]) + c[_c_idx(i, j)] = _mfma(a[i], b[j], c[_c_idx(i, j)]) return c def _mfma_ABt_one(a, b, c, m, n): From ac0371393d203d14fe11f38208b3173d17e81ec1 Mon Sep 17 00:00:00 2001 From: coderfeli Date: Wed, 13 May 2026 07:59:51 +0000 Subject: [PATCH 07/10] clean(fp8_gemm_4wave): tighten docstring and per-block comments Update the module docstring to reflect the fx.gemm-based MFMA path, drop the redundant LDS-subbuffer block comment, and trim the rest of the inline comments down to the non-obvious bits. Inline the register-fragment element counts (8/8/4) instead of carrying named constants whose only use was to keep the comments aligned. Net -15 lines and no behavioral change. Co-authored-by: Cursor --- kernels/fp8_gemm_4wave.py | 86 ++++++++++++++++----------------------- 1 file changed, 36 insertions(+), 50 deletions(-) diff --git a/kernels/fp8_gemm_4wave.py b/kernels/fp8_gemm_4wave.py index 38093dbfa..0e013b884 100644 --- a/kernels/fp8_gemm_4wave.py +++ b/kernels/fp8_gemm_4wave.py @@ -6,12 +6,11 @@ Algorithm derived from HipKittens FP8_4wave (https://github.com/HazyResearch/HipKittens/blob/7782744ba1fd259a377a99e2ea8f71384cc80e55/kernels/gemm/fp8fp32/FP8_4wave/4_wave.cu#L1). -Global IO, scale loads, and bf16 stores go through the layout API -(``fx.rocdl.make_buffer_tensor`` + ``fx.copy`` with ``BufferCopyLDS128b`` -/ ``BufferCopy{16,32,128}b``). MFMAs use ``fly.mma_atom_call_ssa`` so -the chained Vec(4, f32) accumulator stays on AGPR. The XOR swizzle and -the 8-buffer LDS pipeline ping-pong are kept as direct arithmetic to -preserve the original kernel's interleaved-cluster scheduling. +Global IO, scale loads, bf16 stores, and the per-atom MFMA all go +through the layout API (``fx.rocdl.make_buffer_tensor`` + ``fx.copy`` ++ ``fx.gemm``). The XOR swizzle and the 8-buffer LDS pipeline are +kept as direct arithmetic to preserve the kernel's interleaved +cluster scheduling. LDS storage uses 8 named ``fx.get_dyn_shared`` bases carved into one dyn-shared region; the ``fly-attach-lds-alias-scope`` MLIR pass @@ -60,8 +59,7 @@ def _xcd_swizzle(num_pid_m, num_pid_n): def compile_fp8_gemm(*, M: int, N: int, K: int, BLOCK_M: int = 256, BLOCK_N: int = 256, use_xcd_remap: bool = True): - # MFMA atom is 16x16x128; 4 waves in a 2x2 config require BLOCK >= 64. - BLOCK_K = 128 + BLOCK_K = 128 # MFMA_Scale 16x16x128 atom; 4-wave 2x2 layout needs BLOCK >= 64. LDS_BLOCK_M = BLOCK_M // 2 LDS_BLOCK_N = BLOCK_N // 2 assert BLOCK_M >= 64 and BLOCK_N >= 64 @@ -69,7 +67,7 @@ def compile_fp8_gemm(*, M: int, N: int, K: int, BLOCK_M: int = 256, BLOCK_N: int N_BLOCKS = N // BLOCK_N K_ITERS = K // BLOCK_K - # Number of 16-row 16x128 tiles per wave per A/B partition. + # 16-row 16x128 atom tiles per wave per A/B partition. N_TILES_A = BLOCK_M // 4 // 16 N_TILES_B = BLOCK_N // 4 // 16 N_ACCUMS = N_TILES_A * N_TILES_B @@ -80,12 +78,10 @@ def compile_fp8_gemm(*, M: int, N: int, K: int, BLOCK_M: int = 256, BLOCK_N: int a_lds_size = LDS_BLOCK_M * BLOCK_K b_lds_size = LDS_BLOCK_N * BLOCK_K - # 8 disjoint sub-buffers carved out of a single dyn-shared LDS region: - # A_lds_cur_{0,1}, A_lds_next_{0,1}, B_lds_cur_{0,1}, B_lds_next_{0,1}. - # ``fx.get_dyn_shared(sym_name=...)`` emits one external [0 x i8] - # addrspace(3) global per name; ``fly-attach-lds-alias-scope`` - # gives each global its own alias scope so the AMDGPU SI Wait - # Counter pass treats cross-name accesses as no-alias. + # 8 disjoint sub-buffers within one dyn-shared region. Each named + # ``fx.get_dyn_shared`` emits a distinct LDS global so the + # ``fly-attach-lds-alias-scope`` pass can give it its own alias + # scope. _LDS_SUBBUFS = [ ("A_lds_cur_0", 0 * a_lds_size), ("A_lds_cur_1", 1 * a_lds_size), @@ -113,10 +109,6 @@ def kernel_gemm( _shared_f8_ptr_ty = fx.PointerType.get(F8_IR_t, _AS_SHARED, 512) _shared_i32_ptr_ty = fx.PointerType.get(fx.T.i32(), _AS_SHARED, 512) - # One ptrtoint per named base; per-access offsets are added in i32 - # before ``fx.inttoptr``. ``fly-attach-lds-alias-scope`` traces - # each access back to its base symbol and tags loads / stores / - # buffer_load_lds with the corresponding alias scope. _lds_int = { name: fx.ptrtoint(fx.get_dyn_shared(sym_name=name)) for name, _ in _LDS_SUBBUFS @@ -143,9 +135,9 @@ def kernel_gemm( B0_gl_offset = (tile_j * BLOCK_N) * K B1_gl_offset = (tile_j * BLOCK_N + LDS_BLOCK_N) * K - # A/B come in as torch.int8 (PyTorch fp8 view restriction); recast - # the buffer-desc pointer's element type to fp8 so typed copy - # atoms (BufferCopyLDS128b) accept them. + # A/B arrive as torch.int8 (PyTorch fp8 view limitation); recast + # the buffer-desc element type to fp8 so BufferCopyLDS128b takes + # them. def _make_fp8_buf_tensor(arg_i8): t_i8 = fx.rocdl.make_buffer_tensor(arg_i8) iter_i8 = fx.get_iter(t_i8) @@ -168,7 +160,7 @@ def _make_fp8_buf_tensor(arg_i8): sa_div = fx.logical_divide(gSA, fx.make_layout(1, 1)) sb_div = fx.logical_divide(gSB, fx.make_layout(1, 1)) - # XOR bits 4..6 of the tile-local linear offset with bits 8..10. + # XOR bits[4..6] of the tile-local linear offset with bits[8..10]. def _swizzle_128(row, col): offset = row * BLOCK_K + col swz = ((offset % (16 * BLOCK_K)) >> 8) << 4 @@ -196,8 +188,8 @@ def _compute_lds_swizzle(wave_idx, n_tiles): lds_swz.append(swz) return lds_swz - # G->LDS atom: 128 bits per thread = 16 fp8 elements. The atom - # state carries the runtime ``soffset`` set to ``k_offset``. + # 128 bits per thread = 16 fp8 elements; soffset carries the + # runtime k offset. g2lds_atom = fx.make_copy_atom(fx.rocdl.BufferCopyLDS128b(), 128) def _lds_dst_at(name, byte_offset_runtime): @@ -221,8 +213,8 @@ def _load_one_lds(gl_src_div, name, k_offset, gl_offsets, tile_idx): def _pack_i32x4_i32x8(lo, hi): return lo.shuffle(hi, list(range(8))) - # 16 fp8 == 4 i32; load via i32-typed ptr to sidestep the missing - # LLVM vector type for vector<16xf8>. + # 16 fp8 == 4 i32; load as i32x4 because LLVM has no + # vector<16xf8>. def _vec_load_lds_i32x4(name, fp8_elem_offset): off = _lds_int[name] + fx.Int32(_lds_off[name] + fp8_elem_offset) ptr = fx.inttoptr(_shared_i32_ptr_ty, off) @@ -247,9 +239,9 @@ def _load_one_rt(name, lds_swz, row, k): def _c_idx(i, j): return i * N_TILES_B + j - # The C++ AddressSpace enum prepends Generic=0, so the Python - # AddressSpace.Register value (2) maps to Shared on the C++ side. - # Pass the C++ integer (3) directly to MemRefType.get. + # C++ AddressSpace enum prepends Generic=0; pass the C++ index + # (3 = Register) directly to MemRefType.get to avoid the + # Python AddressSpace.Register (=2) being read as Shared. _AS_REG = 3 scale_atom_4 = fx.make_copy_atom(fx.rocdl.BufferCopy128b(), fx.Float32) scale_atom_1 = fx.make_copy_atom(fx.rocdl.BufferCopy32b(), fx.Float32) @@ -296,31 +288,25 @@ def _wait_barrier(count): mma_atom = fx.make_mma_atom(fx.rocdl.cdna4.MFMA_Scale(16, 16, 128, fx.Float8E4M3FN)) - # All MFMAs go through ``fx.gemm``. Each call spills the Vec - # operands into register memref fragments (i32x8 for A/B, - # f32x4 for the accumulator) and pulls the accumulator back - # out; ``fly-convert-atom-call-to-ssa-form`` + - # ``fly-promote-regmem-to-vectorssa`` elide the alloca / - # store / load round trip and leave a plain - # ``llvm.amdgcn.mfma.scale.f32.16x16x128`` call chained on - # ``<4 x float>`` SSA values, which ISel maps to AGPR. - a_atom_i32_elems = 8 - b_atom_i32_elems = 8 - c_atom_f32_elems = 4 - a_reg_ty = fx.MemRefType.get(fx.T.i32(), fx.LayoutType.get(a_atom_i32_elems, 1), _AS_REG) - b_reg_ty = fx.MemRefType.get(fx.T.i32(), fx.LayoutType.get(b_atom_i32_elems, 1), _AS_REG) - c_reg_ty = fx.MemRefType.get(fx.T.f32(), fx.LayoutType.get(c_atom_f32_elems, 1), _AS_REG) - - # Single-atom tiled_mma over the 4-wave 2x2 layout in (M, N). + # MFMA goes through fx.gemm. The Vec operands are spilled into + # register-memref fragments around each call; the alloca / + # store / load round trip is folded away by + # ``fly-convert-atom-call-to-ssa-form`` + + # ``fly-promote-regmem-to-vectorssa``, leaving a plain + # ``llvm.amdgcn.mfma.scale.f32.16x16x128`` chained on + # ``<4 x float>`` SSA so ISel keeps the accumulator on AGPR. + a_reg_ty = fx.MemRefType.get(fx.T.i32(), fx.LayoutType.get(8, 1), _AS_REG) + b_reg_ty = fx.MemRefType.get(fx.T.i32(), fx.LayoutType.get(8, 1), _AS_REG) + c_reg_ty = fx.MemRefType.get(fx.T.f32(), fx.LayoutType.get(4, 1), _AS_REG) tiled_mma_single = fx.make_tiled_mma( mma_atom, fx.make_layout((2, 2, 1), (1, 2, 0)), ) def _mfma(a_vec, b_vec, c_vec): - a_mem = fx.memref_alloca(a_reg_ty, fx.make_layout(a_atom_i32_elems, 1)) - b_mem = fx.memref_alloca(b_reg_ty, fx.make_layout(b_atom_i32_elems, 1)) - c_mem = fx.memref_alloca(c_reg_ty, fx.make_layout(c_atom_f32_elems, 1)) + a_mem = fx.memref_alloca(a_reg_ty, fx.make_layout(8, 1)) + b_mem = fx.memref_alloca(b_reg_ty, fx.make_layout(8, 1)) + c_mem = fx.memref_alloca(c_reg_ty, fx.make_layout(4, 1)) fx.memref_store_vec(a_vec, a_mem) fx.memref_store_vec(b_vec, b_mem) fx.memref_store_vec(c_vec, c_mem) @@ -344,7 +330,7 @@ def _mfma_ABt_one(a, b, c, m, n): return c def _interleaved_cluster(lds_dst, gl_src, k_offset, gl_offsets, wave_idx, lds_src, n_tiles_lds, a, b, c): - # 64x64 output via 4x4 MFMAs, with per-tile G→LDS and LDS→reg + # 4x4 MFMAs over 64x64, with per-tile G->LDS and LDS->reg # loads interleaved between MFMAs to hide latency. rt_dst = [] From 54ef57e7cda526d2a6f60628355438415aa97bd0 Mon Sep 17 00:00:00 2001 From: coderfeli Date: Wed, 13 May 2026 08:30:01 +0000 Subject: [PATCH 08/10] docs(fp8_gemm_4wave): note why _lds_dst_at uses ptrtoint not add_offset Now that the LDS base comes from ``fx.get_dyn_shared`` directly (i.e. already a ``fly.ptr``), the obvious cleanup is to replace the ptrtoint + add + inttoptr chain inside ``_lds_dst_at`` with ``fx.add_offset`` + ``fx.recast_iter``. That path was tested and compiled cleanly, but produced a 5-9% perf regression on the BLOCK=256 shapes (5120: -9%, 8192: -5%, 9728: -7%; BLOCK=64 was unchanged). The natural route adds an int_tuple wrapping op and a recast_iter that survives canonicalization, and the back-end then fails to match the common-base + offset idiom the inttoptr form exposes. Keep the inttoptr form and document why so we don't try to "clean it up" again. Co-authored-by: Cursor --- kernels/fp8_gemm_4wave.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/kernels/fp8_gemm_4wave.py b/kernels/fp8_gemm_4wave.py index 0e013b884..b4d3b0862 100644 --- a/kernels/fp8_gemm_4wave.py +++ b/kernels/fp8_gemm_4wave.py @@ -192,6 +192,12 @@ def _compute_lds_swizzle(wave_idx, n_tiles): # runtime k offset. g2lds_atom = fx.make_copy_atom(fx.rocdl.BufferCopyLDS128b(), 128) + # Routed through ptrtoint + add + inttoptr instead of the + # natural fx.add_offset+fx.recast_iter chain: empirically the + # natural route compiles ~5-9% slower on BLOCK=256 shapes + # (probably because the extra int_tuple / recast_iter ops + # survive canonicalization and disrupt the AMDGPU back-end's + # common-base + offset pattern matching). def _lds_dst_at(name, byte_offset_runtime): off = _lds_int[name] + fx.Int32(_lds_off[name] + byte_offset_runtime) ptr = fx.inttoptr(_shared_f8_ptr_ty, off) From c9192dc28f654f8825ed41381a3a7f4dd67baeb9 Mon Sep 17 00:00:00 2001 From: coderfeli Date: Wed, 13 May 2026 08:59:29 +0000 Subject: [PATCH 09/10] refactor(fp8_gemm_4wave): express XOR swizzle as a CoordSwizzle layout Replace the hand-written ``_swizzle_128`` Python helper with one ``CoordSwizzleType`` attribute composed onto two outer layouts: ``_lds_swz_layout`` (row stride = BLOCK_K) for LDS-side accesses and ``_gl_swz_layout`` (row stride = K) for global-side accesses. Every swizzled coord becomes ``fx.crd2idx((row, col), layout)``, unwrapped to a scalar via ``IntTuple.to_py_value()``. The XOR pattern is the same one the manual helper computed: bits[1..3] of dim 0 (row) XOR bits[4..6] of dim 1 (col) written in CoordSwizzle form as ``(mask=3, base_row=1, mode_row=[0], base_col=4, mode_col=[1])``. Perf is within run-to-run noise across all 4 parametrized shapes (BLOCK=64 and BLOCK=256), with a slight microbenchmark uptick on some shapes. Co-authored-by: Cursor --- kernels/fp8_gemm_4wave.py | 30 ++++++++++++++++++------------ 1 file changed, 18 insertions(+), 12 deletions(-) diff --git a/kernels/fp8_gemm_4wave.py b/kernels/fp8_gemm_4wave.py index b4d3b0862..4d795541a 100644 --- a/kernels/fp8_gemm_4wave.py +++ b/kernels/fp8_gemm_4wave.py @@ -160,20 +160,28 @@ def _make_fp8_buf_tensor(arg_i8): sa_div = fx.logical_divide(gSA, fx.make_layout(1, 1)) sb_div = fx.logical_divide(gSB, fx.make_layout(1, 1)) - # XOR bits[4..6] of the tile-local linear offset with bits[8..10]. - def _swizzle_128(row, col): - offset = row * BLOCK_K + col - swz = ((offset % (16 * BLOCK_K)) >> 8) << 4 - swizzled = offset ^ swz - return swizzled // BLOCK_K, swizzled % BLOCK_K + # XOR 3 bits of dim-0 (row, bit-1 base) with 3 bits of dim-1 + # (col, bit-4 base). Same as the manual + # ((offset>>8)<<4) ^ offset; shared between LDS and global + # access via two outer layouts with different row strides. + _swz_attr = fx.CoordSwizzleType.get(3, 1, [0], 4, [1]) + _swz_shape = (LDS_BLOCK_M, BLOCK_K) + _coord_swz = fx.make_composed_layout( + fx.static(_swz_attr), fx.make_identity_layout(_swz_shape) + ) + _lds_swz_layout = fx.make_composed_layout( + fx.make_layout(_swz_shape, (BLOCK_K, 1)), _coord_swz + ) + _gl_swz_layout = fx.make_composed_layout( + fx.make_layout(_swz_shape, (K, 1)), _coord_swz + ) def _compute_global_swizzle(): offsets = [] for round in range_constexpr(max(N_TILES_A, N_TILES_B)): row = lane_id // 8 + wave_id * 8 + round * 32 col = (lane_id % 8) * 16 - r, c = _swizzle_128(row, col) - offsets.append(r * K + c) + offsets.append(fx.crd2idx((row, col), _gl_swz_layout).to_py_value()) return offsets def _compute_lds_swizzle(wave_idx, n_tiles): @@ -183,8 +191,7 @@ def _compute_lds_swizzle(wave_idx, n_tiles): swz = [] for i in range_constexpr(2): col = (lane_id // 16) * 16 + i * 64 - r, c = _swizzle_128(row, col) - swz.append(r * BLOCK_K + c) + swz.append(fx.crd2idx((row, col), _lds_swz_layout).to_py_value()) lds_swz.append(swz) return lds_swz @@ -234,8 +241,7 @@ def _load_rt(name, wave_idx, n_tiles): halves = [] for step in range_constexpr(2): col = (lane_id // 16) * 16 + step * 64 - r, c = _swizzle_128(row, col) - halves.append(_vec_load_lds_i32x4(name, r * BLOCK_K + c)) + halves.append(_vec_load_lds_i32x4(name, fx.crd2idx((row, col), _lds_swz_layout).to_py_value())) frag.append(_pack_i32x4_i32x8(halves[0], halves[1])) return frag From 9140fe4b02d5dc625a9d6177f2519e4d0c5838a2 Mon Sep 17 00:00:00 2001 From: coderfeli Date: Wed, 13 May 2026 09:05:04 +0000 Subject: [PATCH 10/10] refactor(fp8_gemm_4wave): inline _compute_lds_swizzle into _load_one_rt ``_compute_lds_swizzle`` materialised a per-cluster (n_tiles, 2) table of LDS-swizzled offsets that was then indexed by 8 separate ``_load_one_rt`` calls in ``_interleaved_cluster``. The table was purely a cache of values that ``_load_one_rt`` could compute on demand from ``(wave_idx, n_tiles, row_idx, k)`` -- the trace-time ``range_constexpr`` unrolling already serialises every lookup, so caching them in a Python list buys nothing. Drop the helper and the ``lds_swz`` argument; ``_load_one_rt`` now recomputes ``row`` / ``col`` inline and calls ``fx.crd2idx`` directly. Trace-time output is identical; perf on the BLOCK=256 path is unchanged (8192^3 -> 2140-2165 TFLOPS). Co-authored-by: Cursor --- kernels/fp8_gemm_4wave.py | 34 ++++++++++++---------------------- 1 file changed, 12 insertions(+), 22 deletions(-) diff --git a/kernels/fp8_gemm_4wave.py b/kernels/fp8_gemm_4wave.py index 4d795541a..fc8c48429 100644 --- a/kernels/fp8_gemm_4wave.py +++ b/kernels/fp8_gemm_4wave.py @@ -184,17 +184,6 @@ def _compute_global_swizzle(): offsets.append(fx.crd2idx((row, col), _gl_swz_layout).to_py_value()) return offsets - def _compute_lds_swizzle(wave_idx, n_tiles): - lds_swz = [] - for row_offset in range_constexpr(n_tiles): - row = wave_idx * (n_tiles * 16) + row_offset * 16 + lane_id % 16 - swz = [] - for i in range_constexpr(2): - col = (lane_id // 16) * 16 + i * 64 - swz.append(fx.crd2idx((row, col), _lds_swz_layout).to_py_value()) - lds_swz.append(swz) - return lds_swz - # 128 bits per thread = 16 fp8 elements; soffset carries the # runtime k offset. g2lds_atom = fx.make_copy_atom(fx.rocdl.BufferCopyLDS128b(), 128) @@ -245,8 +234,10 @@ def _load_rt(name, wave_idx, n_tiles): frag.append(_pack_i32x4_i32x8(halves[0], halves[1])) return frag - def _load_one_rt(name, lds_swz, row, k): - return _vec_load_lds_i32x4(name, lds_swz[row][k]) + def _load_one_rt(name, wave_idx, n_tiles, row_idx, k): + row = wave_idx * (n_tiles * 16) + row_idx * 16 + lane_id % 16 + col = (lane_id // 16) * 16 + k * 64 + return _vec_load_lds_i32x4(name, fx.crd2idx((row, col), _lds_swz_layout).to_py_value()) def _c_idx(i, j): return i * N_TILES_B + j @@ -349,48 +340,47 @@ def _interleaved_cluster(lds_dst, gl_src, k_offset, gl_offsets, wave_idx, lds_sr c = _mfma_ABt_one(a, b, c, 0, 0) c = _mfma_ABt_one(a, b, c, 0, 1) - lds_swz = _compute_lds_swizzle(wave_idx, n_tiles_lds) _load_one_lds(gl_src, lds_dst, k_offset, gl_offsets, 0) - rt_dst_0 = _load_one_rt(lds_src, lds_swz, 0, 0) + rt_dst_0 = _load_one_rt(lds_src, wave_idx, n_tiles_lds, 0, 0) c = _mfma_ABt_one(a, b, c, 0, 2) - rt_dst_1 = _load_one_rt(lds_src, lds_swz, 0, 1) + rt_dst_1 = _load_one_rt(lds_src, wave_idx, n_tiles_lds, 0, 1) rt_dst.append(_pack_i32x4_i32x8(rt_dst_0, rt_dst_1)) c = _mfma_ABt_one(a, b, c, 0, 3) _load_one_lds(gl_src, lds_dst, k_offset, gl_offsets, 1) - rt_dst_0 = _load_one_rt(lds_src, lds_swz, 1, 0) + rt_dst_0 = _load_one_rt(lds_src, wave_idx, n_tiles_lds, 1, 0) c = _mfma_ABt_one(a, b, c, 1, 0) c = _mfma_ABt_one(a, b, c, 1, 1) - rt_dst_1 = _load_one_rt(lds_src, lds_swz, 1, 1) + rt_dst_1 = _load_one_rt(lds_src, wave_idx, n_tiles_lds, 1, 1) rt_dst.append(_pack_i32x4_i32x8(rt_dst_0, rt_dst_1)) c = _mfma_ABt_one(a, b, c, 1, 2) c = _mfma_ABt_one(a, b, c, 1, 3) _load_one_lds(gl_src, lds_dst, k_offset, gl_offsets, 2) - rt_dst_0 = _load_one_rt(lds_src, lds_swz, 2, 0) + rt_dst_0 = _load_one_rt(lds_src, wave_idx, n_tiles_lds, 2, 0) c = _mfma_ABt_one(a, b, c, 2, 0) c = _mfma_ABt_one(a, b, c, 2, 1) - rt_dst_1 = _load_one_rt(lds_src, lds_swz, 2, 1) + rt_dst_1 = _load_one_rt(lds_src, wave_idx, n_tiles_lds, 2, 1) rt_dst.append(_pack_i32x4_i32x8(rt_dst_0, rt_dst_1)) c = _mfma_ABt_one(a, b, c, 2, 2) c = _mfma_ABt_one(a, b, c, 2, 3) _load_one_lds(gl_src, lds_dst, k_offset, gl_offsets, 3) - rt_dst_0 = _load_one_rt(lds_src, lds_swz, 3, 0) + rt_dst_0 = _load_one_rt(lds_src, wave_idx, n_tiles_lds, 3, 0) c = _mfma_ABt_one(a, b, c, 3, 0) c = _mfma_ABt_one(a, b, c, 3, 1) - rt_dst_1 = _load_one_rt(lds_src, lds_swz, 3, 1) + rt_dst_1 = _load_one_rt(lds_src, wave_idx, n_tiles_lds, 3, 1) rt_dst.append(_pack_i32x4_i32x8(rt_dst_0, rt_dst_1)) c = _mfma_ABt_one(a, b, c, 3, 2)