Skip to content

Fly lds alias scope#519

Closed
coderfeli wants to merge 10 commits into
mainfrom
fly-lds-alias-scope
Closed

Fly lds alias scope#519
coderfeli wants to merge 10 commits into
mainfrom
fly-lds-alias-scope

Conversation

@coderfeli
Copy link
Copy Markdown
Collaborator

Motivation

Technical Details

Test Plan

Test Result

Submission Checklist

coderfeli and others added 10 commits May 13, 2026 08:02
Without this pass, multiple `fly.get_dyn_shared` bases inside one
kernel collapse to a single LLVM allocation in `LowerModuleLDS`,
which makes `SIInsertWaitcnts` conservatively serialise every
cross-name LDS access with `s_waitcnt vmcnt(N)` and slows the
kernel down by ~3x compared to the static `[N x i8]` SmemAllocator
pattern.

This change adds:

* An optional `sym_name` attribute on `fly.get_dyn_shared` whose
  lowering emits a distinct external `[0 x i8] addrspace(3)` LDS
  global per name (all aliasing the same runtime LDS region).
* A new `fly-attach-lds-alias-scope` pass on `gpu.module` that walks
  every external 0-sized LDS global, gives each one a distinct
  `alias_scope` under a shared `FlyDynSharedDomain`, and tags every
  load / store / `amdgcn.raw.ptr.buffer.load.lds` whose addrspace(3)
  pointer can be statically traced back to a single global through
  `addressof / ptrtoint / add / inttoptr / GEP` with that scope plus
  a noalias-set covering all sibling globals.
* The pass is registered into the ROCm pipeline right after
  `reconcile-unrealized-casts`, so the metadata flows into the LLVM
  IR that `gpu-module-to-binary` hands to AMDGPU codegen.

Verified end-to-end on the FP8 4-wave GEMM:
  static-LDS baseline:                ~535 TFLOPS (117 vmcnt waitcnts)
  named dyn-shared, no pass:          ~180 TFLOPS (334 vmcnt waitcnts)
  named dyn-shared + new pass:        ~535 TFLOPS  (matches baseline)

Static `[N x i8]` LDS globals (SmemAllocator) and single-global
modules are skipped: their alias info already comes from distinct
LLVM symbols, and the pass requires at least two named bases to
have anything to disambiguate.

Co-authored-by: Cursor <cursoragent@cursor.com>
The original dataflow had two soundness gaps that would let the pass
emit a `noalias` annotation about a load whose pointer might really
land in another global's region:

  * `add ptrtoint(@A), ptrtoint(@b)` produced an int with no entry
    in the provenance map (because the previous combine helper only
    stored non-null globals). A subsequent `add %amb, c` then saw
    only one operand with provenance and inherited it, mis-tagging
    every downstream use as belonging to a single global.

  * `or` and `sub` were treated like `add`. `or @g, mask` is only
    addition-equivalent when the operands are bit-disjoint, which we
    can't prove from the IR; `sub` of two pointer-derived ints is a
    `ptrdiff_t`, not a pointer.

The combine helper now uses a tri-state DenseMap (absent / G /
nullptr-sentinel) and explicitly stores the ambiguous sentinel so
downstream `add` / `inttoptr` / `gep` walk the ambiguous tag forward.
`or` / `sub` / `xor` / `and` / `shl` / `shr` / `bitcast` are no
longer treated as canonical pointer arithmetic; values flowing
through them lose their provenance and are skipped at the tag site.

Two FileCheck cases lock the new behavior in: an explicit
`add ptrtoint(@A), ptrtoint(@b)` chain stays untagged, and a
`ptrtoint + or` chain stays untagged.

Co-authored-by: Cursor <cursoragent@cursor.com>
…kernel

