diff --git a/include/flydsl/Dialect/Fly/IR/FlyOps.td b/include/flydsl/Dialect/Fly/IR/FlyOps.td index 36706f4a..b1e74ec5 100644 --- a/include/flydsl/Dialect/Fly/IR/FlyOps.td +++ b/include/flydsl/Dialect/Fly/IR/FlyOps.td @@ -435,7 +435,22 @@ def Fly_MakePtrOp : Fly_Op<"make_ptr", []> { let results = (outs Fly_Pointer:$result); } def Fly_GetDynSharedOp : Fly_Op<"get_dyn_shared", [Pure, DeclareOpInterfaceMethods]> { - let arguments = (ins); + let summary = "Pointer to the kernel's dynamic shared memory"; + let description = [{ + Returns a pointer into the kernel's dynamic shared memory region. + + By default, the lowering reuses a single ``__dynamic_shared_*`` LLVM + global for all calls in a kernel. When `sym_name` is provided, the + lowering instead emits a distinct external ``[0 x i8]`` LDS global + with that exact name. Multiple named bases all alias the same + runtime LDS region (each starts at offset 0 of the dynamic LDS + area), but their distinct LLVM symbols give the + ``fly-attach-lds-alias-scope`` pass the provenance it needs to + attach ``alias_scope``/``noalias`` metadata, which lets AMDGPU's SI + Wait Counter pass elide defensive ``s_waitcnt vmcnt(N)`` between + accesses through different names. + }]; + let arguments = (ins OptionalAttr:$sym_name); let results = (outs Fly_Pointer:$result); let assemblyFormat = "`(` `)` attr-dict `:` qualified(type($result))"; } diff --git a/include/flydsl/Dialect/Fly/Transforms/Passes.td b/include/flydsl/Dialect/Fly/Transforms/Passes.td index 44a34f56..e04fea3d 100644 --- a/include/flydsl/Dialect/Fly/Transforms/Passes.td +++ b/include/flydsl/Dialect/Fly/Transforms/Passes.td @@ -103,4 +103,30 @@ def FlyPromoteRegMemToVectorSSAPass : Pass<"fly-promote-regmem-to-vectorssa"> { ]; } +def FlyAttachLDSAliasScopePass : Pass<"fly-attach-lds-alias-scope", "::mlir::gpu::GPUModuleOp"> { + let summary = "Attach alias scope metadata to dyn-shared LDS accesses"; + let description = [{ + Walks every external `[0 x i8] addrspace(3)` LLVM global in the + `gpu.module` (the dyn-shared LDS bases produced by + `fly.get_dyn_shared(sym_name="...")`) and attaches per-symbol + `alias_scopes` / `noalias_scopes` metadata to every load, store, + and `llvm.amdgcn.raw.ptr.buffer.load.lds` call whose addrspace(3) + pointer can be statically traced back to a single global through + `addressof / ptrtoint / add / inttoptr / getelementptr`. + + Without this metadata, AMDGPU's `LowerModuleLDS` pass collapses + the named dyn-shared globals into one underlying allocation and + `SIInsertWaitcnts` then conservatively serialises every cross-name + LDS access with `s_waitcnt vmcnt(N)`. The metadata flows through + the merge and lets the SI Wait Counter pass treat distinct-named + accesses as no-alias, restoring static-LDS-class scheduling. + + Single-global modules are skipped (no benefit). + }]; + + let dependentDialects = [ + "LLVM::LLVMDialect" + ]; +} + #endif // FLY_PASSES diff --git a/kernels/fp8_gemm_4wave.py b/kernels/fp8_gemm_4wave.py index d9f4fcd0..fc8c4842 100644 --- a/kernels/fp8_gemm_4wave.py +++ b/kernels/fp8_gemm_4wave.py @@ -6,25 +6,24 @@ Algorithm derived from HipKittens FP8_4wave (https://github.com/HazyResearch/HipKittens/blob/7782744ba1fd259a377a99e2ea8f71384cc80e55/kernels/gemm/fp8fp32/FP8_4wave/4_wave.cu#L1). -Global IO, scale loads, and bf16 stores go through the layout API -(``fx.rocdl.make_buffer_tensor`` + ``fx.copy`` with ``BufferCopyLDS128b`` -/ ``BufferCopy{16,32,128}b``). MFMAs use ``fly.mma_atom_call_ssa`` so -the chained Vec(4, f32) accumulator stays on AGPR. The XOR swizzle and -the 8-buffer LDS pipeline ping-pong are kept as direct arithmetic to -preserve the original kernel's interleaved-cluster scheduling. +Global IO, scale loads, bf16 stores, and the per-atom MFMA all go +through the layout API (``fx.rocdl.make_buffer_tensor`` + ``fx.copy`` ++ ``fx.gemm``). The XOR swizzle and the 8-buffer LDS pipeline are +kept as direct arithmetic to preserve the kernel's interleaved +cluster scheduling. + +LDS storage uses 8 named ``fx.get_dyn_shared`` bases carved into one +dyn-shared region; the ``fly-attach-lds-alias-scope`` MLIR pass +attaches per-symbol alias scopes so AMDGPU's SI Wait Counter pass +treats cross-buffer accesses as no-alias. """ import flydsl.compiler as flyc import flydsl.expr as fx -from flydsl._mlir.dialects import arith as _arith_dialect -from flydsl._mlir.dialects import fly as _fly_dialect from flydsl._mlir.dialects import llvm as _llvm -from flydsl._mlir.dialects import memref as _memref_dialect from flydsl._mlir.dialects.fly_rocdl import TargetAddressSpace as _TgtAS -from flydsl.compiler.kernel_function import CompilationContext -from flydsl.expr import arith, const_expr, range_constexpr +from flydsl.expr import arith, const_expr, range_constexpr, rocdl from flydsl.expr.typing import Vector as Vec -from flydsl.utils.smem_allocator import SmemAllocator, SmemPtr def _divmod(a, b): @@ -60,8 +59,7 @@ def _xcd_swizzle(num_pid_m, num_pid_n): def compile_fp8_gemm(*, M: int, N: int, K: int, BLOCK_M: int = 256, BLOCK_N: int = 256, use_xcd_remap: bool = True): - # MFMA atom is 16x16x128; 4 waves in a 2x2 config require BLOCK >= 64. - BLOCK_K = 128 + BLOCK_K = 128 # MFMA_Scale 16x16x128 atom; 4-wave 2x2 layout needs BLOCK >= 64. LDS_BLOCK_M = BLOCK_M // 2 LDS_BLOCK_N = BLOCK_N // 2 assert BLOCK_M >= 64 and BLOCK_N >= 64 @@ -69,7 +67,7 @@ def compile_fp8_gemm(*, M: int, N: int, K: int, BLOCK_M: int = 256, BLOCK_N: int N_BLOCKS = N // BLOCK_N K_ITERS = K // BLOCK_K - # Number of 16-row 16x128 tiles per wave per A/B partition. + # 16-row 16x128 atom tiles per wave per A/B partition. N_TILES_A = BLOCK_M // 4 // 16 N_TILES_B = BLOCK_N // 4 // 16 N_ACCUMS = N_TILES_A * N_TILES_B @@ -77,26 +75,24 @@ def compile_fp8_gemm(*, M: int, N: int, K: int, BLOCK_M: int = 256, BLOCK_N: int _use_interleaved_block = BLOCK_M == 256 and BLOCK_N == 256 - A_lds_cur0_alloc = SmemAllocator(None, "gfx950", "A_lds_cur_0") - A_lds_cur1_alloc = SmemAllocator(None, "gfx950", "A_lds_cur_1") - A_lds_next0_alloc = SmemAllocator(None, "gfx950", "A_lds_next_0") - A_lds_next1_alloc = SmemAllocator(None, "gfx950", "A_lds_next_1") - B_lds_cur0_alloc = SmemAllocator(None, "gfx950", "B_lds_cur_0") - B_lds_cur1_alloc = SmemAllocator(None, "gfx950", "B_lds_cur_1") - B_lds_next0_alloc = SmemAllocator(None, "gfx950", "B_lds_next_0") - B_lds_next1_alloc = SmemAllocator(None, "gfx950", "B_lds_next_1") - a_lds_size = LDS_BLOCK_M * BLOCK_K b_lds_size = LDS_BLOCK_N * BLOCK_K - A_lds_cur0_alloc.ptr = a_lds_size - A_lds_cur1_alloc.ptr = a_lds_size - A_lds_next0_alloc.ptr = a_lds_size - A_lds_next1_alloc.ptr = a_lds_size - B_lds_cur0_alloc.ptr = b_lds_size - B_lds_cur1_alloc.ptr = b_lds_size - B_lds_next0_alloc.ptr = b_lds_size - B_lds_next1_alloc.ptr = b_lds_size + # 8 disjoint sub-buffers within one dyn-shared region. Each named + # ``fx.get_dyn_shared`` emits a distinct LDS global so the + # ``fly-attach-lds-alias-scope`` pass can give it its own alias + # scope. + _LDS_SUBBUFS = [ + ("A_lds_cur_0", 0 * a_lds_size), + ("A_lds_cur_1", 1 * a_lds_size), + ("A_lds_next_0", 2 * a_lds_size), + ("A_lds_next_1", 3 * a_lds_size), + ("B_lds_cur_0", 4 * a_lds_size + 0 * b_lds_size), + ("B_lds_cur_1", 4 * a_lds_size + 1 * b_lds_size), + ("B_lds_next_0", 4 * a_lds_size + 2 * b_lds_size), + ("B_lds_next_1", 4 * a_lds_size + 3 * b_lds_size), + ] + _TOTAL_LDS_BYTES = 4 * a_lds_size + 4 * b_lds_size @flyc.kernel def kernel_gemm( @@ -106,23 +102,23 @@ def kernel_gemm( A_scale: fx.Tensor, B_scale: fx.Tensor, ): - MfmaAccum_t = Vec.make_type(4, fx.Float32) RT_C_i = Vec.filled(4, 0.0, fx.Float32) F8_IR_t = fx.Float8E4M3FN.ir_type - Vec16_t = Vec.make_type(16, fx.Float8E4M3FN) - a_cur0 = SmemPtr(A_lds_cur0_alloc.get_base(), 0, F8_IR_t, shape=(a_lds_size,)).get() - a_cur1 = SmemPtr(A_lds_cur1_alloc.get_base(), 0, F8_IR_t, shape=(a_lds_size,)).get() - a_next0 = SmemPtr(A_lds_next0_alloc.get_base(), 0, F8_IR_t, shape=(a_lds_size,)).get() - a_next1 = SmemPtr(A_lds_next1_alloc.get_base(), 0, F8_IR_t, shape=(a_lds_size,)).get() + _AS_SHARED = 2 + _shared_f8_ptr_ty = fx.PointerType.get(F8_IR_t, _AS_SHARED, 512) + _shared_i32_ptr_ty = fx.PointerType.get(fx.T.i32(), _AS_SHARED, 512) - b_cur0 = SmemPtr(B_lds_cur0_alloc.get_base(), 0, F8_IR_t, shape=(b_lds_size,)).get() - b_cur1 = SmemPtr(B_lds_cur1_alloc.get_base(), 0, F8_IR_t, shape=(b_lds_size,)).get() - b_next0 = SmemPtr(B_lds_next0_alloc.get_base(), 0, F8_IR_t, shape=(b_lds_size,)).get() - b_next1 = SmemPtr(B_lds_next1_alloc.get_base(), 0, F8_IR_t, shape=(b_lds_size,)).get() + _lds_int = { + name: fx.ptrtoint(fx.get_dyn_shared(sym_name=name)) + for name, _ in _LDS_SUBBUFS + } + _lds_off = dict(_LDS_SUBBUFS) - _AS_SHARED = 2 - _shared_ptr_ty = fx.PointerType.get(F8_IR_t, _AS_SHARED, 512) + a_cur0, a_cur1 = "A_lds_cur_0", "A_lds_cur_1" + a_next0, a_next1 = "A_lds_next_0", "A_lds_next_1" + b_cur0, b_cur1 = "B_lds_cur_0", "B_lds_cur_1" + b_next0, b_next1 = "B_lds_next_0", "B_lds_next_1" lane_id = fx.thread_idx.x % 64 wave_id = fx.thread_idx.x // 64 @@ -139,9 +135,9 @@ def kernel_gemm( B0_gl_offset = (tile_j * BLOCK_N) * K B1_gl_offset = (tile_j * BLOCK_N + LDS_BLOCK_N) * K - # A/B come in as torch.int8 (PyTorch fp8 view restriction); recast - # the buffer-desc pointer's element type to fp8 so typed copy - # atoms (BufferCopyLDS128b) accept them. + # A/B arrive as torch.int8 (PyTorch fp8 view limitation); recast + # the buffer-desc element type to fp8 so BufferCopyLDS128b takes + # them. def _make_fp8_buf_tensor(arg_i8): t_i8 = fx.rocdl.make_buffer_tensor(arg_i8) iter_i8 = fx.get_iter(t_i8) @@ -164,89 +160,91 @@ def _make_fp8_buf_tensor(arg_i8): sa_div = fx.logical_divide(gSA, fx.make_layout(1, 1)) sb_div = fx.logical_divide(gSB, fx.make_layout(1, 1)) - # XOR bits 4..6 of the tile-local linear offset with bits 8..10. - def _swizzle_128(row, col): - offset = row * BLOCK_K + col - swz = ((offset % (16 * BLOCK_K)) >> 8) << 4 - swizzled = offset ^ swz - return swizzled // BLOCK_K, swizzled % BLOCK_K + # XOR 3 bits of dim-0 (row, bit-1 base) with 3 bits of dim-1 + # (col, bit-4 base). Same as the manual + # ((offset>>8)<<4) ^ offset; shared between LDS and global + # access via two outer layouts with different row strides. + _swz_attr = fx.CoordSwizzleType.get(3, 1, [0], 4, [1]) + _swz_shape = (LDS_BLOCK_M, BLOCK_K) + _coord_swz = fx.make_composed_layout( + fx.static(_swz_attr), fx.make_identity_layout(_swz_shape) + ) + _lds_swz_layout = fx.make_composed_layout( + fx.make_layout(_swz_shape, (BLOCK_K, 1)), _coord_swz + ) + _gl_swz_layout = fx.make_composed_layout( + fx.make_layout(_swz_shape, (K, 1)), _coord_swz + ) def _compute_global_swizzle(): offsets = [] for round in range_constexpr(max(N_TILES_A, N_TILES_B)): row = lane_id // 8 + wave_id * 8 + round * 32 col = (lane_id % 8) * 16 - r, c = _swizzle_128(row, col) - offsets.append(r * K + c) + offsets.append(fx.crd2idx((row, col), _gl_swz_layout).to_py_value()) return offsets - def _compute_lds_swizzle(wave_idx, n_tiles): - lds_swz = [] - for row_offset in range_constexpr(n_tiles): - row = wave_idx * (n_tiles * 16) + row_offset * 16 + lane_id % 16 - swz = [] - for i in range_constexpr(2): - col = (lane_id // 16) * 16 + i * 64 - r, c = _swizzle_128(row, col) - swz.append(r * BLOCK_K + c) - lds_swz.append(swz) - return lds_swz - - # G->LDS atom: 128 bits per thread = 16 fp8 elements. The atom - # state carries the runtime ``soffset`` set to ``k_offset``. + # 128 bits per thread = 16 fp8 elements; soffset carries the + # runtime k offset. g2lds_atom = fx.make_copy_atom(fx.rocdl.BufferCopyLDS128b(), 128) - # LDS dst pointers for ``buffer_load_lds`` go through - # ``extract_aligned_pointer_as_index + add + inttoptr`` to break - # LLVM's alias chain on the LDS sub-buffer symbols; otherwise the - # AMDGPU backend inserts defensive ``s_waitcnt vmcnt(N)`` between - # G->LDS writes and the subsequent ``ds_read``. - def _lds_dst_at(lds_dst_mem, byte_offset_runtime): - base_idx = _memref_dialect.extract_aligned_pointer_as_index(lds_dst_mem) - offset_idx = base_idx + fx.Index(byte_offset_runtime) - offset_i64 = _arith_dialect.index_cast(fx.T.i64(), offset_idx) - lds_ptr = fx.inttoptr(_shared_ptr_ty, offset_i64) - return fx.make_view(lds_ptr, fx.make_layout(1, 1)) - - def _load_lds(gl_src_div, lds_dst_mem, k_offset, gl_offsets, n_tiles): + # Routed through ptrtoint + add + inttoptr instead of the + # natural fx.add_offset+fx.recast_iter chain: empirically the + # natural route compiles ~5-9% slower on BLOCK=256 shapes + # (probably because the extra int_tuple / recast_iter ops + # survive canonicalization and disrupt the AMDGPU back-end's + # common-base + offset pattern matching). + def _lds_dst_at(name, byte_offset_runtime): + off = _lds_int[name] + fx.Int32(_lds_off[name] + byte_offset_runtime) + ptr = fx.inttoptr(_shared_f8_ptr_ty, off) + return fx.make_view(ptr, fx.make_layout(1, 1)) + + def _load_lds(gl_src_div, name, k_offset, gl_offsets, n_tiles): assert len(gl_offsets) >= n_tiles for step in range_constexpr(n_tiles): src = fx.slice(gl_src_div, (None, fx.Int32(gl_offsets[step]))) - dst = _lds_dst_at(lds_dst_mem, wave_id * 1024 + step * 4096) + dst = _lds_dst_at(name, wave_id * 1024 + step * 4096) fx.copy(g2lds_atom, src, dst, soffset=fx.Int32(k_offset)) - def _load_one_lds(gl_src_div, lds_dst_mem, k_offset, gl_offsets, tile_idx): + def _load_one_lds(gl_src_div, name, k_offset, gl_offsets, tile_idx): assert len(gl_offsets) > tile_idx src = fx.slice(gl_src_div, (None, fx.Int32(gl_offsets[tile_idx]))) - dst = _lds_dst_at(lds_dst_mem, wave_id * 1024 + tile_idx * 4096) + dst = _lds_dst_at(name, wave_id * 1024 + tile_idx * 4096) fx.copy(g2lds_atom, src, dst, soffset=fx.Int32(k_offset)) def _pack_i32x4_i32x8(lo, hi): return lo.shuffle(hi, list(range(8))) - def _load_rt(lds_src, wave_idx, n_tiles): + # 16 fp8 == 4 i32; load as i32x4 because LLVM has no + # vector<16xf8>. + def _vec_load_lds_i32x4(name, fp8_elem_offset): + off = _lds_int[name] + fx.Int32(_lds_off[name] + fp8_elem_offset) + ptr = fx.inttoptr(_shared_i32_ptr_ty, off) + view = fx.make_view(ptr, fx.make_layout(4, 1)) + return Vec(fx.memref_load_vec(view)) + + def _load_rt(name, wave_idx, n_tiles): frag = [] for i in range_constexpr(n_tiles): row = wave_idx * (n_tiles * 16) + i * 16 + lane_id % 16 halves = [] for step in range_constexpr(2): col = (lane_id // 16) * 16 + step * 64 - r, c = _swizzle_128(row, col) - v = Vec.load(Vec16_t, lds_src, [fx.Index(r * BLOCK_K + c)]) - halves.append(v.bitcast(fx.Int32)) + halves.append(_vec_load_lds_i32x4(name, fx.crd2idx((row, col), _lds_swz_layout).to_py_value())) frag.append(_pack_i32x4_i32x8(halves[0], halves[1])) return frag - def _load_one_rt(lds_src, lds_swz, row, k): - v = Vec.load(Vec16_t, lds_src, [fx.Index(lds_swz[row][k])]) - return v.bitcast(fx.Int32) + def _load_one_rt(name, wave_idx, n_tiles, row_idx, k): + row = wave_idx * (n_tiles * 16) + row_idx * 16 + lane_id % 16 + col = (lane_id // 16) * 16 + k * 64 + return _vec_load_lds_i32x4(name, fx.crd2idx((row, col), _lds_swz_layout).to_py_value()) def _c_idx(i, j): return i * N_TILES_B + j - # The C++ AddressSpace enum prepends Generic=0, so the Python - # AddressSpace.Register value (2) maps to Shared on the C++ side. - # Pass the C++ integer (3) directly to MemRefType.get. + # C++ AddressSpace enum prepends Generic=0; pass the C++ index + # (3 = Register) directly to MemRefType.get to avoid the + # Python AddressSpace.Register (=2) being read as Shared. _AS_REG = 3 scale_atom_4 = fx.make_copy_atom(fx.rocdl.BufferCopy128b(), fx.Float32) scale_atom_1 = fx.make_copy_atom(fx.rocdl.BufferCopy32b(), fx.Float32) @@ -291,13 +289,32 @@ def _wait_barrier(count): has_side_effects=True, ) - # MFMA via ``fly.mma_atom_call_ssa``. The atom carries scale_a / - # scale_b state (default 0x7F7F7F7F = no scaling). Returns a - # chained Vec(4, f32) SSA so the accumulator stays on AGPR. mma_atom = fx.make_mma_atom(fx.rocdl.cdna4.MFMA_Scale(16, 16, 128, fx.Float8E4M3FN)) - def _mfma(a_val, b_val, c_val): - return _fly_dialect.mma_atom_call_ssa([MfmaAccum_t], mma_atom, a_val, b_val, c_val) + # MFMA goes through fx.gemm. The Vec operands are spilled into + # register-memref fragments around each call; the alloca / + # store / load round trip is folded away by + # ``fly-convert-atom-call-to-ssa-form`` + + # ``fly-promote-regmem-to-vectorssa``, leaving a plain + # ``llvm.amdgcn.mfma.scale.f32.16x16x128`` chained on + # ``<4 x float>`` SSA so ISel keeps the accumulator on AGPR. + a_reg_ty = fx.MemRefType.get(fx.T.i32(), fx.LayoutType.get(8, 1), _AS_REG) + b_reg_ty = fx.MemRefType.get(fx.T.i32(), fx.LayoutType.get(8, 1), _AS_REG) + c_reg_ty = fx.MemRefType.get(fx.T.f32(), fx.LayoutType.get(4, 1), _AS_REG) + tiled_mma_single = fx.make_tiled_mma( + mma_atom, + fx.make_layout((2, 2, 1), (1, 2, 0)), + ) + + def _mfma(a_vec, b_vec, c_vec): + a_mem = fx.memref_alloca(a_reg_ty, fx.make_layout(8, 1)) + b_mem = fx.memref_alloca(b_reg_ty, fx.make_layout(8, 1)) + c_mem = fx.memref_alloca(c_reg_ty, fx.make_layout(4, 1)) + fx.memref_store_vec(a_vec, a_mem) + fx.memref_store_vec(b_vec, b_mem) + fx.memref_store_vec(c_vec, c_mem) + fx.gemm(tiled_mma_single, c_mem, a_mem, b_mem, c_mem) + return Vec(fx.memref_load_vec(c_mem)) def _mfma_ABt_all(a, b, c): assert len(a) == N_TILES_A @@ -316,55 +333,54 @@ def _mfma_ABt_one(a, b, c, m, n): return c def _interleaved_cluster(lds_dst, gl_src, k_offset, gl_offsets, wave_idx, lds_src, n_tiles_lds, a, b, c): - # 64x64 output via 4x4 MFMAs, with per-tile G→LDS and LDS→reg + # 4x4 MFMAs over 64x64, with per-tile G->LDS and LDS->reg # loads interleaved between MFMAs to hide latency. rt_dst = [] c = _mfma_ABt_one(a, b, c, 0, 0) c = _mfma_ABt_one(a, b, c, 0, 1) - lds_swz = _compute_lds_swizzle(wave_idx, n_tiles_lds) _load_one_lds(gl_src, lds_dst, k_offset, gl_offsets, 0) - rt_dst_0 = _load_one_rt(lds_src, lds_swz, 0, 0) + rt_dst_0 = _load_one_rt(lds_src, wave_idx, n_tiles_lds, 0, 0) c = _mfma_ABt_one(a, b, c, 0, 2) - rt_dst_1 = _load_one_rt(lds_src, lds_swz, 0, 1) + rt_dst_1 = _load_one_rt(lds_src, wave_idx, n_tiles_lds, 0, 1) rt_dst.append(_pack_i32x4_i32x8(rt_dst_0, rt_dst_1)) c = _mfma_ABt_one(a, b, c, 0, 3) _load_one_lds(gl_src, lds_dst, k_offset, gl_offsets, 1) - rt_dst_0 = _load_one_rt(lds_src, lds_swz, 1, 0) + rt_dst_0 = _load_one_rt(lds_src, wave_idx, n_tiles_lds, 1, 0) c = _mfma_ABt_one(a, b, c, 1, 0) c = _mfma_ABt_one(a, b, c, 1, 1) - rt_dst_1 = _load_one_rt(lds_src, lds_swz, 1, 1) + rt_dst_1 = _load_one_rt(lds_src, wave_idx, n_tiles_lds, 1, 1) rt_dst.append(_pack_i32x4_i32x8(rt_dst_0, rt_dst_1)) c = _mfma_ABt_one(a, b, c, 1, 2) c = _mfma_ABt_one(a, b, c, 1, 3) _load_one_lds(gl_src, lds_dst, k_offset, gl_offsets, 2) - rt_dst_0 = _load_one_rt(lds_src, lds_swz, 2, 0) + rt_dst_0 = _load_one_rt(lds_src, wave_idx, n_tiles_lds, 2, 0) c = _mfma_ABt_one(a, b, c, 2, 0) c = _mfma_ABt_one(a, b, c, 2, 1) - rt_dst_1 = _load_one_rt(lds_src, lds_swz, 2, 1) + rt_dst_1 = _load_one_rt(lds_src, wave_idx, n_tiles_lds, 2, 1) rt_dst.append(_pack_i32x4_i32x8(rt_dst_0, rt_dst_1)) c = _mfma_ABt_one(a, b, c, 2, 2) c = _mfma_ABt_one(a, b, c, 2, 3) _load_one_lds(gl_src, lds_dst, k_offset, gl_offsets, 3) - rt_dst_0 = _load_one_rt(lds_src, lds_swz, 3, 0) + rt_dst_0 = _load_one_rt(lds_src, wave_idx, n_tiles_lds, 3, 0) c = _mfma_ABt_one(a, b, c, 3, 0) c = _mfma_ABt_one(a, b, c, 3, 1) - rt_dst_1 = _load_one_rt(lds_src, lds_swz, 3, 1) + rt_dst_1 = _load_one_rt(lds_src, wave_idx, n_tiles_lds, 3, 1) rt_dst.append(_pack_i32x4_i32x8(rt_dst_0, rt_dst_1)) c = _mfma_ABt_one(a, b, c, 3, 2) @@ -525,26 +541,6 @@ def launch_gemm( B_scale: fx.Tensor, stream: fx.Stream, ): - from flydsl._mlir import ir - - A_lds_cur0_alloc.finalized = False - A_lds_cur1_alloc.finalized = False - A_lds_next0_alloc.finalized = False - A_lds_next1_alloc.finalized = False - B_lds_cur0_alloc.finalized = False - B_lds_cur1_alloc.finalized = False - B_lds_next0_alloc.finalized = False - B_lds_next1_alloc.finalized = False - ctx = CompilationContext.get_current() - with ir.InsertionPoint(ctx.gpu_module_body): - A_lds_cur0_alloc.finalize() - A_lds_cur1_alloc.finalize() - A_lds_next0_alloc.finalize() - A_lds_next1_alloc.finalize() - B_lds_cur0_alloc.finalize() - B_lds_cur1_alloc.finalize() - B_lds_next0_alloc.finalize() - B_lds_next1_alloc.finalize() grid_x = (M * N) // (BLOCK_M * BLOCK_N) kernel_gemm( A, @@ -553,6 +549,11 @@ def launch_gemm( A_scale, B_scale, value_attrs={"rocdl.waves_per_eu": 1, "rocdl.flat_work_group_size": "256,256"}, - ).launch(grid=(grid_x, 1, 1), block=(256, 1, 1), stream=stream) + ).launch( + grid=(grid_x, 1, 1), + block=(256, 1, 1), + smem=_TOTAL_LDS_BYTES, + stream=stream, + ) return launch_gemm diff --git a/lib/Conversion/FlyToROCDL/FlyToROCDL.cpp b/lib/Conversion/FlyToROCDL/FlyToROCDL.cpp index 7985ca42..c3a32b96 100644 --- a/lib/Conversion/FlyToROCDL/FlyToROCDL.cpp +++ b/lib/Conversion/FlyToROCDL/FlyToROCDL.cpp @@ -123,7 +123,8 @@ class GetDynSharedOpLowering : public OpConversionPattern { if (!moduleOp) return op->emitError("get_dyn_shared must be inside a gpu.module"); - LLVM::GlobalOp sharedGlobal = getOrCreateDynSharedGlobal(rewriter, moduleOp, loc, addrSpace); + LLVM::GlobalOp sharedGlobal = + getOrCreateDynSharedGlobal(rewriter, moduleOp, loc, addrSpace, op.getSymNameAttr()); OpBuilder::InsertionGuard guard(rewriter); rewriter.setInsertionPoint(op); @@ -142,21 +143,39 @@ class GetDynSharedOpLowering : public OpConversionPattern { private: static LLVM::GlobalOp getOrCreateDynSharedGlobal(ConversionPatternRewriter &rewriter, gpu::GPUModuleOp moduleOp, Location loc, - unsigned addrSpace) { + unsigned addrSpace, + StringAttr requestedName) { + // When sym_name is requested we look up by exact name and create a + // distinct external [0 x i8] LDS global if missing. Otherwise we + // reuse the first existing matching dyn-shared global, falling back + // to a freshly generated `__dynamic_shared_` symbol. llvm::StringSet<> existingNames; + LLVM::GlobalOp firstMatch = nullptr; for (auto globalOp : moduleOp.getBody()->getOps()) { existingNames.insert(globalOp.getSymName()); - if (auto arrayType = dyn_cast(globalOp.getType())) { - if (globalOp.getAddrSpace() == addrSpace && arrayType.getNumElements() == 0 && - globalOp.getAlignment().value_or(0) == 1024) - return globalOp; + if (requestedName && globalOp.getSymName() == requestedName.getValue()) + return globalOp; + if (!requestedName) { + if (auto arrayType = dyn_cast(globalOp.getType())) { + if (!firstMatch && globalOp.getAddrSpace() == addrSpace && + arrayType.getNumElements() == 0 && + globalOp.getAlignment().value_or(0) == 1024) + firstMatch = globalOp; + } } } + if (!requestedName && firstMatch) + return firstMatch; - unsigned counter = 0; - SmallString<128> symName = SymbolTable::generateSymbolName<128>( - "__dynamic_shared_", [&](StringRef candidate) { return existingNames.contains(candidate); }, - counter); + SmallString<128> symName; + if (requestedName) { + symName.assign(requestedName.getValue()); + } else { + unsigned counter = 0; + symName = SymbolTable::generateSymbolName<128>( + "__dynamic_shared_", + [&](StringRef candidate) { return existingNames.contains(candidate); }, counter); + } OpBuilder::InsertionGuard guard(rewriter); rewriter.setInsertionPointToStart(moduleOp.getBody()); diff --git a/lib/Dialect/Fly/CMakeLists.txt b/lib/Dialect/Fly/CMakeLists.txt index 65e500d3..455de6e2 100644 --- a/lib/Dialect/Fly/CMakeLists.txt +++ b/lib/Dialect/Fly/CMakeLists.txt @@ -15,6 +15,7 @@ add_mlir_dialect_library(MLIRFlyDialect Transforms/ConvertAtomCallToSSAForm.cpp Transforms/PromoteRegMemToVectorSSA.cpp Transforms/IntSwizzleSimplify.cpp + Transforms/AttachLDSAliasScope.cpp DEPENDS MLIRFlyIncGen @@ -24,5 +25,6 @@ add_mlir_dialect_library(MLIRFlyDialect LINK_LIBS MLIRGPUDialect MLIRIR + MLIRLLVMDialect MLIRTargetLLVMIRExport ) diff --git a/lib/Dialect/Fly/Transforms/AttachLDSAliasScope.cpp b/lib/Dialect/Fly/Transforms/AttachLDSAliasScope.cpp new file mode 100644 index 00000000..7f6ab0d6 --- /dev/null +++ b/lib/Dialect/Fly/Transforms/AttachLDSAliasScope.cpp @@ -0,0 +1,266 @@ +// SPDX-License-Identifier: Apache-2.0 +// Copyright (c) 2026 FlyDSL Project Contributors + +#include "mlir/Dialect/GPU/IR/GPUDialect.h" +#include "mlir/Dialect/LLVMIR/LLVMAttrs.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/Pass/Pass.h" + +#include "flydsl/Dialect/Fly/Transforms/Passes.h" + +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringRef.h" + +using namespace mlir; + +namespace mlir { +namespace fly { +#define GEN_PASS_DEF_FLYATTACHLDSALIASSCOPEPASS +#include "flydsl/Dialect/Fly/Transforms/Passes.h.inc" +} // namespace fly +} // namespace mlir + +namespace { + +// LDS address space on AMDGPU. +static constexpr unsigned kLDSAddrSpace = 3; + +/// Returns true if `g` is an external `[0 x i8] addrspace(3)` global, +/// i.e. a dyn-shared LDS base. We restrict on size 0 (HSA dynamic LDS +/// convention) so we don't accidentally tag SmemAllocator-style static +/// globals (whose alias info already comes from distinct symbols). +static bool isDynSharedGlobal(LLVM::GlobalOp g) { + if (g.getAddrSpace() != kLDSAddrSpace) + return false; + if (g.getLinkage() != LLVM::Linkage::External) + return false; + auto arrTy = dyn_cast(g.getType()); + if (!arrTy) + return false; + return arrTy.getNumElements() == 0; +} + +/// Per-SSA-value provenance, encoded as a tri-state DenseMap: +/// - absent entry => unknown / not derived from any tracked global +/// - mapped to G => derived from exactly the LDS global G +/// - mapped to nullptr => *known* to mix two or more globals (ambiguous); +/// downstream uses that consume this value must +/// also be marked ambiguous so the pass never +/// tags an access with a single scope when its +/// true scope set is larger. +using PtrProvenance = llvm::DenseMap; +using IntProvenance = llvm::DenseMap; + +/// True iff `op` is an `llvm.amdgcn.raw.ptr.buffer.load.lds` intrinsic. +static bool isBufferLoadLDS(LLVM::CallOp call) { + auto callee = call.getCallee(); + if (!callee) + return false; + return callee->starts_with("llvm.amdgcn.raw.ptr.buffer.load.lds"); +} + +/// Returns the addrspace(3) pointer operand consumed by `op`, or +/// nullptr if there isn't exactly one such operand worth tagging. +static Value memoryPointerForOp(Operation *op) { + if (auto load = dyn_cast(op)) { + auto ptrTy = dyn_cast(load.getAddr().getType()); + if (ptrTy && ptrTy.getAddressSpace() == kLDSAddrSpace) + return load.getAddr(); + return nullptr; + } + if (auto store = dyn_cast(op)) { + auto ptrTy = dyn_cast(store.getAddr().getType()); + if (ptrTy && ptrTy.getAddressSpace() == kLDSAddrSpace) + return store.getAddr(); + return nullptr; + } + if (auto call = dyn_cast(op)) { + if (!isBufferLoadLDS(call)) + return nullptr; + // The LDS pointer is the second arg (after the buffer-desc ptr). + if (call.getNumOperands() >= 2) { + Value lds = call.getOperand(1); + auto ptrTy = dyn_cast(lds.getType()); + if (ptrTy && ptrTy.getAddressSpace() == kLDSAddrSpace) + return lds; + } + return nullptr; + } + return nullptr; +} + +/// Forward dataflow that maps SSA values back to the LDS global they +/// derive from. Only the canonical pointer-arithmetic chain is tracked +/// so that we never tag an access whose pointer might really span more +/// than one global: +/// * `LLVM::AddressOfOp(@g)` -> ptr provenance(@g) +/// * `LLVM::PtrToIntOp(p)` -> int provenance(p) +/// * `LLVM::AddOp(a, b)` -> int provenance(a) iff *exactly one* +/// operand carries provenance; if both +/// carry provenance, the result mixes +/// globals and is recorded as ambiguous +/// * `LLVM::IntToPtrOp(i)` -> ptr provenance(i) +/// * `LLVM::GEPOp(p)` -> ptr provenance(p) +/// +/// `or`/`sub`/`xor`/`and`/`shl`/`shr`/`bitcast` and any other op are +/// treated as provenance-destroying. The dataflow is intentionally +/// fail-safe: when in doubt, drop the tag rather than emit one that +/// could wrongly tell LLVM "no alias" about pointers that really do +/// alias at runtime. +static void computeProvenance( + LLVM::LLVMFuncOp func, + const llvm::DenseMap &nameToGlobal, + PtrProvenance &ptrProv, IntProvenance &intProv) { + // Tri-state DenseMap merge. Mirrors the encoding documented on + // `IntProvenance` / `PtrProvenance`: + // - absent entry => unknown + // - present, G => provenance(G) + // - present, null => ambiguous + // + // Returns (resultProvenance, hasInfo). When hasInfo is false the + // caller stores nothing (keeps the value unknown); when hasInfo is + // true and resultProvenance is null the caller stores a sentinel + // entry so subsequent uses also propagate as ambiguous. + auto combine = [](LLVM::GlobalOp a, bool aSeen, LLVM::GlobalOp b, + bool bSeen) -> std::pair { + if (!aSeen && !bSeen) + return {nullptr, false}; + if (!aSeen) + return {b, true}; + if (!bSeen) + return {a, true}; + // Both operands have known provenance entries. + // - either is ambiguous (null) -> ambiguous + // - same non-null global -> *still* ambiguous, because adding + // a pointer-derived int to itself doesn't represent any single + // well-formed pointer + // - different non-null globals -> ambiguous + if (!a || !b || a != b) + return {nullptr, true}; + return {nullptr, true}; // see comment above (a == b case) + }; + + func.walk([&](Operation *op) { + if (auto addrOf = dyn_cast(op)) { + auto it = nameToGlobal.find(addrOf.getGlobalName()); + if (it != nameToGlobal.end()) + ptrProv[addrOf.getResult()] = it->second; + return; + } + if (auto p2i = dyn_cast(op)) { + auto it = ptrProv.find(p2i.getArg()); + if (it != ptrProv.end()) + intProv[p2i.getResult()] = it->second; // may store ambiguous + return; + } + if (auto add = dyn_cast(op)) { + auto la = intProv.find(add.getLhs()); + auto lb = intProv.find(add.getRhs()); + bool aSeen = la != intProv.end(); + bool bSeen = lb != intProv.end(); + auto [g, hasInfo] = combine(aSeen ? la->second : nullptr, aSeen, + bSeen ? lb->second : nullptr, bSeen); + if (hasInfo) + intProv[add.getResult()] = g; // g may be null = ambiguous sentinel + return; + } + if (auto i2p = dyn_cast(op)) { + auto it = intProv.find(i2p.getArg()); + if (it != intProv.end()) { + auto ptrTy = dyn_cast(i2p.getResult().getType()); + if (ptrTy && ptrTy.getAddressSpace() == kLDSAddrSpace) + ptrProv[i2p.getResult()] = it->second; // propagate ambiguous too + } + return; + } + if (auto gep = dyn_cast(op)) { + auto it = ptrProv.find(gep.getBase()); + if (it != ptrProv.end()) { + auto ptrTy = dyn_cast(gep.getResult().getType()); + if (ptrTy && ptrTy.getAddressSpace() == kLDSAddrSpace) + ptrProv[gep.getResult()] = it->second; + } + return; + } + }); +} + +class FlyAttachLDSAliasScopePass + : public mlir::fly::impl::FlyAttachLDSAliasScopePassBase< + FlyAttachLDSAliasScopePass> { +public: + using mlir::fly::impl::FlyAttachLDSAliasScopePassBase< + FlyAttachLDSAliasScopePass>::FlyAttachLDSAliasScopePassBase; + + void runOnOperation() override { + gpu::GPUModuleOp gpuModule = getOperation(); + + // Collect dyn-shared globals in declaration order. + SmallVector dynGlobals; + llvm::DenseMap nameToGlobal; + for (auto g : gpuModule.getOps()) { + if (isDynSharedGlobal(g)) { + dynGlobals.push_back(g); + nameToGlobal[g.getSymName()] = g; + } + } + if (dynGlobals.size() < 2) + return; // Single (or no) dyn-shared region: nothing to disambiguate. + + MLIRContext *ctx = &getContext(); + OpBuilder builder(ctx); + + // One domain per gpu.module, one scope per dyn-shared global. + auto domain = LLVM::AliasScopeDomainAttr::get( + ctx, builder.getStringAttr("FlyDynSharedDomain")); + + llvm::DenseMap globalToScope; + for (auto g : dynGlobals) { + auto scope = LLVM::AliasScopeAttr::get( + domain, builder.getStringAttr(g.getSymName())); + globalToScope[g] = scope; + } + + // Pre-compute the noalias-set per global = all scopes except its + // own. This is what makes cross-global accesses no-alias. + llvm::DenseMap globalToNoalias; + for (auto g : dynGlobals) { + SmallVector others; + others.reserve(dynGlobals.size() - 1); + for (auto og : dynGlobals) + if (og != g) + others.push_back(globalToScope[og]); + globalToNoalias[g] = ArrayAttr::get(ctx, others); + } + + for (auto func : gpuModule.getOps()) { + if (func.empty()) + continue; + PtrProvenance ptrProv; + IntProvenance intProv; + computeProvenance(func, nameToGlobal, ptrProv, intProv); + + func.walk([&](Operation *op) { + Value lds = memoryPointerForOp(op); + if (!lds) + return; + auto it = ptrProv.find(lds); + if (it == ptrProv.end() || !it->second) + return; + LLVM::GlobalOp g = it->second; + auto scopeIt = globalToScope.find(g); + auto noaliasIt = globalToNoalias.find(g); + if (scopeIt == globalToScope.end() || noaliasIt == globalToNoalias.end()) + return; + auto scopeAttr = ArrayAttr::get(ctx, {scopeIt->second}); + op->setAttr("alias_scopes", scopeAttr); + op->setAttr("noalias_scopes", noaliasIt->second); + }); + } + } +}; + +} // namespace diff --git a/python/flydsl/compiler/backends/rocm.py b/python/flydsl/compiler/backends/rocm.py index c32a328b..f12001dc 100644 --- a/python/flydsl/compiler/backends/rocm.py +++ b/python/flydsl/compiler/backends/rocm.py @@ -84,6 +84,7 @@ def _pipeline_parts(self, *, compile_hints: dict) -> Tuple[List[str], str]: "convert-arith-to-llvm", "convert-func-to-llvm", "reconcile-unrealized-casts", + "gpu.module(fly-attach-lds-alias-scope)", *( ["ensure-debug-info-scope-on-llvm-func{emission-kind=LineTablesOnly}"] if env.debug.enable_debug_info diff --git a/python/flydsl/expr/primitive.py b/python/flydsl/expr/primitive.py index f67d9dc6..fb989e57 100644 --- a/python/flydsl/expr/primitive.py +++ b/python/flydsl/expr/primitive.py @@ -999,8 +999,19 @@ def make_ptr(result_type, args, loc=None, ip=None): @traced_op -def get_dyn_shared(loc=None, ip=None): - return fly.get_dyn_shared(loc=loc, ip=ip) +def get_dyn_shared(sym_name=None, loc=None, ip=None): + """Get a base pointer into the kernel's dynamic shared memory. + + If ``sym_name`` is provided the lowering emits a distinct external + ``[0 x i8] addrspace(3) align 1024`` global with that exact name. + All named bases share the same runtime LDS region (each starts at + offset 0 of the dynamic LDS area) but the + ``fly-attach-lds-alias-scope`` pass uses the distinct symbols to + attach ``alias_scope``/``noalias`` metadata, so AMDGPU's + ``SIInsertWaitcnts`` pass treats accesses through different names + as no-alias even though ``LowerModuleLDS`` later merges them. + """ + return fly.get_dyn_shared(sym_name=sym_name, loc=loc, ip=ip) @traced_op diff --git a/tests/mlir/Transforms/attach_lds_alias_scope.mlir b/tests/mlir/Transforms/attach_lds_alias_scope.mlir new file mode 100644 index 00000000..8d9fdf28 --- /dev/null +++ b/tests/mlir/Transforms/attach_lds_alias_scope.mlir @@ -0,0 +1,227 @@ +// SPDX-License-Identifier: Apache-2.0 +// Copyright (c) 2026 FlyDSL Project Contributors +// RUN: %fly-opt %s --pass-pipeline='builtin.module(gpu.module(fly-attach-lds-alias-scope))' | FileCheck %s + +// fly-attach-lds-alias-scope finds external `[0 x i8] addrspace(3)` +// LDS globals in a gpu.module, gives each one a distinct alias scope, +// and tags every load / store / amdgcn.raw.ptr.buffer.load.lds whose +// addrspace(3) pointer can be traced back to a single global through +// addressof / ptrtoint / add / inttoptr / GEP. + +// ----------------------------------------------------------------------------- +// Two named dyn-shared globals -> per-symbol alias_scopes / noalias_scopes on +// loads, with the int-derived pointer being recognised through +// ptrtoint+add+inttoptr. +// ----------------------------------------------------------------------------- + +// CHECK-DAG: #[[DOMAIN:.+]] = #llvm.alias_scope_domain<{{.*}}description = "FlyDynSharedDomain"> +// CHECK-DAG: #[[SCOPE_A:.+]] = #llvm.alias_scope<{{.*}}domain = #[[DOMAIN]], description = "buf_a"> +// CHECK-DAG: #[[SCOPE_B:.+]] = #llvm.alias_scope<{{.*}}domain = #[[DOMAIN]], description = "buf_b"> + +// CHECK-LABEL: gpu.module @two_named +gpu.module @two_named { + llvm.mlir.global external @buf_a() {addr_space = 3 : i32, alignment = 1024 : i64, dso_local} : !llvm.array<0 x i8> + llvm.mlir.global external @buf_b() {addr_space = 3 : i32, alignment = 1024 : i64, dso_local} : !llvm.array<0 x i8> + + // CHECK-LABEL: llvm.func @load_pair + llvm.func @load_pair(%off: i32) -> vector<4xi32> { + %a_ptr = llvm.mlir.addressof @buf_a : !llvm.ptr<3> + %b_ptr = llvm.mlir.addressof @buf_b : !llvm.ptr<3> + %a_int = llvm.ptrtoint %a_ptr : !llvm.ptr<3> to i32 + %b_int = llvm.ptrtoint %b_ptr : !llvm.ptr<3> to i32 + %a_off = llvm.add %a_int, %off : i32 + %b_off = llvm.add %b_int, %off : i32 + %a_p = llvm.inttoptr %a_off : i32 to !llvm.ptr<3> + %b_p = llvm.inttoptr %b_off : i32 to !llvm.ptr<3> + // CHECK: llvm.load %{{.+}} {alias_scopes = [#[[SCOPE_A]]], noalias_scopes = [#[[SCOPE_B]]]} + %va = llvm.load %a_p : !llvm.ptr<3> -> vector<4xi32> + // CHECK: llvm.load %{{.+}} {alias_scopes = [#[[SCOPE_B]]], noalias_scopes = [#[[SCOPE_A]]]} + %vb = llvm.load %b_p : !llvm.ptr<3> -> vector<4xi32> + %sum = llvm.add %va, %vb : vector<4xi32> + llvm.return %sum : vector<4xi32> + } +} + +// ----- + +// ----------------------------------------------------------------------------- +// Single-global module: pass is a no-op. Tagging a single scope gives the +// SI Wait Counter pass nothing extra to disambiguate. +// ----------------------------------------------------------------------------- + +// CHECK-LABEL: gpu.module @one_named +gpu.module @one_named { + llvm.mlir.global external @only_buf() {addr_space = 3 : i32, alignment = 1024 : i64, dso_local} : !llvm.array<0 x i8> + + // CHECK-LABEL: llvm.func @load_only + llvm.func @load_only() -> vector<4xi32> { + %p = llvm.mlir.addressof @only_buf : !llvm.ptr<3> + // CHECK: llvm.load + // CHECK-NOT: alias_scopes + %v = llvm.load %p : !llvm.ptr<3> -> vector<4xi32> + llvm.return %v : vector<4xi32> + } +} + +// ----- + +// ----------------------------------------------------------------------------- +// Static [N x i8] LDS globals (N > 0, the SmemAllocator pattern) are skipped. +// Their alias info already comes from distinct LLVM symbols. +// ----------------------------------------------------------------------------- + +// CHECK-LABEL: gpu.module @static_lds +gpu.module @static_lds { + llvm.mlir.global external @smem_a() {addr_space = 3 : i32, alignment = 1024 : i64, dso_local} : !llvm.array<4096 x i8> + llvm.mlir.global external @smem_b() {addr_space = 3 : i32, alignment = 1024 : i64, dso_local} : !llvm.array<4096 x i8> + + // CHECK-LABEL: llvm.func @load_static + llvm.func @load_static() -> vector<4xi32> { + %p = llvm.mlir.addressof @smem_a : !llvm.ptr<3> + // CHECK: llvm.load + // CHECK-NOT: alias_scopes + %v = llvm.load %p : !llvm.ptr<3> -> vector<4xi32> + llvm.return %v : vector<4xi32> + } +} + +// ----- + +// ----------------------------------------------------------------------------- +// Ambiguous provenance: an `add` whose lhs is `ptrtoint(@A)` and rhs is +// `ptrtoint(@B)` produces an int that simultaneously carries provenance for +// both globals. Anything downstream must NOT be tagged with a single scope, +// otherwise we'd be telling LLVM "no alias to @B" about a load that may very +// well land in @B's region. +// ----------------------------------------------------------------------------- + +// CHECK-LABEL: gpu.module @ambiguous_add +gpu.module @ambiguous_add { + llvm.mlir.global external @amb_a() {addr_space = 3 : i32, alignment = 1024 : i64, dso_local} : !llvm.array<0 x i8> + llvm.mlir.global external @amb_b() {addr_space = 3 : i32, alignment = 1024 : i64, dso_local} : !llvm.array<0 x i8> + + // CHECK-LABEL: llvm.func @load_ambiguous + llvm.func @load_ambiguous(%c: i32) -> vector<4xi32> { + %a = llvm.mlir.addressof @amb_a : !llvm.ptr<3> + %b = llvm.mlir.addressof @amb_b : !llvm.ptr<3> + %ai = llvm.ptrtoint %a : !llvm.ptr<3> to i32 + %bi = llvm.ptrtoint %b : !llvm.ptr<3> to i32 + %amb = llvm.add %ai, %bi : i32 + %off = llvm.add %amb, %c : i32 + %p = llvm.inttoptr %off : i32 to !llvm.ptr<3> + // CHECK: llvm.load + // CHECK-NOT: alias_scopes + %v = llvm.load %p : !llvm.ptr<3> -> vector<4xi32> + llvm.return %v : vector<4xi32> + } +} + +// ----- + +// ----------------------------------------------------------------------------- +// `or`/`sub`/`xor` are NOT canonical pointer arithmetic via int. Even when +// they happen to be equivalent to `add` they can break provenance, so the +// pass refuses to forward through them. +// ----------------------------------------------------------------------------- + +// CHECK-LABEL: gpu.module @nontracked_op +gpu.module @nontracked_op { + llvm.mlir.global external @nt_a() {addr_space = 3 : i32, alignment = 1024 : i64, dso_local} : !llvm.array<0 x i8> + llvm.mlir.global external @nt_b() {addr_space = 3 : i32, alignment = 1024 : i64, dso_local} : !llvm.array<0 x i8> + + // CHECK-LABEL: llvm.func @load_via_or + llvm.func @load_via_or(%mask: i32) -> vector<4xi32> { + %a = llvm.mlir.addressof @nt_a : !llvm.ptr<3> + %ai = llvm.ptrtoint %a : !llvm.ptr<3> to i32 + %off = llvm.or %ai, %mask : i32 + %p = llvm.inttoptr %off : i32 to !llvm.ptr<3> + // CHECK: llvm.load + // CHECK-NOT: alias_scopes + %v = llvm.load %p : !llvm.ptr<3> -> vector<4xi32> + llvm.return %v : vector<4xi32> + } +} + +// ----- + +// ----------------------------------------------------------------------------- +// Pointer flowing through a block argument (LLVM phi) loses provenance: the +// entry to ^bb1 doesn't know which addressof produced %p, so the load must +// stay untagged regardless of which predecessor branched in. +// ----------------------------------------------------------------------------- + +// CHECK-LABEL: gpu.module @phi_block_arg +gpu.module @phi_block_arg { + llvm.mlir.global external @phi_a() {addr_space = 3 : i32, alignment = 1024 : i64, dso_local} : !llvm.array<0 x i8> + llvm.mlir.global external @phi_b() {addr_space = 3 : i32, alignment = 1024 : i64, dso_local} : !llvm.array<0 x i8> + + // CHECK-LABEL: llvm.func @load_via_phi + llvm.func @load_via_phi(%cond: i1) -> vector<4xi32> { + %a = llvm.mlir.addressof @phi_a : !llvm.ptr<3> + %b = llvm.mlir.addressof @phi_b : !llvm.ptr<3> + llvm.cond_br %cond, ^bb1(%a : !llvm.ptr<3>), ^bb1(%b : !llvm.ptr<3>) + ^bb1(%p: !llvm.ptr<3>): + // CHECK: llvm.load + // CHECK-NOT: alias_scopes + %v = llvm.load %p : !llvm.ptr<3> -> vector<4xi32> + llvm.return %v : vector<4xi32> + } +} + +// ----- + +// ----------------------------------------------------------------------------- +// Deep arithmetic chain through gep + add + inttoptr still resolves to the +// originating global. Two named globals so the pass actually runs (single +// global short-circuits). +// ----------------------------------------------------------------------------- + +// CHECK-LABEL: gpu.module @deep_chain +gpu.module @deep_chain { + llvm.mlir.global external @deep_a() {addr_space = 3 : i32, alignment = 1024 : i64, dso_local} : !llvm.array<0 x i8> + llvm.mlir.global external @deep_b() {addr_space = 3 : i32, alignment = 1024 : i64, dso_local} : !llvm.array<0 x i8> + + // CHECK-LABEL: llvm.func @load_deep + llvm.func @load_deep(%c0: i32, %c1: i32) -> vector<4xi32> { + %a = llvm.mlir.addressof @deep_a : !llvm.ptr<3> + %a_gep = llvm.getelementptr %a[1] : (!llvm.ptr<3>) -> !llvm.ptr<3>, i8 + %a_int = llvm.ptrtoint %a_gep : !llvm.ptr<3> to i32 + %a_off1 = llvm.add %a_int, %c0 : i32 + %a_off2 = llvm.add %a_off1, %c1 : i32 + %a_p = llvm.inttoptr %a_off2 : i32 to !llvm.ptr<3> + // CHECK: llvm.load %{{.+}} {alias_scopes = [#{{.*}}], noalias_scopes = [#{{.*}}]} : !llvm.ptr<3> -> vector<4xi32> + %v = llvm.load %a_p : !llvm.ptr<3> -> vector<4xi32> + llvm.return %v : vector<4xi32> + } +} + +// ----- + +// ----------------------------------------------------------------------------- +// Mixed kernel: dyn-shared (gets tagged) and static [N x i8] (skipped) +// coexist. Only the dyn-shared loads carry alias scopes. +// ----------------------------------------------------------------------------- + +// CHECK-LABEL: gpu.module @mixed_dyn_static +gpu.module @mixed_dyn_static { + llvm.mlir.global external @mix_dyn_a() {addr_space = 3 : i32, alignment = 1024 : i64, dso_local} : !llvm.array<0 x i8> + llvm.mlir.global external @mix_dyn_b() {addr_space = 3 : i32, alignment = 1024 : i64, dso_local} : !llvm.array<0 x i8> + llvm.mlir.global external @mix_static() {addr_space = 3 : i32, alignment = 1024 : i64, dso_local} : !llvm.array<4096 x i8> + + // CHECK-LABEL: llvm.func @load_mixed + llvm.func @load_mixed(%off: i32) -> vector<4xi32> { + %da = llvm.mlir.addressof @mix_dyn_a : !llvm.ptr<3> + %da_i = llvm.ptrtoint %da : !llvm.ptr<3> to i32 + %da_o = llvm.add %da_i, %off : i32 + %da_p = llvm.inttoptr %da_o : i32 to !llvm.ptr<3> + // CHECK: llvm.load %{{.+}} {alias_scopes = [#{{.*}}], noalias_scopes = [#{{.*}}]} : !llvm.ptr<3> -> vector<4xi32> + %v_da = llvm.load %da_p : !llvm.ptr<3> -> vector<4xi32> + + %s = llvm.mlir.addressof @mix_static : !llvm.ptr<3> + // CHECK: llvm.load %{{.+}} : !llvm.ptr<3> -> vector<4xi32> + %v_s = llvm.load %s : !llvm.ptr<3> -> vector<4xi32> + + %sum = llvm.add %v_da, %v_s : vector<4xi32> + llvm.return %sum : vector<4xi32> + } +}