Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 16 additions & 1 deletion include/flydsl/Dialect/Fly/IR/FlyOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -435,7 +435,22 @@ def Fly_MakePtrOp : Fly_Op<"make_ptr", []> {
let results = (outs Fly_Pointer:$result);
}
def Fly_GetDynSharedOp : Fly_Op<"get_dyn_shared", [Pure, DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
let arguments = (ins);
let summary = "Pointer to the kernel's dynamic shared memory";
let description = [{
Returns a pointer into the kernel's dynamic shared memory region.

By default, the lowering reuses a single ``__dynamic_shared_*`` LLVM
global for all calls in a kernel. When `sym_name` is provided, the
lowering instead emits a distinct external ``[0 x i8]`` LDS global
with that exact name. Multiple named bases all alias the same
runtime LDS region (each starts at offset 0 of the dynamic LDS
area), but their distinct LLVM symbols give the
``fly-attach-lds-alias-scope`` pass the provenance it needs to
attach ``alias_scope``/``noalias`` metadata, which lets AMDGPU's SI
Wait Counter pass elide defensive ``s_waitcnt vmcnt(N)`` between
accesses through different names.
}];
let arguments = (ins OptionalAttr<StrAttr>:$sym_name);
let results = (outs Fly_Pointer:$result);
let assemblyFormat = "`(` `)` attr-dict `:` qualified(type($result))";
}
Expand Down
26 changes: 26 additions & 0 deletions include/flydsl/Dialect/Fly/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -103,4 +103,30 @@ def FlyPromoteRegMemToVectorSSAPass : Pass<"fly-promote-regmem-to-vectorssa"> {
];
}

def FlyAttachLDSAliasScopePass : Pass<"fly-attach-lds-alias-scope", "::mlir::gpu::GPUModuleOp"> {
let summary = "Attach alias scope metadata to dyn-shared LDS accesses";
let description = [{
Walks every external `[0 x i8] addrspace(3)` LLVM global in the
`gpu.module` (the dyn-shared LDS bases produced by
`fly.get_dyn_shared(sym_name="...")`) and attaches per-symbol
`alias_scopes` / `noalias_scopes` metadata to every load, store,
and `llvm.amdgcn.raw.ptr.buffer.load.lds` call whose addrspace(3)
pointer can be statically traced back to a single global through
`addressof / ptrtoint / add / inttoptr / getelementptr`.

Without this metadata, AMDGPU's `LowerModuleLDS` pass collapses
the named dyn-shared globals into one underlying allocation and
`SIInsertWaitcnts` then conservatively serialises every cross-name
LDS access with `s_waitcnt vmcnt(N)`. The metadata flows through
the merge and lets the SI Wait Counter pass treat distinct-named
accesses as no-alias, restoring static-LDS-class scheduling.

Single-global modules are skipped (no benefit).
}];

let dependentDialects = [
"LLVM::LLVMDialect"
];
}

#endif // FLY_PASSES
263 changes: 132 additions & 131 deletions kernels/fp8_gemm_4wave.py

Large diffs are not rendered by default.