Add three more lit cases that lock in robustness corners surfaced
by the audit:

  * `phi_block_arg`: a pointer flowing into a block argument (LLVM
    phi) loses provenance, so the load on the merged value stays
    untagged regardless of which predecessor branched in. This is
    important post `convert-scf-to-cf`, where `scf.for` carried
    values become block arguments.
  * `deep_chain`: addressof -> gep -> ptrtoint -> add -> add ->
    inttoptr resolves to the originating global; multi-step add
    chains correctly forward provenance.
  * `mixed_dyn_static`: when a `gpu.module` has both dyn-shared
    `[0 x i8]` globals and SmemAllocator-style `[N x i8]` static
    globals, only the dyn-shared loads receive scopes; static loads
    pass through untouched.

Co-authored-by: Cursor <cursoragent@cursor.com>
…e pass

Replace the 8-allocator SmemAllocator scaffolding and the
stdlib-memref -> inttoptr adapter in ``_lds_dst_at`` with named
``fx.get_dyn_shared(sym_name=...)`` bases. The
``fly-attach-lds-alias-scope`` pass attaches per-symbol alias scopes
so AMDGPU's SI Wait Counter pass treats accesses through different
named bases as no-alias, which gives the same scheduling as the
previous 8-distinct-static-globals layout.

What goes away:
* ``SmemAllocator`` / ``SmemPtr`` plumbing and the finalize-in-jit
  dance in ``launch_gemm``.
* ``_lds_dst_at``'s ``extract_aligned_pointer_as_index +
  index_cast + inttoptr`` adapter -- it existed solely to bridge
  stdlib ``memref<? x f8>`` to a ``fly.tensor`` view that
  ``fx.copy`` accepts. With ``fx.get_dyn_shared`` returning a
  ``fly.ptr`` directly, the bridge collapses to a plain
  ``inttoptr`` of an i32 base + offset.
* ``Vec.load`` of ``vector<16xf8>`` for the LDS->reg path, replaced
  by ``fx.memref_load_vec`` on a ``vector<4xi32>`` view (16 fp8 = 4
  i32) so the lowering also sidesteps the missing LLVM type for
  ``vector<16xf8>``.

The launch wrapper now just declares ``smem=_TOTAL_LDS_BYTES``; the
runtime allocates the dyn-shared region.

Perf across all parametrized shapes is within run-to-run noise of
the pre-refactor static-LDS baseline (~535 / ~1820 / ~2150 / ~2140
TFLOPS on the four shapes).

Co-authored-by: Cursor <cursoragent@cursor.com>
The ``_compute_cluster`` path (BLOCK < 256) now spills each per-atom
Vec(8,i32)/Vec(4,f32) operand into a register-memref fragment and
calls ``fx.gemm`` against a 4-wave 2x2 ``tiled_mma`` instead of
emitting ``fly.mma_atom_call_ssa`` directly.

``fly-convert-atom-call-to-ssa-form`` + ``fly-promote-regmem-to-vectorssa``
elide the alloca / store / load round trip cleanly: the resulting
LLVM IR has zero ``alloca`` for the fragments and the MFMA call
chain stays purely on ``<4 x float>`` SSA, so ISel still maps the
accumulator onto AGPR.

The interleaved BLOCK==256 path keeps the direct
``fly.mma_atom_call_ssa`` route -- its manual per-atom interleaving
with G->LDS / LDS->reg loads is the whole point of the cluster
layout, and ``fx.gemm`` would batch the atoms in a way that
contradicts that schedule.

Perf is within run-to-run noise of baseline on the BLOCK=64 shape
(538-544 TFLOPS) and unchanged on BLOCK=256 paths.

Co-authored-by: Cursor <cursoragent@cursor.com>
Drop the dual mma path. The interleaved BLOCK=256 cluster now
routes its per-atom MFMAs through the same ``_mfma`` helper that
the non-interleaved BLOCK<256 cluster uses, where each call spills
the Vec operands into register-memref fragments and invokes
``fx.gemm`` against the 4-wave 2x2 ``tiled_mma``.