39 changes: 29 additions & 10 deletions lib/Conversion/FlyToROCDL/FlyToROCDL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,8 @@ class GetDynSharedOpLowering : public OpConversionPattern<GetDynSharedOp> {
if (!moduleOp)
return op->emitError("get_dyn_shared must be inside a gpu.module");

LLVM::GlobalOp sharedGlobal = getOrCreateDynSharedGlobal(rewriter, moduleOp, loc, addrSpace);
LLVM::GlobalOp sharedGlobal =
getOrCreateDynSharedGlobal(rewriter, moduleOp, loc, addrSpace, op.getSymNameAttr());

OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPoint(op);
Expand All @@ -142,21 +143,39 @@ class GetDynSharedOpLowering : public OpConversionPattern<GetDynSharedOp> {
private:
static LLVM::GlobalOp getOrCreateDynSharedGlobal(ConversionPatternRewriter &rewriter,
gpu::GPUModuleOp moduleOp, Location loc,
unsigned addrSpace) {
unsigned addrSpace,
StringAttr requestedName) {
// When sym_name is requested we look up by exact name and create a
// distinct external [0 x i8] LDS global if missing. Otherwise we
// reuse the first existing matching dyn-shared global, falling back
// to a freshly generated `__dynamic_shared_<n>` symbol.
llvm::StringSet<> existingNames;
LLVM::GlobalOp firstMatch = nullptr;
for (auto globalOp : moduleOp.getBody()->getOps<LLVM::GlobalOp>()) {
existingNames.insert(globalOp.getSymName());
if (auto arrayType = dyn_cast<LLVM::LLVMArrayType>(globalOp.getType())) {
if (globalOp.getAddrSpace() == addrSpace && arrayType.getNumElements() == 0 &&
globalOp.getAlignment().value_or(0) == 1024)
return globalOp;
if (requestedName && globalOp.getSymName() == requestedName.getValue())
return globalOp;
if (!requestedName) {
if (auto arrayType = dyn_cast<LLVM::LLVMArrayType>(globalOp.getType())) {
if (!firstMatch && globalOp.getAddrSpace() == addrSpace &&
arrayType.getNumElements() == 0 &&
globalOp.getAlignment().value_or(0) == 1024)
firstMatch = globalOp;
}
}
}
if (!requestedName && firstMatch)
return firstMatch;

unsigned counter = 0;
SmallString<128> symName = SymbolTable::generateSymbolName<128>(
"__dynamic_shared_", [&](StringRef candidate) { return existingNames.contains(candidate); },
counter);
SmallString<128> symName;
if (requestedName) {
symName.assign(requestedName.getValue());
} else {
unsigned counter = 0;
symName = SymbolTable::generateSymbolName<128>(
"__dynamic_shared_",
[&](StringRef candidate) { return existingNames.contains(candidate); }, counter);
}

OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPointToStart(moduleOp.getBody());
Expand Down
2 changes: 2 additions & 0 deletions lib/Dialect/Fly/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ add_mlir_dialect_library(MLIRFlyDialect
Transforms/ConvertAtomCallToSSAForm.cpp
Transforms/PromoteRegMemToVectorSSA.cpp
Transforms/IntSwizzleSimplify.cpp
Transforms/AttachLDSAliasScope.cpp

DEPENDS
MLIRFlyIncGen
Expand All @@ -24,5 +25,6 @@ add_mlir_dialect_library(MLIRFlyDialect
LINK_LIBS
MLIRGPUDialect
MLIRIR
MLIRLLVMDialect
MLIRTargetLLVMIRExport
)
266 changes: 266 additions & 0 deletions lib/Dialect/Fly/Transforms/AttachLDSAliasScope.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,266 @@
// SPDX-License-Identifier: Apache-2.0
// Copyright (c) 2026 FlyDSL Project Contributors

#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/LLVMIR/LLVMAttrs.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/Pass/Pass.h"

#include "flydsl/Dialect/Fly/Transforms/Passes.h"

#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringRef.h"

using namespace mlir;

namespace mlir {
namespace fly {
#define GEN_PASS_DEF_FLYATTACHLDSALIASSCOPEPASS
#include "flydsl/Dialect/Fly/Transforms/Passes.h.inc"
} // namespace fly
} // namespace mlir

namespace {

// LDS address space on AMDGPU.
static constexpr unsigned kLDSAddrSpace = 3;

/// Returns true if `g` is an external `[0 x i8] addrspace(3)` global,
/// i.e. a dyn-shared LDS base. We restrict on size 0 (HSA dynamic LDS
/// convention) so we don't accidentally tag SmemAllocator-style static
/// globals (whose alias info already comes from distinct symbols).
static bool isDynSharedGlobal(LLVM::GlobalOp g) {
if (g.getAddrSpace() != kLDSAddrSpace)
return false;
if (g.getLinkage() != LLVM::Linkage::External)
return false;
auto arrTy = dyn_cast<LLVM::LLVMArrayType>(g.getType());
if (!arrTy)
return false;
return arrTy.getNumElements() == 0;
}

/// Per-SSA-value provenance, encoded as a tri-state DenseMap:
/// - absent entry => unknown / not derived from any tracked global
/// - mapped to G => derived from exactly the LDS global G
/// - mapped to nullptr => *known* to mix two or more globals (ambiguous);
/// downstream uses that consume this value must
/// also be marked ambiguous so the pass never
/// tags an access with a single scope when its
/// true scope set is larger.
using PtrProvenance = llvm::DenseMap<Value, LLVM::GlobalOp>;
using IntProvenance = llvm::DenseMap<Value, LLVM::GlobalOp>;

/// True iff `op` is an `llvm.amdgcn.raw.ptr.buffer.load.lds` intrinsic.
static bool isBufferLoadLDS(LLVM::CallOp call) {
auto callee = call.getCallee();
if (!callee)
return false;
return callee->starts_with("llvm.amdgcn.raw.ptr.buffer.load.lds");
}

/// Returns the addrspace(3) pointer operand consumed by `op`, or
/// nullptr if there isn't exactly one such operand worth tagging.
static Value memoryPointerForOp(Operation *op) {
if (auto load = dyn_cast<LLVM::LoadOp>(op)) {
auto ptrTy = dyn_cast<LLVM::LLVMPointerType>(load.getAddr().getType());
if (ptrTy && ptrTy.getAddressSpace() == kLDSAddrSpace)
return load.getAddr();
return nullptr;
}
if (auto store = dyn_cast<LLVM::StoreOp>(op)) {
auto ptrTy = dyn_cast<LLVM::LLVMPointerType>(store.getAddr().getType());
if (ptrTy && ptrTy.getAddressSpace() == kLDSAddrSpace)
return store.getAddr();
return nullptr;
}
if (auto call = dyn_cast<LLVM::CallOp>(op)) {
if (!isBufferLoadLDS(call))
return nullptr;
// The LDS pointer is the second arg (after the buffer-desc ptr).
if (call.getNumOperands() >= 2) {
Value lds = call.getOperand(1);
auto ptrTy = dyn_cast<LLVM::LLVMPointerType>(lds.getType());
if (ptrTy && ptrTy.getAddressSpace() == kLDSAddrSpace)
return lds;
}
return nullptr;
}
return nullptr;
}

/// Forward dataflow that maps SSA values back to the LDS global they
/// derive from. Only the canonical pointer-arithmetic chain is tracked
/// so that we never tag an access whose pointer might really span more
/// than one global:
/// * `LLVM::AddressOfOp(@g)` -> ptr provenance(@g)
/// * `LLVM::PtrToIntOp(p)` -> int provenance(p)
/// * `LLVM::AddOp(a, b)` -> int provenance(a) iff *exactly one*
/// operand carries provenance; if both
/// carry provenance, the result mixes
/// globals and is recorded as ambiguous
/// * `LLVM::IntToPtrOp(i)` -> ptr provenance(i)
/// * `LLVM::GEPOp(p)` -> ptr provenance(p)
///
/// `or`/`sub`/`xor`/`and`/`shl`/`shr`/`bitcast` and any other op are
/// treated as provenance-destroying. The dataflow is intentionally
/// fail-safe: when in doubt, drop the tag rather than emit one that
/// could wrongly tell LLVM "no alias" about pointers that really do
/// alias at runtime.
static void computeProvenance(
LLVM::LLVMFuncOp func,
const llvm::DenseMap<StringRef, LLVM::GlobalOp> &nameToGlobal,
PtrProvenance &ptrProv, IntProvenance &intProv) {
// Tri-state DenseMap merge. Mirrors the encoding documented on
// `IntProvenance` / `PtrProvenance`:
// - absent entry => unknown
// - present, G => provenance(G)
// - present, null => ambiguous
//
// Returns (resultProvenance, hasInfo). When hasInfo is false the
// caller stores nothing (keeps the value unknown); when hasInfo is
// true and resultProvenance is null the caller stores a sentinel
// entry so subsequent uses also propagate as ambiguous.
auto combine = [](LLVM::GlobalOp a, bool aSeen, LLVM::GlobalOp b,
bool bSeen) -> std::pair<LLVM::GlobalOp, bool> {
if (!aSeen && !bSeen)
return {nullptr, false};
if (!aSeen)
return {b, true};
if (!bSeen)
return {a, true};
// Both operands have known provenance entries.
// - either is ambiguous (null) -> ambiguous
// - same non-null global -> *still* ambiguous, because adding
// a pointer-derived int to itself doesn't represent any single
// well-formed pointer
// - different non-null globals -> ambiguous
if (!a || !b || a != b)
return {nullptr, true};
return {nullptr, true}; // see comment above (a == b case)
};

func.walk<WalkOrder::PreOrder>([&](Operation *op) {
if (auto addrOf = dyn_cast<LLVM::AddressOfOp>(op)) {
auto it = nameToGlobal.find(addrOf.getGlobalName());
if (it != nameToGlobal.end())
ptrProv[addrOf.getResult()] = it->second;
return;
}
if (auto p2i = dyn_cast<LLVM::PtrToIntOp>(op)) {
auto it = ptrProv.find(p2i.getArg());
if (it != ptrProv.end())
intProv[p2i.getResult()] = it->second; // may store ambiguous
return;
}
if (auto add = dyn_cast<LLVM::AddOp>(op)) {
auto la = intProv.find(add.getLhs());
auto lb = intProv.find(add.getRhs());
bool aSeen = la != intProv.end();
bool bSeen = lb != intProv.end();
auto [g, hasInfo] = combine(aSeen ? la->second : nullptr, aSeen,
bSeen ? lb->second : nullptr, bSeen);
if (hasInfo)
intProv[add.getResult()] = g; // g may be null = ambiguous sentinel
return;
}
if (auto i2p = dyn_cast<LLVM::IntToPtrOp>(op)) {
auto it = intProv.find(i2p.getArg());
if (it != intProv.end()) {
auto ptrTy = dyn_cast<LLVM::LLVMPointerType>(i2p.getResult().getType());
if (ptrTy && ptrTy.getAddressSpace() == kLDSAddrSpace)
ptrProv[i2p.getResult()] = it->second; // propagate ambiguous too
}
return;
}
if (auto gep = dyn_cast<LLVM::GEPOp>(op)) {
auto it = ptrProv.find(gep.getBase());
if (it != ptrProv.end()) {
auto ptrTy = dyn_cast<LLVM::LLVMPointerType>(gep.getResult().getType());
if (ptrTy && ptrTy.getAddressSpace() == kLDSAddrSpace)
ptrProv[gep.getResult()] = it->second;
}
return;
}
});
}

class FlyAttachLDSAliasScopePass
: public mlir::fly::impl::FlyAttachLDSAliasScopePassBase<
FlyAttachLDSAliasScopePass> {
public:
using mlir::fly::impl::FlyAttachLDSAliasScopePassBase<
FlyAttachLDSAliasScopePass>::FlyAttachLDSAliasScopePassBase;

void runOnOperation() override {
gpu::GPUModuleOp gpuModule = getOperation();

// Collect dyn-shared globals in declaration order.
SmallVector<LLVM::GlobalOp> dynGlobals;
llvm::DenseMap<StringRef, LLVM::GlobalOp> nameToGlobal;
for (auto g : gpuModule.getOps<LLVM::GlobalOp>()) {
if (isDynSharedGlobal(g)) {
dynGlobals.push_back(g);
nameToGlobal[g.getSymName()] = g;
}
}
if (dynGlobals.size() < 2)
return; // Single (or no) dyn-shared region: nothing to disambiguate.

MLIRContext *ctx = &getContext();
OpBuilder builder(ctx);

// One domain per gpu.module, one scope per dyn-shared global.
auto domain = LLVM::AliasScopeDomainAttr::get(
ctx, builder.getStringAttr("FlyDynSharedDomain"));

llvm::DenseMap<LLVM::GlobalOp, LLVM::AliasScopeAttr> globalToScope;
for (auto g : dynGlobals) {
auto scope = LLVM::AliasScopeAttr::get(
domain, builder.getStringAttr(g.getSymName()));
globalToScope[g] = scope;
}

// Pre-compute the noalias-set per global = all scopes except its
// own. This is what makes cross-global accesses no-alias.
llvm::DenseMap<LLVM::GlobalOp, ArrayAttr> globalToNoalias;
for (auto g : dynGlobals) {
SmallVector<Attribute> others;
others.reserve(dynGlobals.size() - 1);
for (auto og : dynGlobals)
if (og != g)
others.push_back(globalToScope[og]);
globalToNoalias[g] = ArrayAttr::get(ctx, others);
}

for (auto func : gpuModule.getOps<LLVM::LLVMFuncOp>()) {
if (func.empty())
continue;
PtrProvenance ptrProv;
IntProvenance intProv;
computeProvenance(func, nameToGlobal, ptrProv, intProv);

func.walk([&](Operation *op) {
Value lds = memoryPointerForOp(op);
if (!lds)
return;
auto it = ptrProv.find(lds);
if (it == ptrProv.end() || !it->second)
return;
LLVM::GlobalOp g = it->second;
auto scopeIt = globalToScope.find(g);
auto noaliasIt = globalToNoalias.find(g);
if (scopeIt == globalToScope.end() || noaliasIt == globalToNoalias.end())
return;
auto scopeAttr = ArrayAttr::get(ctx, {scopeIt->second});
op->setAttr("alias_scopes", scopeAttr);
op->setAttr("noalias_scopes", noaliasIt->second);
});
}
}
};

} // namespace
Loading
Loading