``fly-convert-atom-call-to-ssa-form`` + ``fly-promote-regmem-to-vectorssa``
elide the alloca / store / load round trip for every call site (0
``alloca`` left in the final LLVM IR), keeping the per-atom
accumulator on ``<4 x float>`` SSA values so ISel still maps it to
AGPR. The interleaved cluster's load schedule is preserved because
``_mfma_ABt_one`` still gets called one atom at a time between the
G->LDS / LDS->reg loads.

Drops the now-unused ``MfmaAccum_t`` alias and the
``flydsl._mlir.dialects.fly`` import.

Perf across the four parametrized shapes is within run-to-run
noise of the pre-unification numbers (537-543 / 1836-1839 /
2137-2172 / 2134-2166 TFLOPS).

Co-authored-by: Cursor <cursoragent@cursor.com>
Update the module docstring to reflect the fx.gemm-based MFMA path,
drop the redundant LDS-subbuffer block comment, and trim the rest
of the inline comments down to the non-obvious bits. Inline the
register-fragment element counts (8/8/4) instead of carrying named
constants whose only use was to keep the comments aligned. Net -15
lines and no behavioral change.

Co-authored-by: Cursor <cursoragent@cursor.com>
Now that the LDS base comes from ``fx.get_dyn_shared`` directly
(i.e. already a ``fly.ptr``), the obvious cleanup is to replace the
ptrtoint + add + inttoptr chain inside ``_lds_dst_at`` with
``fx.add_offset`` + ``fx.recast_iter``. That path was tested and
compiled cleanly, but produced a 5-9% perf regression on the
BLOCK=256 shapes (5120: -9%, 8192: -5%, 9728: -7%; BLOCK=64 was
unchanged). The natural route adds an int_tuple wrapping op and a
recast_iter that survives canonicalization, and the back-end then
fails to match the common-base + offset idiom the inttoptr form
exposes.

Keep the inttoptr form and document why so we don't try to "clean
it up" again.

Co-authored-by: Cursor <cursoragent@cursor.com>
Replace the hand-written ``_swizzle_128`` Python helper with one
``CoordSwizzleType`` attribute composed onto two outer layouts:
``_lds_swz_layout`` (row stride = BLOCK_K) for LDS-side accesses
and ``_gl_swz_layout`` (row stride = K) for global-side accesses.
Every swizzled coord becomes ``fx.crd2idx((row, col), layout)``,
unwrapped to a scalar via ``IntTuple.to_py_value()``.

The XOR pattern is the same one the manual helper computed:

  bits[1..3] of dim 0 (row) XOR bits[4..6] of dim 1 (col)

written in CoordSwizzle form as ``(mask=3, base_row=1,
mode_row=[0], base_col=4, mode_col=[1])``.

Perf is within run-to-run noise across all 4 parametrized shapes
(BLOCK=64 and BLOCK=256), with a slight microbenchmark uptick on
some shapes.

Co-authored-by: Cursor <cursoragent@cursor.com>
``_compute_lds_swizzle`` materialised a per-cluster (n_tiles, 2)
table of LDS-swizzled offsets that was then indexed by 8 separate
``_load_one_rt`` calls in ``_interleaved_cluster``. The table was
purely a cache of values that ``_load_one_rt`` could compute on
demand from ``(wave_idx, n_tiles, row_idx, k)`` -- the trace-time
``range_constexpr`` unrolling already serialises every lookup, so
caching them in a Python list buys nothing.

Drop the helper and the ``lds_swz`` argument; ``_load_one_rt`` now
recomputes ``row`` / ``col`` inline and calls
``fx.crd2idx`` directly. Trace-time output is identical; perf on
the BLOCK=256 path is unchanged (8192^3 -> 2140-2165 TFLOPS).

Co-authored-by: Cursor <cursoragent@cursor.com>
@coderfeli coderfeli marked this pull request as draft May 14, 2026 00:36
@coderfeli coderfeli closed this May 16, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant