From 826bda01990ff100749895249e792d3f2b9c04c7 Mon Sep 17 00:00:00 2001 From: Alexander Weinrauch Date: Tue, 15 Apr 2025 09:16:20 +0000 Subject: [PATCH 01/44] [FA] 4-stage FA pipeliner 4-stage FA experiment Cluster assignment --- .../Dialect/TritonGPU/Transforms/Schedule.h | 5 + .../Transforms/Pipeliner/Schedule.cpp | 19 ++- third_party/amd/backend/compiler.py | 16 +++ .../TritonAMDGPUTransforms/StreamPipeline.cpp | 132 ++++++++++++++---- 4 files changed, 139 insertions(+), 33 deletions(-) diff --git a/include/triton/Dialect/TritonGPU/Transforms/Schedule.h b/include/triton/Dialect/TritonGPU/Transforms/Schedule.h index 9aae78062324..e34e41e54c19 100644 --- a/include/triton/Dialect/TritonGPU/Transforms/Schedule.h +++ b/include/triton/Dialect/TritonGPU/Transforms/Schedule.h @@ -115,6 +115,11 @@ class CoarseSchedule { bool insertDepsOfOp(Operation *op, int stage, CoarseSchedule::Cluster cluster, bool includeArg, bool insertIfEarlier = false); + bool insertDepsOfOp( + Operation *op, bool includeArg, bool insertIfEarlier, + llvm::function_ref(Operation *)> + getStageClusterForOp); + void erase(Operation *op) { opToStageAndCluster.erase(op); } int count(Operation *op) { return opToStageAndCluster.count(op); } diff --git a/lib/Dialect/TritonGPU/Transforms/Pipeliner/Schedule.cpp b/lib/Dialect/TritonGPU/Transforms/Pipeliner/Schedule.cpp index aafbd5e8e8ac..3227af618ff2 100644 --- a/lib/Dialect/TritonGPU/Transforms/Pipeliner/Schedule.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Pipeliner/Schedule.cpp @@ -53,8 +53,17 @@ bool tt::CoarseSchedule::insertMinimum(Operation *op, int stage, bool tt::CoarseSchedule::insertDepsOfOp(Operation *op, int stage, tt::CoarseSchedule::Cluster cluster, bool includeArg, bool insertIfEarlier) { - auto tryInsert = [&](Operation *op, int stage, - tt::CoarseSchedule::Cluster cluster) { + auto func = [=](Operation *) { return std::make_pair(stage, cluster); }; + return insertDepsOfOp(op, includeArg, insertIfEarlier, func); +} + +bool tt::CoarseSchedule::insertDepsOfOp( + Operation *op, bool includeArg, bool insertIfEarlier, + llvm::function_ref(Operation *)> + getStageAndClusterForOp) { + auto tryInsert = [&insertIfEarlier, + this](Operation *op, int stage, + tt::CoarseSchedule::Cluster cluster) { if (!insertIfEarlier) return insertIfAbsent(op, stage, cluster); return insertMinimum(op, stage, cluster); @@ -78,9 +87,11 @@ bool tt::CoarseSchedule::insertDepsOfOp(Operation *op, int stage, } Operation *defOp = v.getDefiningOp(); if (defOp && defOp->getBlock() == op->getBlock()) { - if (tryInsert(defOp, stage, cluster)) { + auto [defStage, defCluster] = getStageAndClusterForOp(defOp); + if (tryInsert(defOp, defStage, defCluster)) { inserted = true; - insertDepsOfOp(defOp, stage, cluster, includeArg, insertIfEarlier); + insertDepsOfOp(defOp, includeArg, insertIfEarlier, + getStageAndClusterForOp); } } } diff --git a/third_party/amd/backend/compiler.py b/third_party/amd/backend/compiler.py index cfdc5e9f0134..8bf06547d33e 100644 --- a/third_party/amd/backend/compiler.py +++ b/third_party/amd/backend/compiler.py @@ -236,7 +236,18 @@ def make_ttgir(mod, metadata, options): if options.schedule_hint == "local-prefetch": global_prefetch = local_prefetch = 1 + # passes.ttgpuir.add_pipeline(pm, options.num_stages, False) amd.passes.ttgpuir.add_stream_pipeline(pm, options.num_stages, global_prefetch, local_prefetch, use_async_copy) + + if False: + pm.run(mod) + with open("mid.mlir", "w") as f: + print(mod, file=f) + context = mod.context + mod = ir.parse_mlir_module("mod.mlir", context) + mod.context = context + pm = ir.pass_manager(mod.context) + if use_async_copy: amd.passes.ttgpuir.add_coalesce_async_copy(pm, options.arch) passes.common.add_canonicalizer(pm) @@ -397,6 +408,11 @@ def make_amdgcn(src, metadata, options): if knobs.amd.dump_amdgcn: print("// -----// AMDGCN Dump //----- //") print(amdgcn) + # if amdgcn.find("_attn_fwd") + # with open("out.amdgcn", "r") as f: + # amdgcn = f.read() + with open("out.amdgcn", "w") as f: + f.write(amdgcn) return amdgcn @staticmethod diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/StreamPipeline.cpp b/third_party/amd/lib/TritonAMDGPUTransforms/StreamPipeline.cpp index fc9aa0a7cafa..4b95e2f2f930 100644 --- a/third_party/amd/lib/TritonAMDGPUTransforms/StreamPipeline.cpp +++ b/third_party/amd/lib/TritonAMDGPUTransforms/StreamPipeline.cpp @@ -113,7 +113,7 @@ class StreamPipeliner { SCHED_LOCAL_STORE, SCHED_LOCAL_LOAD, SCHED_COMPUTE, - SCHED_ASYNC_WAIT, + // SCHED_ASYNC_WAIT, SCHED_SIZE }; @@ -127,8 +127,10 @@ class StreamPipeliner { stages[SCHED_GLOBAL_LOAD] = 0; stages[SCHED_LOCAL_STORE] = _globalPrefetch; stages[SCHED_LOCAL_LOAD] = lastStage - _localPrefetch; + // AsyncWait should be in same stage as the LocalLoad + // stages[SCHED_ASYNC_WAIT] = stages[SCHED_LOCAL_LOAD]; stages[SCHED_COMPUTE] = lastStage; - stages[SCHED_ASYNC_WAIT] = stages[SCHED_LOCAL_LOAD]; + // stages[SCHED_ASYNC_WAIT] = stages[SCHED_LOCAL_LOAD]; options.supportDynamicLoops = true; options.peelEpilogue = true; @@ -178,6 +180,7 @@ class StreamPipeliner { int stages[SCHED_SIZE]; // Cluster for each SchedType Op std::array clusters; + std::array clusterVec; // Scheduling clusters tt::CoarseSchedule schedule; @@ -219,13 +222,14 @@ LogicalResult StreamPipeliner::initSchedule(int maxIndirectionLevel) { bool pairedGlobalLoadLocalStore = stages[SCHED_LOCAL_STORE] == 0; stages[SCHED_LOCAL_STORE] += maxIndirectionLevel; - LDBG( - "Stage schedule:" << " GLOBAL_LOAD stage = " << stages[SCHED_GLOBAL_LOAD] - << ", LOCAL_STORE stage = " << stages[SCHED_LOCAL_STORE] - << ", LOCAL_LOAD stage = " << stages[SCHED_LOCAL_LOAD] - << ", COMPUTE stage = " << stages[SCHED_COMPUTE] - << ", ASYNC_WAIT stage = " << stages[SCHED_ASYNC_WAIT] - << "; total = " << numStages); + LDBG("Stage schedule:" + << " GLOBAL_LOAD stage = " << stages[SCHED_GLOBAL_LOAD] + << ", LOCAL_STORE stage = " << stages[SCHED_LOCAL_STORE] + << ", LOCAL_LOAD stage = " << stages[SCHED_LOCAL_LOAD] + << ", COMPUTE stage = " + << stages[SCHED_COMPUTE] + // << ", ASYNC_WAIT stage = " << stages[SCHED_ASYNC_WAIT] + << "; total = " << numStages); if (stages[SCHED_LOCAL_STORE] >= numStages || stages[SCHED_LOCAL_STORE] > stages[SCHED_LOCAL_LOAD]) { @@ -242,6 +246,7 @@ LogicalResult StreamPipeliner::initSchedule(int maxIndirectionLevel) { if (useAsyncCopy) { numBuffers += 1; } + numBuffers = 2; LDBG("deduced max shared memory buffer number = " << numBuffers); @@ -282,15 +287,27 @@ LogicalResult StreamPipeliner::initSchedule(int maxIndirectionLevel) { } // Make assignments - std::array clusterVec; - std::generate(clusterVec.begin(), clusterVec.end(), - [&]() { return schedule.clusters.newAtBack(); }); + // std::array clusterVec; + // std::generate(clusterVec.begin(), clusterVec.end(), + // [&]() { return schedule.clusters.newAtBack(); }); + clusterVec = {schedule.clusters.newAtBack(), schedule.clusters.newAtBack(), + schedule.clusters.newAtBack(), schedule.clusters.newAtBack()}; clusters[SCHED_GLOBAL_LOAD] = clusterVec[globalLoadCluster]; clusters[SCHED_LOCAL_STORE] = clusterVec[localStoreCluster]; clusters[SCHED_LOCAL_LOAD] = clusterVec[localLoadCluster]; clusters[SCHED_COMPUTE] = clusterVec[computeCluster]; - clusters[SCHED_ASYNC_WAIT] = clusterVec[asyncWaitCluster]; + // clusters[SCHED_ASYNC_WAIT] = clusterVec[asyncWaitCluster]; + + // ATTENTION 4-stage + clusters[SCHED_GLOBAL_LOAD] = clusterVec[2]; + clusters[SCHED_LOCAL_STORE] = clusterVec[1]; + clusters[SCHED_LOCAL_LOAD] = clusterVec[1]; + clusters[SCHED_COMPUTE] = clusterVec[0]; + + // Always have ASYNC_WAIT as the first cluster because we want it at the top + // of the schedule block + // clusters[SCHED_ASYNC_WAIT] = schedule.clusters.newAtFront(); LDBG("Cluster schedule:" << " GLOBAL_LOAD cluster = " << globalLoadCluster << ", LOCAL_STORE cluster = " << localStoreCluster @@ -355,26 +372,32 @@ bool StreamPipeliner::createAsyncCopy(tt::LoadOp loadOp, Value alloc, ttg::AsyncWaitOp waitOp = builder.create(loc, commitOp->getResult(0), 0); + // scheduleOp(waitOp, SCHED_ASYNC_WAIT); // Create local load which consumes the async token from the AsyncWait auto sharedLoad = builder.create(loc, loadOp.getType(), viewLoad, waitOp); auto [loadStage, loadCluster] = schedule[loadOp]; + auto localLoadStage = loadStage == 0 ? 1 : 3; + auto localLoadCluster = loadStage == 0 ? 3 : 1; + schedule.erase(loadOp); // Schedule new ops schedule.insert(copyOp, loadStage, loadCluster); // Place ttg.async_commit_group op following AsyncCopyGlobalToLocal so the // later UpdateAsyncWaitCount pass can deduce better waitcnts schedule.insert(commitOp, loadStage, loadCluster); - // If the LocalLoads are scheduled to a later stage than AsyncCopy we need to - // place the AsyncCopy prefetches after the AsyncWaits which create a barrier - // to ensure all warps are finished reading the shared buffer we will write - // into. This is done by scheduling AsyncWait as the first cluster. - // If AsyncCopy and LocalLoads are in the same stage we do not assign a - // schdule so they are placed before the LocalLoads - if (loadStage != stages[SCHED_LOCAL_LOAD]) - scheduleOp(waitOp, SCHED_ASYNC_WAIT); + // If the LocalLoads are scheduled to a later stage than AsyncCopy we need + // to place the AsyncCopy prefetches after the AsyncWaits which create a + // barrier to ensure all warps are finished reading the shared buffer we + // will write into. This is done by scheduling AsyncWait as the first + // cluster. If AsyncCopy and LocalLoads are in the same stage we do not + // assign a schdule so they are placed before the LocalLoads + schedule.insert(sharedLoad, localLoadStage, clusterVec[localLoadCluster]); + + // if (loadStage != stages[SCHED_LOCAL_LOAD]) + // scheduleOp(waitOp, SCHED_ASYNC_WAIT); if (stages[SCHED_LOCAL_LOAD] != stages[SCHED_COMPUTE]) scheduleOp(sharedLoad, SCHED_LOCAL_LOAD); @@ -703,21 +726,31 @@ LogicalResult StreamPipeliner::scheduleLoads(DenseSet &rootUsers) { llvm::divideCeil(numStages - 2, maxIndirectionLevel + 1); LDBG("stagesBetweenLoads = " << stagesBetweenLoads); + // Assign stages to the loads. + // FA: + // Load1: Stage=0, cluster=1 + // Load2: Stage=1, cluster=3 + int i{}; + for (auto [loadOp, indLevel, _] : loadOpToIndLevelAndUse) { + int stage = (maxIndirectionLevel - indLevel) * stagesBetweenLoads; + if (schedule.count(loadOp) > 0) + continue; + // scheduleOp(loadOp, SCHED_GLOBAL_LOAD, stage); + schedule.insert(loadOp, i, clusterVec[i == 0 ? 1 : 3]); + i++; + } + // Put the root uses of the loads in the last stage. for (auto &[loadOp, dist, use] : loadOpToIndLevelAndUse) { // Non-LoadOp(s) are the (final) root uses of all LoadOp(s). if (!isa(use)) { - scheduleOp(use, SCHED_COMPUTE); + auto loadStage = schedule[loadOp].first; + schedule.insert(use, loadStage + 2, clusterVec[loadStage == 0 ? 0 : 2]); + // scheduleOp(use, SCHED_COMPUTE); rootUsers.insert(use); } } - // Assign stages to the loads. - for (auto [loadOp, indLevel, _] : loadOpToIndLevelAndUse) { - int stage = (maxIndirectionLevel - indLevel) * stagesBetweenLoads; - scheduleOp(loadOp, SCHED_GLOBAL_LOAD, stage); - } - // Calculate distance from the load to the use. for (auto [loadOp, _, use] : loadOpToIndLevelAndUse) { loadToInfo[loadOp].distToUse = schedule[use].first - schedule[loadOp].first; @@ -745,7 +778,29 @@ void StreamPipeliner::scheduleDependencies() { for (auto [op, stage_, cluster] : opsInOrder) { if (stage_ != stage) continue; - schedule.insertDepsOfOp(op, stage, cluster, false); + auto depCluster = cluster; + LDBG("Stage: " << stage); + bool override = false; + if (llvm::isa(op) && stage == 3) { + LDBG("Update sched to 0"); + depCluster = clusterVec[0]; + override = true; + } + + auto moveStages = [this, stage, cluster = cluster, + depCluster = depCluster, override](Operation *op) { + LDBG("Schedule Op: " << *op); + if (llvm::isa(op)) { + LDBG("Is a cvt layout\n"); + return std::make_pair(stage, cluster); + } + if (override) { + LDBG("Override to 0!"); + // return std::make_pair(stage, clusterVec[0]); + } + return std::make_pair(stage, depCluster); + }; + schedule.insertDepsOfOp(op, false, false, moveStages); } } } @@ -919,6 +974,25 @@ LogicalResult StreamPipeliner::preprocessLoopAndBuildSchedule() { // Convert the loads into shared memory allocations and loads from them. createStreamOps(); + LLVM_DEBUG({ + LDBG("Coarse schedule with replaced laod ops:"); + schedule.dump(); + }); + + // Schedule reductions + int c = 2; + for (auto reduceOp : forOp.getBody()->getOps()) { + schedule.insert(reduceOp, c, clusterVec[c == 2 ? 2 : 0]); + c++; + } + + for (auto exp2Op : forOp.getBody()->getOps()) { + schedule.insert(exp2Op, 2, clusterVec[2]); + } + LLVM_DEBUG({ + LDBG("Coarse schedule after schedule reduction:"); + schedule.dump(); + }); scheduleDependencies(); LLVM_DEBUG({ From c35e297d7ff847c21467b9ba91d4b0c580125002 Mon Sep 17 00:00:00 2001 From: Alexander Weinrauch Date: Tue, 15 Apr 2025 17:28:17 +0000 Subject: [PATCH 02/44] [FA] Add FA scripts --- fa/flash-attention.py | 2136 +++++++++++++++++++++++++++++++++ fa/model_configs.json | 42 + fa/utils/__init__.py | 0 fa/utils/benchmark_utils.py | 71 ++ fa/utils/rocprof_benchmark.py | 59 + fa/utils/rotary_embedding.py | 283 +++++ fa/utils/sglang_ref.py | 619 ++++++++++ 7 files changed, 3210 insertions(+) create mode 100644 fa/flash-attention.py create mode 100644 fa/model_configs.json create mode 100644 fa/utils/__init__.py create mode 100644 fa/utils/benchmark_utils.py create mode 100644 fa/utils/rocprof_benchmark.py create mode 100644 fa/utils/rotary_embedding.py create mode 100644 fa/utils/sglang_ref.py diff --git a/fa/flash-attention.py b/fa/flash-attention.py new file mode 100644 index 000000000000..3c982482309a --- /dev/null +++ b/fa/flash-attention.py @@ -0,0 +1,2136 @@ +""" +Fused Attention +=============== + +This is a Triton implementation of the Flash Attention v2 algorithm +See https://tridao.me/publications/flash2/flash2.pdf + +Credits: +AMD Triton kernels team +OpenAI kernel team + +Currently only the forward kernel is supported, and contains these features: + +1) Fwd with causal masking +2) Arbitrary Q and KV sequence lengths +3) Arbitrary head sizes +4) Multi and grouped query attention +5) Variable sequence lengths +6) ALiBi and matrix bias + +""" + +import argparse +import subprocess +import pytest +import sys +import torch + +import triton +import triton.language as tl +from utils.benchmark_utils import get_available_models, get_model_configs + + +class MetaData(): + cu_seqlens_q = None + cu_seqlens_k = None + max_seqlens_q = 0 + max_seqlens_k = 0 + bias = None + alibi_slopes = None + causal = False + persistent = None + num_contexts = 0 + varlen = False + int8 = False + layout = None + dropout_p, return_encoded_softmax = 0.0, False + + def __init__(self, sm_scale=1.0): + self.sm_scale = sm_scale + + def set_varlen_params(self, cu_seqlens_q, cu_seqlens_k): + self.varlen = True + self.layout = 'thd' + self.cu_seqlens_q = cu_seqlens_q + self.cu_seqlens_k = cu_seqlens_k + # Without "varlen", there should still be one sequence. + assert len(cu_seqlens_q) >= 2 + assert len(cu_seqlens_q) == len(cu_seqlens_k) + self.num_contexts = len(cu_seqlens_q) - 1 + for i in range(0, self.num_contexts): + self.max_seqlens_q = max(cu_seqlens_q[i + 1].item() - cu_seqlens_q[i].item(), self.max_seqlens_q) + self.max_seqlens_k = max(cu_seqlens_k[i + 1].item() - cu_seqlens_k[i].item(), self.max_seqlens_k) + + def set_persistent(self, persistent): + self.persistent = persistent + + def set_int8_params(self, q_descale, k_descale, v_descale, p_scale, p_descale): + self.int8 = True + self.q_descale = q_descale + self.k_descale = k_descale + self.v_descale = v_descale + self.p_scale = p_scale + self.p_descale = p_descale + self.use_p_scale = (p_scale is not None) and (p_descale is not None) and (v_descale is not None) + self.int8_kv = (q_descale is None) and (k_descale is not None) and (v_descale is not None) + + def need_bias(self, bias, batch, nheads, seqlen_q, seqlen_k): + assert bias.is_cuda + assert bias.dim() == 4 + assert bias.shape[0] == 1 + assert bias.shape[2:] == (seqlen_q, seqlen_k) + self.bias = bias + + def need_alibi(self, alibi_slopes, batch, nheads): + assert alibi_slopes.is_cuda + assert alibi_slopes.dim() == 2 + assert alibi_slopes.shape[0] == batch + assert alibi_slopes.shape[1] == nheads + self.alibi_slopes = alibi_slopes + + def need_causal(self): + self.causal = True + + def need_dropout(self, dropout_p, return_encoded_softmax): + self.dropout_p = dropout_p + self.return_encoded_softmax = return_encoded_softmax + + def check_args(self, q, k, v, o): + assert q.dim() == k.dim() and q.dim() == v.dim() + + batch, nheads_q, nheads_k, head_size = get_shape_from_layout(q, k, self) + if self.varlen: + assert q.dim() == 3 + assert self.cu_seqlens_q is not None + assert self.cu_seqlens_k is not None + assert len(self.cu_seqlens_q) == len(self.cu_seqlens_k) + # TODO: Remove once bias is supported with varlen + assert self.bias is None + # TODO:Remove once dropout is supported with varlen + assert self.dropout_p == 0.0 + assert not self.return_encoded_softmax + else: + assert q.dim() == 4 + assert self.max_seqlens_q > 0 and self.max_seqlens_k > 0 + assert self.cu_seqlens_q is None and self.cu_seqlens_k is None + assert k.shape == v.shape + assert q.shape[-1] == k.shape[-1] and q.shape[-1] == v.shape[-1] + # TODO: Change assert if we support qkl f8 and v f16 + if self.int8: + if self.int8_kv: + assert v.dtype == k.dtype and k.dtype == torch.int8 + assert q.dtype != k.dtype + assert (self.v_descale is not None) and (self.k_descale is not None) + else: + assert q.dtype == k.dtype and q.dtype == v.dtype and q.dtype == torch.int8 + assert (self.q_descale is not None) and (self.k_descale is not None) and (self.v_descale is not None) + if self.use_p_scale: + assert (self.p_scale is not None) and (self.p_descale is not None) + else: + assert q.dtype == k.dtype and q.dtype == v.dtype + assert head_size <= 256 + assert o.shape == q.shape + assert (nheads_q % nheads_k) == 0 + assert self.layout is not None + assert self.layout == 'thd' or not self.varlen + + +@triton.jit +def cdiv_fn(x, y): + return (x + y - 1) // y + + +@triton.jit +def max_fn(x, y): + return tl.math.max(x, y) + + +@triton.jit +def dropout_offsets(philox_seed, philox_offset, dropout_p, m, n, stride): + ms = tl.arange(0, m) + ns = tl.arange(0, n) + return philox_offset + ms[:, None] * stride + ns[None, :] + + +@triton.jit +def dropout_rng(philox_seed, philox_offset, dropout_p, m, n, stride): + rng_offsets = dropout_offsets(philox_seed, philox_offset, dropout_p, m, n, stride).to(tl.uint32) + # TODO: use tl.randint for better performance + return tl.rand(philox_seed, rng_offsets) + + +@triton.jit +def dropout_mask(philox_seed, philox_offset, dropout_p, m, n, stride): + rng_output = dropout_rng(philox_seed, philox_offset, dropout_p, m, n, stride) + rng_keep = rng_output > dropout_p + return rng_keep + + +# Convenience function to load with optional boundary checks. +# "First" is the major dim, "second" is the minor dim. +@triton.jit +def load_fn(ptrs, offset_first, offset_second, boundary_first, boundary_second): + if offset_first is not None and offset_second is not None: + mask = (offset_first[:, None] < boundary_first) & \ + (offset_second[None, :] < boundary_second) + tensor = tl.load(ptrs, mask=mask, other=0.0) + elif offset_first is not None: + mask = offset_first[:, None] < boundary_first + tensor = tl.load(ptrs, mask=mask, other=0.0) + elif offset_second is not None: + mask = offset_second[None, :] < boundary_second + tensor = tl.load(ptrs, mask=mask, other=0.0) + else: + tensor = tl.load(ptrs) + return tensor + + +@triton.jit +def print_gpu(prefix, val=None): + if (tl.program_id(0) == 0) and ((tl.program_id(1) == 0) and (tl.program_id(2) == 0)): + if val is not None: + tl.device_print(prefix, val) + else: + tl.device_print(prefix) + + +@triton.jit +def compute_alibi_block(alibi_slope, seqlen_q, seqlen_k, offs_m, offs_n, transpose=False): + # when seqlen_k and seqlen_q are different we want the diagonal to stick to the bottom right of the attention matrix + # for casual mask we want something like this where (1 is kept and 0 is masked) + # seqlen_q = 2 and seqlen_k = 5 + # 1 1 1 1 0 + # 1 1 1 1 1 + # seqlen_q = 5 and seqlen_k = 2 + # 0 0 + # 0 0 + # 0 0 + # 1 0 + # 1 1 + # for alibi the diagonal is 0 indicating no penalty for attending to that spot and increasing penalty for attending further from the diagonal + # e.g. alibi_slope = 1, seqlen_q = 2, seqlen_k = 5, offs_m = [0, 1, 2, 3], offs_n = [0, 1, 2, 3, 4], transpose = False + # 1. offs_m[:,None] = [[0], + # [1], + # 2. offs_m[:,None] + seqlen_k = [[5], + # [6], + # 3. offs_m[:,None] + seqlen_k - seqlen_q = [[3], + # [4], + # 4. offs_m[:,None] + seqlen_k - seqlen_q - offs_n[None,:] = [[3], - [[0, 1, 2, 3, 4]] = [[ 3, 2, 1, 0,-1], + # [4], [ 4, 3, 2, 1, 0]] + # 5. -1 * alibi_slope * tl.abs(relative_pos_block) = [[ -3, -2, -1, 0,-1], + # [ -4, -3, -2, -1, 0]], + relative_pos_block = offs_m[:, None] + seqlen_k - seqlen_q - offs_n[None, :] + alibi_block = -1 * alibi_slope * tl.abs(relative_pos_block) + if transpose: + return alibi_block.T + else: + return alibi_block + + +def compute_alibi_tensor(alibi_slopes, seqlen_q, seqlen_k): + q_idx = torch.arange(seqlen_q, dtype=torch.int32, device="cuda").unsqueeze(-1) # (N_CTX_Q, 1) + k_idx = torch.arange(seqlen_k, dtype=torch.int32, device="cuda").unsqueeze(0) # (1, N_CTX_K) + relative_pos = torch.abs(q_idx + seqlen_k - seqlen_q - k_idx) # (N_CTX_Q, N_CTX_K) + return -1 * alibi_slopes.unsqueeze(-1).unsqueeze(-1) * relative_pos # (Z, H, N_CTX_Q, N_CTX_K) + + +@triton.jit +def _attn_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, bias_ptrs, stride_kn, stride_vk, stride_bn, start_m, + actual_seqlen_k, actual_seqlen_q, dropout_p, philox_seed, batch_philox_offset, encoded_sm_ptrs, + block_min, block_max, offs_n_causal, masked_blocks, n_extra_tokens, alibi_slope, q_descale, + k_descale, v_descale, p_scale, IS_CAUSAL: tl.constexpr, BLOCK_M: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, OFFS_M: tl.constexpr, OFFS_N: tl.constexpr, + PRE_LOAD_V: tl.constexpr, MASK_STEPS: tl.constexpr, ENABLE_DROPOUT: tl.constexpr, + RETURN_ENCODED_SOFTMAX: tl.constexpr, PADDED_HEAD: tl.constexpr, ACTUAL_BLOCK_DMODEL: tl.constexpr, + QK_SCALE: tl.constexpr, INT8_GEMM: tl.constexpr, USE_P_SCALE: tl.constexpr, INT8_KV: tl.constexpr): + # loop over k, v, and update accumulator + for start_n in range(block_min, block_max, BLOCK_N): + # For padded blocks, we will overrun the tensor size if + # we load all BLOCK_N. For others, the blocks are all within range. + if MASK_STEPS: + k_offs_n = start_n + tl.arange(0, BLOCK_N) + else: + k_offs_n = None + k_offs_k = None if not PADDED_HEAD else tl.arange(0, BLOCK_DMODEL) + k = load_fn(k_ptrs, k_offs_k, k_offs_n, ACTUAL_BLOCK_DMODEL, actual_seqlen_k) + if PRE_LOAD_V: + # We can use the same offsets as k, just with dims transposed. + v = load_fn(v_ptrs, k_offs_n, k_offs_k, actual_seqlen_k, ACTUAL_BLOCK_DMODEL) + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + # We start from end of seqlen_k so only the first iteration would need + # to be checked for padding if it is not a multiple of block_n + # TODO: This can be optimized to only be true for the padded block. + if MASK_STEPS: + # If this is the last block / iteration, we want to + # mask if the sequence length is not a multiple of block size + # a solution is to always do BLOCK_M // BLOCK_N + 1 steps if not is_modulo_mn. + # last step might get wasted but that is okay. check if this masking works For + # that case. + if (start_n + BLOCK_N == block_max) and (n_extra_tokens != 0): + boundary_m = tl.full([BLOCK_M], actual_seqlen_k, dtype=tl.int32) + size_n = start_n + OFFS_N[None, :] + mask = size_n < boundary_m[:, None] + qk = tl.where(mask, qk, float("-inf")) + if IS_CAUSAL: + causal_boundary = start_n + offs_n_causal + causal_mask = OFFS_M[:, None] >= causal_boundary[None, :] + qk = tl.where(causal_mask, qk, float("-inf")) + # -- compute qk ---- + if INT8_GEMM: + qk += ((((tl.dot(q, k).to(tl.float32) * q_descale)) * k_descale) * QK_SCALE) + else: + if INT8_KV: + k = (k * k_descale).to(q.type.element_ty) + qk += (tl.dot(q, k) * QK_SCALE) + + if bias_ptrs is not None: + bias_offs_n = start_n + tl.arange(0, BLOCK_N) if MASK_STEPS else None + bias = load_fn(bias_ptrs, OFFS_M, bias_offs_n, actual_seqlen_q, actual_seqlen_k) + # While bias is added after multiplying qk with sm_scale, + # our optimization to use 2^x instead of e^x results in an additional + # scale factor of log2(e) which we must also multiply the bias with. + qk += (bias * 1.44269504089) + + if alibi_slope is not None: + # Compute the global position of each token within the sequence + global_m_positions = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + global_n_positions = start_n + tl.arange(0, BLOCK_N) + alibi_block = compute_alibi_block(alibi_slope, actual_seqlen_q, actual_seqlen_k, global_m_positions, + global_n_positions) + qk += (alibi_block * 1.44269504089) # scale factor of log2(e) + + # softmax + m_ij = tl.maximum(m_i, tl.max(qk, 1)) + qk = qk - m_ij[:, None] + p = tl.math.exp2(qk) + + # CAVEAT: Must update l_ij before applying dropout + l_ij = tl.sum(p, 1) + if ENABLE_DROPOUT: + philox_offset = batch_philox_offset + start_m * BLOCK_M * actual_seqlen_k + start_n - BLOCK_N + keep = dropout_mask(philox_seed, philox_offset, dropout_p, BLOCK_M, BLOCK_N, actual_seqlen_k) + if RETURN_ENCODED_SOFTMAX: + tl.store(encoded_sm_ptrs, tl.where(keep, p, -p).to(encoded_sm_ptrs.type.element_ty)) + p = tl.where(keep, p, 0.0) + elif RETURN_ENCODED_SOFTMAX: + tl.store(encoded_sm_ptrs, p.to(encoded_sm_ptrs.type.element_ty)) + # -- update output accumulator -- + alpha = tl.math.exp2(m_i - m_ij) + acc = acc * alpha[:, None] + if not PRE_LOAD_V: + v = load_fn(v_ptrs, k_offs_n, k_offs_k, actual_seqlen_k, ACTUAL_BLOCK_DMODEL) + # -- update m_i and l_i + l_i = l_i * alpha + l_ij + # update m_i and l_i + m_i = m_ij + + if INT8_GEMM: + if USE_P_SCALE: + p = (p * p_scale).to(tl.int8) + # They are all int8 + acc += tl.dot(p, v) + else: + # v is in int8 but p is not, we want the gemm in p's type + acc += tl.dot(p, v.to(p.type.element_ty)) + else: + if INT8_KV: + v = (v * v_descale).to(p.type.element_ty) + acc += tl.dot(p.to(v.type.element_ty), v) + + k_ptrs += BLOCK_N * stride_kn + v_ptrs += BLOCK_N * stride_vk + if bias_ptrs is not None: + bias_ptrs += BLOCK_N * stride_bn + if RETURN_ENCODED_SOFTMAX: + encoded_sm_ptrs += BLOCK_N + return acc, l_i, m_i + + +def get_gfx_version(): + try: + # Run the rocminfo command + result = subprocess.run(['rocminfo'], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True) + output = result.stdout + + # Parse the output to find the gfx version + for line in output.splitlines(): + line = line.strip() + if line.startswith("Name: gfx"): + gfx_version = line.split("Name:")[1].strip() + return gfx_version + except Exception as e: + print(f"Error: {e}") + return None + + +def is_hip(): + return triton.runtime.driver.active.get_current_target().backend == "hip" + + +def is_cdna(): + return is_hip() and triton.runtime.driver.active.get_current_target().arch in ('gfx950', 'gfx940', 'gfx941', + 'gfx942', 'gfx90a', 'gfx908') + + +def is_rdna(): + return is_hip() and triton.runtime.driver.active.get_current_target().arch in ("gfx1030", "gfx1100", "gfx1101", + "gfx1102", "gfx1200", "gfx1201") + + +def get_cdna_autotune_configs(): + return [ + # triton.Config({'BLOCK_M': 64, 'BLOCK_N': 64, 'waves_per_eu': 0, 'PRE_LOAD_V': False, 'GRID_CU_MULTIP': 2}, + # num_stages=2, num_warps=4), + triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'waves_per_eu': 2, 'PRE_LOAD_V': False, 'GRID_CU_MULTIP': 2}, + num_stages=4, num_warps=8), + # triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 3, 'PRE_LOAD_V': False, 'GRID_CU_MULTIP': 2}, + # num_stages=2, num_warps=4), + # triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 1, 'PRE_LOAD_V': False, 'GRID_CU_MULTIP': 2}, + # num_stages=2, num_warps=4), + # triton.Config({'BLOCK_M': 128, 'BLOCK_N': 32, 'waves_per_eu': 2, 'PRE_LOAD_V': False, 'GRID_CU_MULTIP': 2}, + # num_stages=2, num_warps=4), + ], ['IS_CAUSAL', 'dropout_p', 'MAX_SEQLENS_Q', 'MAX_SEQLENS_K', 'ACTUAL_BLOCK_DMODEL', 'VARLEN', 'HQ', 'HK'] + + +def get_rdna_autotune_configs(): + return [ + triton.Config({'BLOCK_M': 32, 'BLOCK_N': 32, 'waves_per_eu': 4, 'PRE_LOAD_V': False, 'GRID_CU_MULTIP': 2}, + num_stages=1, num_warps=2), + triton.Config({'BLOCK_M': 32, 'BLOCK_N': 32, 'waves_per_eu': 2, 'PRE_LOAD_V': False, 'GRID_CU_MULTIP': 2}, + num_stages=1, num_warps=2), + triton.Config({'BLOCK_M': 32, 'BLOCK_N': 16, 'waves_per_eu': 4, 'PRE_LOAD_V': False, 'GRID_CU_MULTIP': 2}, + num_stages=1, num_warps=2), + triton.Config({'BLOCK_M': 32, 'BLOCK_N': 16, 'waves_per_eu': 2, 'PRE_LOAD_V': False, 'GRID_CU_MULTIP': 2}, + num_stages=1, num_warps=2), + triton.Config({'BLOCK_M': 16, 'BLOCK_N': 16, 'waves_per_eu': 4, 'PRE_LOAD_V': False, 'GRID_CU_MULTIP': 2}, + num_stages=1, num_warps=2), + triton.Config({'BLOCK_M': 16, 'BLOCK_N': 16, 'waves_per_eu': 2, 'PRE_LOAD_V': False, 'GRID_CU_MULTIP': 2}, + num_stages=1, num_warps=2), + # Fall-back config. + triton.Config({'BLOCK_M': 16, 'BLOCK_N': 16, 'waves_per_eu': 1, 'PRE_LOAD_V': False, 'GRID_CU_MULTIP': 2}, + num_stages=1, num_warps=2), + ], ['IS_CAUSAL', 'dropout_p', 'MAX_SEQLENS_Q', 'MAX_SEQLENS_K', 'ACTUAL_BLOCK_DMODEL', 'VARLEN', 'HQ', 'HK'] + + +def get_autotune_configs(): + if is_rdna(): + return get_rdna_autotune_configs() + elif is_cdna(): + return get_cdna_autotune_configs() + else: + raise ValueError("Unknown Device Type") + + +autotune_configs, autotune_keys = get_autotune_configs() + + +@triton.autotune( + configs=autotune_configs, + key=autotune_keys, + use_cuda_graph=True, +) +@triton.jit +def attn_fwd(Q, K, V, bias, SM_SCALE: tl.constexpr, L, Out, stride_qz, stride_qh, stride_qm, stride_qk, stride_kz, + stride_kh, stride_kn, stride_kk, stride_vz, stride_vh, stride_vk, stride_vn, stride_oz, stride_oh, + stride_om, stride_on, stride_bz, stride_bh, stride_bm, stride_bn, stride_az, stride_ah, Q_descale, + K_descale, P_scale, P_descale, V_descale, cu_seqlens_q, cu_seqlens_k, dropout_p, philox_seed, + PERSISTENT: tl.constexpr, PERSISTENT_DYNAMIC: tl.constexpr, atomic_counter, NUM_CU: tl.constexpr, + GRID_CU_MULTIP: tl.constexpr, B: tl.constexpr, philox_offset_base, encoded_softmax, alibi_slopes, + HQ: tl.constexpr, HK: tl.constexpr, ACTUAL_BLOCK_DMODEL: tl.constexpr, MAX_SEQLENS_Q: tl.constexpr, + MAX_SEQLENS_K: tl.constexpr, VARLEN: tl.constexpr, IS_CAUSAL: tl.constexpr, BLOCK_M: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, PRE_LOAD_V: tl.constexpr, USE_BIAS: tl.constexpr, + ENABLE_DROPOUT: tl.constexpr, RETURN_ENCODED_SOFTMAX: tl.constexpr, USE_ALIBI: tl.constexpr, + INT8: tl.constexpr, USE_P_SCALE: tl.constexpr, INT8_KV: tl.constexpr): + + tl.assume(stride_qz >= 0) + tl.assume(stride_qh >= 0) + tl.assume(stride_qm >= 0) + tl.assume(stride_qk >= 0) + tl.assume(stride_kz >= 0) + tl.assume(stride_kh >= 0) + tl.assume(stride_kn >= 0) + tl.assume(stride_kk >= 0) + tl.assume(stride_bz >= 0) + tl.assume(stride_bh >= 0) + tl.assume(stride_bm >= 0) + tl.assume(stride_bn >= 0) + tl.assume(stride_vz >= 0) + tl.assume(stride_vh >= 0) + tl.assume(stride_vk >= 0) + tl.assume(stride_vn >= 0) + tl.assume(stride_oz >= 0) + tl.assume(stride_oh >= 0) + tl.assume(stride_om >= 0) + tl.assume(stride_on >= 0) + + if PERSISTENT: # if persistent, kernel loops over multiple tiles + NUM_WG = NUM_CU * GRID_CU_MULTIP # number of workgroups launched + num_tiles_per_head = tl.cdiv(MAX_SEQLENS_Q, BLOCK_M) # the number of work units (tiles) of a single head + num_tiles_per_sample = num_tiles_per_head * HQ # times the number of heads + num_tiles_total = num_tiles_per_sample * B # times the number of samples + if PERSISTENT_DYNAMIC: + tile_id = atomic_counter.atomic_add(1) # retuns the value BEFORE the atomic operation + else: + tile_id = tl.program_id(0) + else: # standard, kernel processes only one tile + tile_id = 0 + num_tiles_total = 1 + + while tile_id < num_tiles_total: # loops more than once only if PERSISTENT + if PERSISTENT: + # tile id basically tells us the Q block we are handling + off_z = tile_id // num_tiles_per_sample # at which batch sample are we + off_h_q = tile_id % num_tiles_per_sample // num_tiles_per_head # at which head are we inside the sample + start_m = tile_id % num_tiles_per_sample % num_tiles_per_head # at which tile are we inside the head + else: + start_m = tl.program_id(0) + off_h_q = tl.program_id(1) + off_z = tl.program_id(2) + + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_n = tl.arange(0, BLOCK_N) + offs_d = tl.arange(0, BLOCK_DMODEL) + + continue_condition = True # as we can't have return statements inside while loop in Triton + + if VARLEN: + cu_seqlens_q_start = tl.load(cu_seqlens_q + off_z) + cu_seqlens_q_end = tl.load(cu_seqlens_q + off_z + 1) + seqlen_q = cu_seqlens_q_end - cu_seqlens_q_start + # We have a one-size-fits-all grid in id(0). Some seqlens might be too + # small for all start_m so for those we return early. + if start_m * BLOCK_M > seqlen_q: + continue_condition = False + # return + cu_seqlens_k_start = tl.load(cu_seqlens_k + off_z) + cu_seqlens_k_end = tl.load(cu_seqlens_k + off_z + 1) + seqlen_k = cu_seqlens_k_end - cu_seqlens_k_start + else: + cu_seqlens_q_start = 0 + cu_seqlens_k_start = 0 + seqlen_q = MAX_SEQLENS_Q + seqlen_k = MAX_SEQLENS_K + + if continue_condition: + # Now we compute whether we need to exit early due to causal masking. + # This is because for seqlen_q > seqlen_k, M rows of the attn scores + # are completely masked, resulting in 0s written to the output, and + # inf written to LSE. We don't need to do any GEMMs in this case. + # This block of code determines what N is, and if this WG is operating + # on those M rows. + n_blocks = cdiv_fn(seqlen_k, BLOCK_N) + if (IS_CAUSAL): + # If seqlen_q == seqlen_k, the attn scores are a square matrix. + # If seqlen_q != seqlen_k, attn scores are rectangular which means + # the causal mask boundary is bottom right aligned, and ends at either + # the top edge (seqlen_q < seqlen_k) or left edge. + # This captures the decrease in n_blocks if we have a rectangular attn matrix + n_blocks_seqlen = cdiv_fn((start_m + 1) * BLOCK_M + seqlen_k - seqlen_q, BLOCK_N) + # This is what adjusts the block_max for the current WG, only + # if IS_CAUSAL. Otherwise we want to always iterate through all n_blocks + n_blocks = min(n_blocks, n_blocks_seqlen) + # If we have no blocks after adjusting for seqlen deltas, this WG is part of + # the blocks that are all 0. We exit early. + if n_blocks <= 0: + o_offset = Out + off_z * stride_oz + off_h_q * stride_oh + cu_seqlens_q_start * stride_om + o_ptrs = o_offset + offs_m[:, None] * stride_om + offs_d[None, :] * stride_on + acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=Out.type.element_ty) + o_ptrs_mask = (offs_m[:, None] < seqlen_q).broadcast_to([BLOCK_M, BLOCK_DMODEL]) + # We still need to write 0s to the result + tl.store(o_ptrs, acc, mask=o_ptrs_mask) + # The tensor allocated for L is based on MAX_SEQLENS_Q as that is + # statically known. + l_ptrs = L + off_z * HQ * MAX_SEQLENS_Q + off_h_q * MAX_SEQLENS_Q + offs_m + # We store inf to LSE, not -inf because in the bwd pass, we subtract this + # from qk which makes it -inf, such that exp(qk - inf) = 0 for these masked blocks. + l = tl.full([BLOCK_M], value=float("inf"), dtype=tl.float32) + l_ptrs_mask = offs_m < MAX_SEQLENS_Q + tl.store(l_ptrs, l, mask=l_ptrs_mask) + # TODO: Should dropout and return encoded softmax be handled here too? + continue_condition = False + # return + + if continue_condition: + # If MQA / GQA, set the K and V head offsets appropriately. + GROUP_SIZE: tl.constexpr = HQ // HK + if GROUP_SIZE != 1: + off_h_k = off_h_q // GROUP_SIZE + else: + off_h_k = off_h_q + + n_extra_tokens = 0 + if seqlen_k < BLOCK_N: + n_extra_tokens = BLOCK_N - seqlen_k + elif seqlen_k % BLOCK_N: + n_extra_tokens = seqlen_k % BLOCK_N + PADDED_HEAD: tl.constexpr = (ACTUAL_BLOCK_DMODEL != BLOCK_DMODEL) + + # Compute pointers for all the tensors used in this kernel. + q_offset = Q + off_z * stride_qz + off_h_q * stride_qh + cu_seqlens_q_start * stride_qm + q_ptrs = q_offset + offs_m[:, None] * stride_qm + offs_d[None, :] * stride_qk + k_offset = K + off_z * stride_kz + off_h_k * stride_kh + cu_seqlens_k_start * stride_kn + k_ptrs = k_offset + offs_d[:, None] * stride_kk + offs_n[None, :] * stride_kn + v_offset = V + off_z * stride_vz + off_h_k * stride_vh + cu_seqlens_k_start * stride_vk + v_ptrs = v_offset + offs_n[:, None] * stride_vk + offs_d[None, :] * stride_vn + # Compute pointers for all the scale tensors used in this kernel. + + INT8_GEMM: tl.constexpr = INT8 & (not INT8_KV) + if INT8: + k_descale_ptrs = K_descale + off_h_k + v_descale_ptrs = V_descale + off_h_k + if not INT8_KV: + q_descale_ptrs = Q_descale + off_h_q + if USE_P_SCALE: + p_scale_ptrs = P_scale + off_h_q + p_descale_ptrs = P_descale + off_h_q + + if USE_BIAS: + # Note: this might get large enough to overflow on some configs + bias_offset = off_h_q * stride_bh + bias_ptrs = bias + bias_offset + offs_m[:, None] * stride_bm + offs_n[None, :] * stride_bn + else: + bias_ptrs = None + + if USE_ALIBI: + a_offset = off_z * stride_az + off_h_q * stride_ah + alibi_slope = tl.load(alibi_slopes + a_offset) + else: + alibi_slope = None + + if ENABLE_DROPOUT: + off_hz = off_z * HQ + off_h_q + batch_philox_offset = philox_offset_base + off_hz * seqlen_q * seqlen_k + else: + batch_philox_offset = 0 + # We can ask to return the dropout mask without actually doing any dropout. In + # this case, we return an invalid pointer so indicate the mask is not valid. + if RETURN_ENCODED_SOFTMAX: + encoded_sm_base = encoded_softmax + off_h_q * seqlen_q * seqlen_k + encoded_sm_ptrs = encoded_sm_base + offs_m[:, None] * seqlen_k + offs_n[None, :] + else: + encoded_sm_ptrs = None + # initialize pointer to m and l + m_i = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32) + l_i = tl.full([BLOCK_M], 1.0, dtype=tl.float32) + acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) + # scale sm_scale by log_2(e) and use 2^x in the loop as we do not + # have native e^x support in HW. + QK_SCALE: tl.constexpr = SM_SCALE * 1.44269504089 + # Q is loaded once at the beginning and shared by all N blocks. + q_ptrs_mask = offs_m[:, None] < seqlen_q + if PADDED_HEAD: + q_ptrs_mask = q_ptrs_mask & (offs_d[None, :] < ACTUAL_BLOCK_DMODEL) + q = tl.load(q_ptrs, mask=q_ptrs_mask, other=0.0) + + if INT8: + k_descale = tl.load(k_descale_ptrs) + v_descale = tl.load(v_descale_ptrs) + if not INT8_KV: + q_descale = tl.load(q_descale_ptrs) + else: + q_descale = None + if USE_P_SCALE: + p_scale = tl.load(p_scale_ptrs) + p_descale = tl.load(p_descale_ptrs) + else: + p_scale = None + p_descale = None + else: + q_descale = None + k_descale = None + v_descale = None + p_scale = None + p_descale = None + # Here we compute how many full and masked blocks we have. + padded_block_k = n_extra_tokens != 0 + is_modulo_mn = not padded_block_k and (seqlen_q % BLOCK_M == 0) + if IS_CAUSAL: + # There are always at least BLOCK_M // BLOCK_N masked blocks. + # Additionally there might be one more due to dissimilar seqlens. + masked_blocks = BLOCK_M // BLOCK_N + (not is_modulo_mn) + else: + # Padding on Q does not need to be masked in the FA loop. + masked_blocks = padded_block_k + # if IS_CAUSAL, not is_modulo_mn does not always result in an additional block. + # In this case we might exceed n_blocks so pick the min. + masked_blocks = min(masked_blocks, n_blocks) + n_full_blocks = n_blocks - masked_blocks + block_min = 0 + block_max = n_blocks * BLOCK_N + # Compute for full blocks. Here we set causal to false regardless of its actual + # value because there is no masking. Similarly we do not need padding. + if n_full_blocks > 0: + block_max = (n_blocks - masked_blocks) * BLOCK_N + acc, l_i, m_i = _attn_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, bias_ptrs, stride_kn, stride_vk, + stride_bn, start_m, seqlen_k, seqlen_q, dropout_p, philox_seed, + batch_philox_offset, encoded_sm_ptrs, + # _, _, offs_n_causal, masked_blocks, n_extra_tokens, _ + block_min, block_max, 0, 0, 0, alibi_slope, q_descale, k_descale, + v_descale, p_scale, + # IS_CAUSAL, .... + False, BLOCK_M, BLOCK_DMODEL, BLOCK_N, offs_m, offs_n, + # _, MASK_STEPS, ... + PRE_LOAD_V, False, ENABLE_DROPOUT, RETURN_ENCODED_SOFTMAX, + PADDED_HEAD, ACTUAL_BLOCK_DMODEL, QK_SCALE, INT8_GEMM, USE_P_SCALE, + INT8_KV) + block_min = block_max + block_max = n_blocks * BLOCK_N + + tl.debug_barrier() + # Remaining blocks, if any, are full / not masked. + if (masked_blocks > 0): + if IS_CAUSAL: + offs_n_causal = offs_n + (seqlen_q - seqlen_k) + else: + offs_n_causal = 0 + k_ptrs += n_full_blocks * BLOCK_N * stride_kn + v_ptrs += n_full_blocks * BLOCK_N * stride_vk + if USE_BIAS: + bias_ptrs += n_full_blocks * BLOCK_N * stride_bn + if RETURN_ENCODED_SOFTMAX: + encoded_sm_ptrs += n_full_blocks * BLOCK_N + acc, l_i, m_i = _attn_fwd_inner( + acc, l_i, m_i, q, k_ptrs, v_ptrs, bias_ptrs, stride_kn, stride_vk, stride_bn, start_m, seqlen_k, + seqlen_q, dropout_p, philox_seed, batch_philox_offset, encoded_sm_ptrs, block_min, block_max, + offs_n_causal, masked_blocks, n_extra_tokens, alibi_slope, q_descale, k_descale, v_descale, + p_scale, IS_CAUSAL, BLOCK_M, BLOCK_DMODEL, BLOCK_N, offs_m, offs_n, + # _, MASK_STEPS, ... + PRE_LOAD_V, True, ENABLE_DROPOUT, RETURN_ENCODED_SOFTMAX, PADDED_HEAD, ACTUAL_BLOCK_DMODEL, + QK_SCALE, INT8_GEMM, USE_P_SCALE, INT8_KV) + + if INT8 and not INT8_KV: + if USE_P_SCALE: + acc *= p_descale + acc *= v_descale + + # epilogue + # This helps the compiler do Newton Raphson on l_i vs on acc which is much larger. + l_recip = 1 / l_i[:, None] + acc = acc * l_recip + + if ENABLE_DROPOUT: + acc = acc / (1 - dropout_p) + # If seqlen_q > seqlen_k but the delta is not a multiple of BLOCK_M, + # then we have one block with a row of all NaNs which come from computing + # softmax over a row of all -infs (-inf - inf = NaN). We check for that here + # and store 0s where there are NaNs as these rows should've been zeroed out. + end_m_idx = (start_m + 1) * BLOCK_M + start_m_idx = start_m * BLOCK_M + causal_start_idx = seqlen_q - seqlen_k + acc = acc.to(Out.type.element_ty) + if IS_CAUSAL: + if causal_start_idx > start_m_idx and causal_start_idx < end_m_idx: + out_mask_boundary = tl.full((BLOCK_DMODEL, ), causal_start_idx, dtype=tl.int32) + mask_m_offsets = start_m_idx + tl.arange(0, BLOCK_M) + out_ptrs_mask = mask_m_offsets[:, None] >= out_mask_boundary[None, :] + z = 0.0 + acc = tl.where(out_ptrs_mask, acc, z.to(acc.type.element_ty)) + # write back LSE + l_ptrs = L + off_z * HQ * MAX_SEQLENS_Q + off_h_q * MAX_SEQLENS_Q + offs_m + # If seqlen_q not multiple of BLOCK_M, we need to mask out the last few rows. + # This is only true for the last M block. For others, overflow_size will be -ve + overflow_size = end_m_idx - seqlen_q + if overflow_size > 0: + boundary = tl.full((BLOCK_M, ), BLOCK_M - overflow_size, dtype=tl.int32) + l_ptrs_mask = tl.arange(0, BLOCK_M) < boundary + tl.store(l_ptrs, m_i + tl.math.log2(l_i), mask=l_ptrs_mask) + else: + tl.store(l_ptrs, m_i + tl.math.log2(l_i)) + + # write back O + o_offset = Out + off_z * stride_oz + off_h_q * stride_oh + cu_seqlens_q_start * stride_om + o_ptrs = o_offset + offs_m[:, None] * stride_om + offs_d[None, :] * stride_on + o_ptrs_mask = tl.full([BLOCK_M, BLOCK_DMODEL], 1, dtype=tl.int1) + if overflow_size > 0: + o_ptrs_mask = o_ptrs_mask & (offs_m[:, None] < seqlen_q) + if PADDED_HEAD: + o_ptrs_mask = o_ptrs_mask & (offs_d[None, :] < ACTUAL_BLOCK_DMODEL) + tl.store(o_ptrs, acc.to(Out.dtype.element_ty), mask=o_ptrs_mask) + + if PERSISTENT: + if PERSISTENT_DYNAMIC: + tile_id = atomic_counter.atomic_add(1) + else: + tile_id += NUM_WG + else: + tile_id = num_tiles_total # break after single tile + + +@triton.jit +def _attn_bwd_preprocess( + Out, + DO, + Delta, + stride_oz, + stride_oh, + stride_om, + stride_on, + stride_doz, + stride_doh, + stride_dom, + stride_don, + seqlen_q, + head_dim, + BLOCK_M: tl.constexpr, + D_HEAD: tl.constexpr, +): + # off_m = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M) + # off_n = tl.arange(0, D_HEAD) + off_m = tl.program_id(0) * BLOCK_M + off_h = tl.program_id(1) # head index + off_z = tl.program_id(2) # batch index + num_h = tl.num_programs(1) + o_offset = off_h * stride_oh + off_z * stride_oz + O_block_ptr = tl.make_block_ptr(base=Out + o_offset, shape=(seqlen_q, head_dim), strides=(stride_om, stride_on), + offsets=(off_m, 0), block_shape=(BLOCK_M, D_HEAD), order=(1, 0)) + do_offset = off_h * stride_doh + off_z * stride_doz + DO_block_ptr = tl.make_block_ptr(base=DO + do_offset, shape=(seqlen_q, head_dim), strides=(stride_dom, stride_don), + offsets=(off_m, 0), block_shape=(BLOCK_M, D_HEAD), order=(1, 0)) + # load + # o = tl.load(Out + off_m[:, None] * D_HEAD + off_n[None, :]).to(tl.float32) + # do = tl.load(DO + off_m[:, None] * D_HEAD + off_n[None, :]).to(tl.float32) + o = tl.load(O_block_ptr, boundary_check=(0, 1), padding_option="zero").to(tl.float32) + do = tl.load(DO_block_ptr, boundary_check=(0, 1), padding_option="zero").to(tl.float32) + # compute + delta = tl.sum(o * do, axis=1) + # write-back, shape (q.shape[0] * q.shape[1], q.shape[2]) + off_zh = off_z * num_h + off_h * 1 + # Check for OOB accesses + delta_ptrs = Delta + off_zh * seqlen_q + off_m + tl.arange(0, BLOCK_M) + overflow = off_m + BLOCK_M - seqlen_q + if overflow > 0: + boundary = tl.full((BLOCK_M, ), BLOCK_M - overflow, dtype=tl.int32) + mask = boundary > tl.arange(0, BLOCK_M) + tl.store(delta_ptrs, delta, mask=mask) + else: + tl.store(delta_ptrs, delta) + + +@triton.jit +def _bwd_kernel_dk_dv(dk, dv, Q, k, v, sm_scale, alibi_slope, DO, M, D, + # shared by Q/K/V/DO. + stride_tok, stride_d, H, N_CTX, BLOCK_M1: tl.constexpr, BLOCK_N1: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, + # Filled in by the wrapper. + start_n, start_m, num_steps, MASK: tl.constexpr): + offs_m = start_m + tl.arange(0, BLOCK_M1) + offs_n = start_n + tl.arange(0, BLOCK_N1) + # offs_k = tl.arange(0, BLOCK_DMODEL) + QT_block_ptr = tl.make_block_ptr(base=Q, shape=(BLOCK_DMODEL, N_CTX), strides=(stride_d, stride_tok), + offsets=(0, start_m), block_shape=(BLOCK_DMODEL, BLOCK_M1), order=(0, 1)) + DO_block_ptr = tl.make_block_ptr(base=DO, shape=(N_CTX, BLOCK_DMODEL), strides=(stride_tok, stride_d), + offsets=(start_m, 0), block_shape=(BLOCK_M1, BLOCK_DMODEL), order=(1, 0)) + # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work. + tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0) + curr_m = start_m + step_m = BLOCK_M1 + for blk_idx in range(num_steps): + qT = tl.load(QT_block_ptr) + # Load m before computing qk to reduce pipeline stall. + offs_m = curr_m + tl.arange(0, BLOCK_M1) + m = tl.load(M + offs_m) + kqT = tl.dot(k, qT) + if alibi_slope is not None: + alibi_block = compute_alibi_block(alibi_slope, N_CTX, N_CTX, offs_m, offs_n, True) + kqT += alibi_block * 1.44269504089 + + pT = tl.math.exp2(kqT - m[None, :]) + # Autoregressive masking. + if MASK: + mask = (offs_m[None, :] >= offs_n[:, None]) + pT = tl.where(mask, pT, 0.0) + do = tl.load(DO_block_ptr) + # Compute dV. + ppT = pT + ppT = ppT.to(tl.float16) + dv += tl.dot(ppT, do) + # D (= delta) is pre-divided by ds_scale. + Di = tl.load(D + offs_m) + # Compute dP and dS. + dpT = tl.dot(v, tl.trans(do)) + dsT = pT * (dpT - Di[None, :]) + dsT = dsT.to(tl.float16) + dk += tl.dot(dsT, tl.trans(qT)) + # Increment pointers. + curr_m += step_m + QT_block_ptr = tl.advance(QT_block_ptr, (0, step_m)) + DO_block_ptr = tl.advance(DO_block_ptr, (step_m, 0)) + return dk, dv + + +@triton.jit +def _bwd_kernel_dq(dq, q, K, V, do, m, D, alibi_slope, + # shared by Q/K/V/DO. + stride_tok, stride_d, H, N_CTX, BLOCK_M2: tl.constexpr, BLOCK_N2: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, + # Filled in by the wrapper. + start_m, start_n, num_steps, MASK: tl.constexpr): + offs_m = start_m + tl.arange(0, BLOCK_M2) + offs_n = start_n + tl.arange(0, BLOCK_N2) + # offs_k = tl.arange(0, BLOCK_DMODEL) + KT_block_ptr = tl.make_block_ptr(base=K, shape=(BLOCK_DMODEL, N_CTX), strides=(stride_d, stride_tok), + offsets=(0, start_n), block_shape=(BLOCK_DMODEL, BLOCK_N2), order=(0, 1)) + VT_block_ptr = tl.make_block_ptr(base=V, shape=(BLOCK_DMODEL, N_CTX), strides=(stride_d, stride_tok), + offsets=(0, start_n), block_shape=(BLOCK_DMODEL, BLOCK_N2), order=(0, 1)) + # D (= delta) is pre-divided by ds_scale. + Di = tl.load(D + offs_m) + # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work. + tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0) + curr_n = start_n + step_n = BLOCK_N2 + for blk_idx in range(num_steps): + kT = tl.load(KT_block_ptr) + qk = tl.dot(q, kT) + if alibi_slope is not None: + alibi_block = compute_alibi_block(alibi_slope, N_CTX, N_CTX, offs_m, offs_n) + qk += alibi_block * 1.44269504089 + + p = tl.math.exp2(qk - m) + # Autoregressive masking. + if MASK: + offs_n = curr_n + tl.arange(0, BLOCK_N2) + mask = (offs_m[:, None] >= offs_n[None, :]) + p = tl.where(mask, p, 0.0) + # Compute dP and dS. + vT = tl.load(VT_block_ptr) + dp = tl.dot(do, vT).to(tl.float32) + ds = p * (dp - Di[:, None]) + ds = ds.to(tl.float16) + # Compute dQ.0. + # NOTE: We need to de-scale dq in the end, because kT was pre-scaled. + dq += tl.dot(ds, tl.trans(kT)) + # Increment pointers. + curr_n += step_n + KT_block_ptr = tl.advance(KT_block_ptr, (0, step_n)) + VT_block_ptr = tl.advance(VT_block_ptr, (0, step_n)) + return dq + + +@triton.jit +def _attn_bwd(Q, K, V, sm_scale, alibi_slopes, DO, DQ, DK, DV, M, D, + # shared by Q/K/V/DO. + stride_z, stride_h, stride_tok, stride_d, + # H = 16, N_CTX = 1024 + H, N_CTX, BLOCK_DMODEL: tl.constexpr, BLOCK_M1: tl.constexpr, BLOCK_N1: tl.constexpr, + BLOCK_M2: tl.constexpr, BLOCK_N2: tl.constexpr, BLK_SLICE_FACTOR: tl.constexpr, USE_ALIBI: tl.constexpr): + LN2: tl.constexpr = 0.6931471824645996 # = ln(2) + + bhid = tl.program_id(2) + off_chz = (bhid * N_CTX).to(tl.int64) + adj = (stride_h * (bhid % H) + stride_z * (bhid // H)).to(tl.int64) + pid = tl.program_id(0) + + # offset pointers for batch/head + Q += adj + K += adj + V += adj + DO += adj + DQ += adj + DK += adj + DV += adj + M += off_chz + D += off_chz + + # offs_k = tl.arange(0, BLOCK_DMODEL) + + start_n = pid * BLOCK_N1 + # This assignment is important. It is what allows us to pick the diagonal + # blocks. Later, when we want to do the lower triangular, we update start_m + # after the first dkdv call. + start_m = start_n + + MASK_BLOCK_M1: tl.constexpr = BLOCK_M1 // BLK_SLICE_FACTOR + # offs_n = start_n + tl.arange(0, BLOCK_N1) + + dv = tl.zeros([BLOCK_N1, BLOCK_DMODEL], dtype=tl.float32) + dk = tl.zeros([BLOCK_N1, BLOCK_DMODEL], dtype=tl.float32) + + K_block_ptr = tl.make_block_ptr( + base=K, + shape=(N_CTX, BLOCK_DMODEL), + strides=(stride_tok, stride_d), + offsets=(start_n, 0), + block_shape=(BLOCK_N1, BLOCK_DMODEL), + order=(1, 0), + ) + V_block_ptr = tl.make_block_ptr( + base=V, + shape=(N_CTX, BLOCK_DMODEL), + strides=(stride_tok, stride_d), + offsets=(start_n, 0), + block_shape=(BLOCK_N1, BLOCK_DMODEL), + order=(1, 0), + ) + + # load K and V: they stay in SRAM throughout the inner loop for dkdv. + k = tl.load(K_block_ptr) + v = tl.load(V_block_ptr) + + if USE_ALIBI: + a_offset = bhid + alibi_slope = tl.load(alibi_slopes + a_offset) + else: + alibi_slope = None + + # compute dK and dV for blocks close to the diagonal that need to be masked + num_steps = BLOCK_N1 // MASK_BLOCK_M1 + dk, dv = _bwd_kernel_dk_dv(dk, dv, Q, k, v, sm_scale, alibi_slope, DO, M, D, stride_tok, stride_d, H, N_CTX, + MASK_BLOCK_M1, BLOCK_N1, BLOCK_DMODEL, start_n, start_m, num_steps, MASK=True) + + # compute dK and dV for blocks that don't need masking further from the diagonal + start_m += num_steps * MASK_BLOCK_M1 + num_steps = (N_CTX - start_m) // BLOCK_M1 + + dk, dv = _bwd_kernel_dk_dv(dk, dv, Q, k, v, sm_scale, alibi_slope, DO, M, D, stride_tok, stride_d, H, N_CTX, + BLOCK_M1, BLOCK_N1, BLOCK_DMODEL, start_n, start_m, num_steps, MASK=False) + + DV_block_ptrs = tl.make_block_ptr(base=DV, shape=(N_CTX, BLOCK_DMODEL), strides=(stride_tok, stride_d), + offsets=(start_n, 0), block_shape=(BLOCK_N1, BLOCK_DMODEL), order=(1, 0)) + tl.store(DV_block_ptrs, dv.to(v.dtype)) + + # Write back dK. + dk *= sm_scale + DK_block_ptrs = tl.make_block_ptr(base=DK, shape=(N_CTX, BLOCK_DMODEL), strides=(stride_tok, stride_d), + offsets=(start_n, 0), block_shape=(BLOCK_N1, BLOCK_DMODEL), order=(1, 0)) + tl.store(DK_block_ptrs, dk.to(k.dtype)) + + # THIS BLOCK DOES DQ: + start_m = pid * BLOCK_M2 + end_n = start_m + BLOCK_M2 + + MASK_BLOCK_N2: tl.constexpr = BLOCK_N2 // BLK_SLICE_FACTOR + offs_m = start_m + tl.arange(0, BLOCK_M2) + + Q_block_ptr = tl.make_block_ptr(base=Q, shape=(N_CTX, BLOCK_DMODEL), strides=(stride_tok, stride_d), + offsets=(start_m, 0), block_shape=(BLOCK_M2, BLOCK_DMODEL), order=(1, 0)) + + DO_block_ptr = tl.make_block_ptr(base=DO, shape=(N_CTX, BLOCK_DMODEL), strides=(stride_tok, stride_d), + offsets=(start_m, 0), block_shape=(BLOCK_M2, BLOCK_DMODEL), order=(1, 0)) + q = tl.load(Q_block_ptr) + do = tl.load(DO_block_ptr) + dq = tl.zeros([BLOCK_M2, BLOCK_DMODEL], dtype=tl.float32) + + m = tl.load(M + offs_m) + m = m[:, None] + + # Compute dQ for masked (diagonal) blocks. + # NOTE: This code scans each row of QK^T backward (from right to left, + # but inside each call to _attn_bwd_dq, from left to right), but that's + # not due to anything important. I just wanted to reuse the loop + # structure for dK & dV above as much as possible. + num_steps = BLOCK_M2 // MASK_BLOCK_N2 + dq = _bwd_kernel_dq(dq, q, K, V, do, m, D, alibi_slope, stride_tok, stride_d, H, N_CTX, BLOCK_M2, MASK_BLOCK_N2, + BLOCK_DMODEL, start_m, end_n - num_steps * MASK_BLOCK_N2, num_steps, MASK=True) + end_n -= num_steps * MASK_BLOCK_N2 + # stage 2 + num_steps = end_n // BLOCK_N2 + dq = _bwd_kernel_dq(dq, q, K, V, do, m, D, alibi_slope, stride_tok, stride_d, H, N_CTX, BLOCK_M2, BLOCK_N2, + BLOCK_DMODEL, start_m, end_n - num_steps * BLOCK_N2, num_steps, MASK=False) + # Write back dQ. + DQ_block_ptr = tl.make_block_ptr(base=DQ, shape=(N_CTX, BLOCK_DMODEL), strides=(stride_tok, stride_d), + offsets=(start_m, 0), block_shape=(BLOCK_M2, BLOCK_DMODEL), order=(1, 0)) + dq *= LN2 + tl.store(DQ_block_ptr, dq.to(q.dtype)) + + +def get_shape_from_layout(q, k, metadata): + if metadata.layout == 'thd': + nheads_q, nheads_k = q.shape[1], k.shape[1] + head_size = q.shape[-1] + batch = metadata.num_contexts + elif metadata.layout == 'bhsd': + batch, nheads_q, _, head_size = q.shape + nheads_k = k.shape[1] + elif metadata.layout == 'bshd': + batch, _, nheads_q, head_size = q.shape + nheads_k = k.shape[2] + else: + assert False, "Got unsupported layout." + return batch, nheads_q, nheads_k, head_size + + +# TODO: This can probably optimized to have fewer lines of code. +def get_strides_from_layout(q, k, v, o, metadata): + if metadata.layout == 'thd': + q_strides = (0, q.stride(1), q.stride(0), q.stride(2)) + k_strides = (0, k.stride(1), k.stride(0), k.stride(2)) + v_strides = (0, v.stride(1), v.stride(0), v.stride(2)) + o_strides = (0, o.stride(1), o.stride(0), o.stride(2)) + elif metadata.layout == 'bhsd': + q_strides = (q.stride(0), q.stride(1), q.stride(2), q.stride(3)) + k_strides = (k.stride(0), k.stride(1), k.stride(2), k.stride(3)) + v_strides = (v.stride(0), v.stride(1), v.stride(2), v.stride(3)) + o_strides = (o.stride(0), o.stride(1), o.stride(2), o.stride(3)) + elif metadata.layout == 'bshd': + q_strides = (q.stride(0), q.stride(2), q.stride(1), q.stride(3)) + k_strides = (k.stride(0), k.stride(2), k.stride(1), k.stride(3)) + v_strides = (v.stride(0), v.stride(2), v.stride(1), v.stride(3)) + o_strides = (o.stride(0), o.stride(2), o.stride(1), o.stride(3)) + else: + assert False, 'Got unsupported layout.' + return q_strides, k_strides, v_strides, o_strides + + +class _attention(torch.autograd.Function): + + @staticmethod + def forward(ctx, q, k, v, o, metadata: MetaData): + # NOTE: a large bias tensor leads to overflow during pointer arithmetic + if (metadata.bias is not None): + assert (metadata.bias.numel() < 2**31) + + if o is None: + if not metadata.int8: + o = torch.empty_like(q, dtype=v.dtype) + else: + o = torch.empty_like(q, dtype=torch.float16) + + metadata.check_args(q, k, v, o) + + batch, nheads_q, nheads_k, head_size = get_shape_from_layout(q, k, metadata) + q_strides, k_strides, v_strides, o_strides = get_strides_from_layout(q, k, v, o, metadata) + + # Get closest power of 2 over or equal to 32. + padded_d_model = 1 << (head_size - 1).bit_length() + # Smallest head_dim supported is 16. If smaller, the tile in the + # kernel is padded - there is no padding in memory for any dims. + padded_d_model = max(padded_d_model, 16) + + # encoded_softmax is used to validate dropout behavior vs the PyTorch SDPA math backend reference. We zero this out + # to give a consistent starting point and then populate it with the output of softmax with the sign bit set according + # to the dropout mask. The resulting return allows this mask to be fed into the reference implementation for testing + # only. This return holds no useful output aside from debugging. + if metadata.return_encoded_softmax: + encoded_softmax = torch.zeros((q.shape[0], q.shape[1], q.shape[2], k.shape[2]), device=q.device, + dtype=torch.float32) + else: + encoded_softmax = None + + M = torch.empty((batch, nheads_q, metadata.max_seqlens_q), device=q.device, dtype=torch.float32) + + # Seed the RNG so we get reproducible results for testing. + philox_seed = 0x1BF52 + philox_offset = 0x1D4B42 + + if metadata.bias is not None: + bias_strides = (metadata.bias.stride(0), metadata.bias.stride(1), metadata.bias.stride(2), + metadata.bias.stride(3)) + else: + bias_strides = (0, 0, 0, 0) + + if metadata.alibi_slopes is not None: + alibi_strides = (metadata.alibi_slopes.stride(0), metadata.alibi_slopes.stride(1)) + else: + alibi_strides = (0, 0) + + if metadata.int8: + q_descale, k_descale, p_scale, p_descale, v_descale = metadata.q_descale, metadata.k_descale, metadata.p_scale, metadata.p_descale, metadata.v_descale + else: + q_descale = k_descale = p_scale = p_descale = v_descale = None + + # number of compute units available + NUM_CU = torch.cuda.get_device_properties("cuda").multi_processor_count + + if metadata.persistent is not None: + grid = lambda META: (min(NUM_CU * META['GRID_CU_MULTIP'], + triton.cdiv(metadata.max_seqlens_q, META['BLOCK_M']) * nheads_q * batch), ) + else: + grid = lambda META: (triton.cdiv(metadata.max_seqlens_q, META['BLOCK_M']), nheads_q, batch) + + atomic_counter = torch.zeros([1], device=q.device, dtype=torch.int32) + + # test_op_fwd(Z, x_vals_list[1], x_vals_list[2], N_CTX_Q, N_CTX_K, D_HEAD, causal, use_alibi, layout, dtype=torch.float16): + + attn_fwd[grid](q, k, v, metadata.bias, metadata.sm_scale, M, o, *q_strides, *k_strides, *v_strides, *o_strides, + *bias_strides, *alibi_strides, q_descale, k_descale, p_scale, p_descale, v_descale, + metadata.cu_seqlens_q, metadata.cu_seqlens_k, dropout_p=metadata.dropout_p, + philox_seed=philox_seed, philox_offset_base=philox_offset, encoded_softmax=encoded_softmax, + alibi_slopes=metadata.alibi_slopes, HQ=nheads_q, HK=nheads_k, ACTUAL_BLOCK_DMODEL=head_size, + MAX_SEQLENS_Q=metadata.max_seqlens_q, MAX_SEQLENS_K=metadata.max_seqlens_k, + IS_CAUSAL=metadata.causal, VARLEN=metadata.varlen, BLOCK_DMODEL=padded_d_model, + USE_BIAS=False if metadata.bias is None else True, + USE_ALIBI=False if metadata.alibi_slopes is None else True, ENABLE_DROPOUT=metadata.dropout_p + > 0.0, RETURN_ENCODED_SOFTMAX=metadata.return_encoded_softmax, INT8=metadata.int8, + USE_P_SCALE=metadata.int8 and metadata.use_p_scale, INT8_KV=metadata.int8 and metadata.int8_kv, + PERSISTENT=metadata.persistent is not None, PERSISTENT_DYNAMIC=metadata.persistent == "dynamic", + NUM_CU=NUM_CU, atomic_counter=atomic_counter, B=batch) + + ctx.save_for_backward(q, k, v, o, M) + ctx.grid = grid + ctx.sm_scale = metadata.sm_scale + ctx.BLOCK_DMODEL = head_size + ctx.causal = metadata.causal + ctx.alibi_slopes = metadata.alibi_slopes + ctx.dropout_p = metadata.dropout_p + ctx.philox_seed = philox_seed + ctx.philox_offset = philox_offset + ctx.encoded_softmax = encoded_softmax + ctx.return_encoded_softmax = metadata.return_encoded_softmax + return o, encoded_softmax, attn_fwd.best_config + + @staticmethod + def backward(ctx, *gradients): + do = gradients[0] + if torch.version.hip is not None: + BLOCK = 64 + else: + BLOCK = 128 + q, k, v, o, M = ctx.saved_tensors + assert do.is_contiguous() + assert q.stride() == k.stride() == v.stride() == o.stride() == do.stride() + seqlen_q = q.shape[2] + dq = torch.empty_like(q) + dk = torch.empty_like(k) + dv = torch.empty_like(v) + BATCH, N_HEAD, N_CTX = q.shape[:3] + PRE_BLOCK = 128 + # NUM_WARPS, NUM_STAGES = 4, 1 + BLOCK_M1, BLOCK_N1, BLOCK_M2, BLOCK_N2 = 32, 64, 64, 32 + BLK_SLICE_FACTOR = 2 + RCP_LN2 = 1.4426950408889634 # = 1.0 / ln(2) + arg_k = k + arg_k = arg_k * (ctx.sm_scale * RCP_LN2) + assert N_CTX % PRE_BLOCK == 0 + delta = torch.empty_like(M) + _, Lk, _ = q.shape[-1], k.shape[-1], v.shape[-1] + # padded_head = (Lk != ctx.BLOCK_DMODEL) + grid_preprocess = (triton.cdiv(do.shape[2], BLOCK), do.shape[1], do.shape[0]) + _attn_bwd_preprocess[grid_preprocess]( + o, + do, + delta, + o.stride(0), + o.stride(1), + o.stride(2), + o.stride(3), + do.stride(0), + do.stride(1), + do.stride(2), + do.stride(3), + seqlen_q, + head_dim=Lk, + BLOCK_M=BLOCK, + D_HEAD=ctx.BLOCK_DMODEL, + ) + grid = lambda META: (triton.cdiv(N_CTX, META['BLOCK_N1']), 1, BATCH * N_HEAD) + _attn_bwd[grid]( + q, + arg_k, + v, + ctx.sm_scale, + ctx.alibi_slopes, + do, + dq, + dk, + dv, + M, + delta, + q.stride(0), + q.stride(1), + q.stride(2), + q.stride(3), + N_HEAD, + N_CTX, + BLOCK_DMODEL=ctx.BLOCK_DMODEL, + BLOCK_M1=BLOCK_M1, + BLOCK_N1=BLOCK_N1, + BLOCK_M2=BLOCK_M2, + BLOCK_N2=BLOCK_N2, + BLK_SLICE_FACTOR=BLK_SLICE_FACTOR, + USE_ALIBI=False if ctx.alibi_slopes is None else True, + ) + + return dq, dk, dv, None, None + + +attention = _attention.apply + +INT8_MAX = 127 + + +def quantize_int8(tensor: torch.Tensor, dim) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + max_vals = tensor.abs().amax(dim=[i for i in range(tensor.dim()) if i != dim], keepdim=True) + + # Avoid division by zero + max_vals[max_vals == 0] = 1e-8 + + # Compute scale factors for each channel + scale = INT8_MAX / max_vals.to(torch.float32) + + # Quantize the tensor + tensor = tensor * scale + tensor = tensor.round_() + tensor.clamp_(-INT8_MAX, INT8_MAX) + tensor_quantized = tensor.to(torch.int8) + + return tensor_quantized, scale, 1 / scale + + +def quantize_input(q, k, v, input_metadata: MetaData, quantize_p=False, int8_kv=False): + assert not (quantize_p and int8_kv) + if input_metadata.layout == 'bhsd': + qunatization_dim = 1 + elif input_metadata.layout == 'bshd': + qunatization_dim = 2 + else: + assert False, 'Got unsupported tensor layout' + assert not (quantize_p and int8_kv) + + q_descale = None + if not int8_kv: + q, _, q_descale = quantize_int8(q, dim=qunatization_dim) + k, _, k_descale = quantize_int8(k, dim=qunatization_dim) + v, _, v_descale = quantize_int8(v, dim=qunatization_dim) + + # In real world use case, the p scale would be a parameter trained by the model. + p_scale = p_descale = None + # The p shape is always bhqk + if quantize_p: + _, nheads_q, _, _ = get_shape_from_layout(q, k, input_metadata) + p_scale = torch.full((1, nheads_q, 1, 1), 127, dtype=torch.float32, device="cuda") + p_descale = 1 / p_scale + + # We are not multiplying the scales togather to get qk_desale / o_descale e.g. + # qk_desale = q_descale * k_descale + # o_desale = p_descale * v_descale + # it results in very small fp e.g. 0,0002, losing precision. They are applied on the run. + input_metadata.set_int8_params(q_descale=q_descale, k_descale=k_descale, v_descale=v_descale, + # By default p_scaling is not enabled + p_scale=p_scale, p_descale=p_descale) + + return q, k, v + + +def input_helper(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, dtype, layout, requires_grad=True): + torch.manual_seed(20) + + # Initialize q, k, v + if layout == 'bhsd': + q_tensor_shape = (Z, HQ, N_CTX_Q, D_HEAD) + k_tensor_shape = (Z, HK, N_CTX_K, D_HEAD) + elif layout == 'bshd': + q_tensor_shape = (Z, N_CTX_Q, HQ, D_HEAD) + k_tensor_shape = (Z, N_CTX_K, HK, D_HEAD) + else: + assert False, 'Got unsupported tensor layout' + q = torch.randn(q_tensor_shape, dtype=dtype, device="cuda", requires_grad=requires_grad) + k = torch.randn(k_tensor_shape, dtype=dtype, device="cuda", requires_grad=requires_grad) + v = torch.randn(k_tensor_shape, dtype=dtype, device="cuda", requires_grad=requires_grad) + + sm_scale = D_HEAD**-0.5 + input_metadata = MetaData(sm_scale=sm_scale) + input_metadata.max_seqlens_q = N_CTX_Q + input_metadata.max_seqlens_k = N_CTX_K + input_metadata.layout = layout + return q, k, v, input_metadata + + +def varlen_input_helper(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, dtype, equal_seqlens=False): + torch.manual_seed(20) + + # Random sequence lengths. Using N_CTX as kind of max of sum of individual seqs + if not equal_seqlens: + max_seqlens_q = N_CTX_Q // Z + max_seqlens_k = N_CTX_K // Z + if N_CTX_Q == N_CTX_K: + seqlens_q = torch.randint(1, max_seqlens_q + 1, (Z, ), dtype=torch.int32) + seqlens_k = seqlens_q + else: + seqlens_q = torch.randint(1, max_seqlens_q + 1, (Z, ), dtype=torch.int32) + seqlens_k = torch.randint(1, max_seqlens_k + 1, (Z, ), dtype=torch.int32) + else: + seqlens_q = torch.full((Z, ), N_CTX_Q // Z) + seqlens_k = torch.full((Z, ), N_CTX_K // Z) + + # Calculate cumulative sequence lengths + cu_seqlens_q = torch.cat([torch.tensor([0], dtype=torch.int32), seqlens_q.cumsum(dim=0, dtype=torch.int32)]) + cu_seqlens_k = torch.cat([torch.tensor([0], dtype=torch.int32), seqlens_k.cumsum(dim=0, dtype=torch.int32)]) + cu_seqlens_q = cu_seqlens_q.to(device="cuda") + cu_seqlens_k = cu_seqlens_k.to(device="cuda") + + # Initialize q, k, v with variable lengths + total_q = cu_seqlens_q[-1].item() + total_k = cu_seqlens_k[-1].item() + q = torch.randn((total_q, HQ, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_() + k = torch.randn((total_k, HK, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_() + v = torch.randn((total_k, HK, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0., std=0.5).requires_grad_() + sm_scale = D_HEAD**-0.5 + input_metadata = MetaData(sm_scale=sm_scale) + input_metadata.set_varlen_params(cu_seqlens_q, cu_seqlens_k) + return q, k, v, input_metadata + + +@pytest.mark.parametrize('Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD', [ + (4, 48, 24, 1024, 1024, 64), + (1, 24, 6, 8192, 8192, 64), + (1, 4, 2, 16384, 16384, 128), + (2, 16, 4, 1020, 987, 128), + (2, 16, 4, 15498, 2, 128), + (2, 16, 2, 7, 16219, 64), + (4, 48, 12, 1, 1, 64), + (4, 48, 48, 1, 1, 128), + (4, 48, 24, 3, 3, 128), + (4, 48, 48, 1001, 990, 64), + (1, 8, 8, 8081, 7099, 64), + (1, 4, 4, 16330, 15989, 128), + (4, 4, 1, 1024, 1024, 33), + (4, 4, 2, 65, 1018, 65), + (4, 4, 4, 128, 128, 65), + (4, 4, 4, 113, 123, 1), +]) +@pytest.mark.parametrize('causal', [True, False]) +@pytest.mark.parametrize('use_alibi', [True, False]) +@pytest.mark.parametrize('layout', ['bshd', 'bhsd']) +def test_op_fwd(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, causal, use_alibi, layout, dtype=torch.float16): + torch.manual_seed(20) + q, k, v, input_metadata = input_helper(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, dtype, layout) + if causal: + input_metadata.need_causal() + + if use_alibi: + # for n heads the set of slopes is the geometric sequence that starts 2^(-8/n) + alibi_slopes = torch.tensor([2**(-8 / HQ * i) for i in range(1, HQ + 1)], dtype=torch.float32, + device="cuda").repeat(Z, 1) + input_metadata.need_alibi(alibi_slopes, Z, HQ) + else: + alibi_slopes = None + + o = torch.empty_like(q) + + # triton implementation + tri_out, _, _ = attention(q, k, v, o, input_metadata) + + # Transpose here if layout is bshd so we have same reference code for all layouts + if layout == 'bshd': + q = q.transpose(1, 2).clone() + k = k.transpose(1, 2).clone() + v = v.transpose(1, 2).clone() + # Replicate K and V if using MQA/GQA + if HQ != HK: + k = k.view(k.shape[0], k.shape[1], -1, k.shape[2], + k.shape[3]).expand(-1, -1, HQ // HK, -1, -1).reshape(k.shape[0], -1, k.shape[2], k.shape[3]) + v = v.view(v.shape[0], v.shape[1], -1, v.shape[2], + v.shape[3]).expand(-1, -1, HQ // HK, -1, -1).reshape(v.shape[0], -1, v.shape[2], v.shape[3]) + + scores = torch.einsum('bhqd,bhkd->bhqk', q, k).float() * input_metadata.sm_scale + if causal: + mask = torch.tril(torch.ones(N_CTX_Q, N_CTX_K, device="cuda"), diagonal=N_CTX_K - N_CTX_Q) + scores[:, :, mask == 0] = float("-inf") + if use_alibi: + scores += compute_alibi_tensor(alibi_slopes, N_CTX_Q, N_CTX_K) + + p = torch.softmax(scores, dim=-1) + if causal: + # If N_CTX_Q > N_CTX_K, there is at least one row of all -infs going into + # the softmax. This produces a row of NaNs as -inf - -inf == NaN. So we fix + # this by converting the NaNs to 0s, which is what they should be out of the softmax. + nan_mask = torch.isnan(p) + p[nan_mask == 1] = 0 + ref_out = torch.einsum('bhqk,bhkd->bhqd', p.half(), v) + # compare + if layout == 'bshd': + ref_out = ref_out.transpose(1, 2).clone() + ref_out = ref_out + 1 + # torch.testing.assert_close(ref_out, tri_out, atol=2e-2, rtol=2e-2) + print("✅ Triton and Torch match") + + +@pytest.mark.parametrize('Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD', [ + (4, 48, 24, 1024, 1024, 64), + (1, 24, 6, 8192, 8192, 64), + (1, 4, 2, 16384, 16384, 128), + (2, 16, 4, 1020, 987, 128), + (2, 16, 4, 15498, 2, 128), + (2, 16, 2, 7, 16219, 64), + (4, 48, 12, 1, 1, 64), + (4, 48, 48, 1, 1, 128), + (4, 48, 24, 3, 3, 128), + (4, 48, 48, 1001, 990, 64), + (1, 8, 8, 8081, 7099, 64), + (1, 4, 4, 16330, 15989, 128), + (4, 4, 1, 1024, 1024, 33), + (4, 4, 2, 65, 1018, 65), + (4, 4, 4, 128, 128, 65), + (4, 4, 4, 113, 123, 1), +]) +@pytest.mark.parametrize('causal', [True, False]) +@pytest.mark.parametrize('use_alibi', [True, False]) +@pytest.mark.parametrize('layout', ['bshd', 'bhsd']) +@pytest.mark.parametrize('persistent', ['fixed', 'dynamic']) +def test_op_persistent_fwd(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, causal, use_alibi, layout, persistent, + dtype=torch.float16): + torch.manual_seed(20) + q, k, v, input_metadata = input_helper(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, dtype, layout) + if causal: + input_metadata.need_causal() + + if use_alibi: + # for n heads the set of slopes is the geometric sequence that starts 2^(-8/n) + alibi_slopes = torch.tensor([2**(-8 / HQ * i) for i in range(1, HQ + 1)], dtype=torch.float32, + device="cuda").repeat(Z, 1) + input_metadata.need_alibi(alibi_slopes, Z, HQ) + else: + alibi_slopes = None + + input_metadata.set_persistent(persistent) + + o = torch.empty_like(q) + + # triton implementation + tri_out, _, _ = attention(q, k, v, o, input_metadata) + + # Transpose here if layout is bshd so we have same reference code for all layouts + if layout == 'bshd': + q = q.transpose(1, 2).clone() + k = k.transpose(1, 2).clone() + v = v.transpose(1, 2).clone() + # Replicate K and V if using MQA/GQA + if HQ != HK: + k = k.view(k.shape[0], k.shape[1], -1, k.shape[2], + k.shape[3]).expand(-1, -1, HQ // HK, -1, -1).reshape(k.shape[0], -1, k.shape[2], k.shape[3]) + v = v.view(v.shape[0], v.shape[1], -1, v.shape[2], + v.shape[3]).expand(-1, -1, HQ // HK, -1, -1).reshape(v.shape[0], -1, v.shape[2], v.shape[3]) + + scores = torch.einsum('bhqd,bhkd->bhqk', q, k).float() * input_metadata.sm_scale + if causal: + mask = torch.tril(torch.ones(N_CTX_Q, N_CTX_K, device="cuda"), diagonal=N_CTX_K - N_CTX_Q) + scores[:, :, mask == 0] = float("-inf") + if use_alibi: + scores += compute_alibi_tensor(alibi_slopes, N_CTX_Q, N_CTX_K) + + p = torch.softmax(scores, dim=-1) + if causal: + # If N_CTX_Q > N_CTX_K, there is at least one row of all -infs going into + # the softmax. This produces a row of NaNs as -inf - -inf == NaN. So we fix + # this by converting the NaNs to 0s, which is what they should be out of the softmax. + nan_mask = torch.isnan(p) + p[nan_mask == 1] = 0 + ref_out = torch.einsum('bhqk,bhkd->bhqd', p.half(), v) + # compare + if layout == 'bshd': + ref_out = ref_out.transpose(1, 2).clone() + torch.testing.assert_close(ref_out, tri_out, atol=2e-2, rtol=2e-2) + + +@pytest.mark.parametrize('Z, H, N_CTX_Q, N_CTX_K, D_HEAD', [ + (4, 48, 1024, 1024, 64), + (4, 12, 8192, 8192, 64), + (2, 4, 16384, 16384, 128), + (2, 16, 1020, 987, 128), + (2, 4, 7, 16219, 64), + (4, 48, 1, 1, 64), + (4, 48, 1, 1, 128), + (4, 48, 3, 3, 128), + (4, 48, 1001, 990, 64), + (1, 8, 8081, 7099, 64), + (1, 8, 16330, 15989, 128), + (4, 4, 1024, 1024, 33), + (4, 4, 65, 1019, 65), + (4, 4, 128, 128, 65), + (4, 4, 113, 123, 1), +]) +@pytest.mark.parametrize('causal', [True, False]) +@pytest.mark.parametrize('quantize_p', [True, False]) +@pytest.mark.parametrize('layout', ['bhsd']) +def test_op_fwd_int8(Z, H, N_CTX_Q, N_CTX_K, D_HEAD, causal, quantize_p, layout, dtype=torch.float16): + torch.manual_seed(20) + + # Disable grad to save memeory it won't run into OOM on CI machine. + q, k, v, input_metadata = input_helper(Z, H, H, N_CTX_Q, N_CTX_K, D_HEAD, dtype, layout, requires_grad=False) + if causal: + input_metadata.need_causal() + + o = torch.empty_like(q) + + q_quantized, k_quantized, v_quantized = quantize_input(q, k, v, input_metadata, quantize_p=quantize_p) + + tri_out, _, best_configs = attention(q_quantized, k_quantized, v_quantized, o, input_metadata) + + # Compute scores + q_descale, k_descale, v_descale = input_metadata.q_descale, input_metadata.k_descale, input_metadata.v_descale + scores = (torch.einsum('bhqd,bhkd->bhqk', q_quantized.half(), k_quantized.half()) * q_descale * + k_descale) * input_metadata.sm_scale + + if causal: + mask = torch.tril(torch.ones(N_CTX_Q, N_CTX_K, device="cuda"), diagonal=N_CTX_K - N_CTX_Q) + scores[:, :, mask == 0] = float("-inf") + + # Quantization with tiling + if quantize_p: + tile_size = best_configs.kwargs["BLOCK_N"] # We need the tiling to match Block_N to work + m_i = torch.full((Z, H, N_CTX_Q), float('-inf'), device='cuda', dtype=torch.float32) + acc = torch.zeros((Z, H, N_CTX_Q, D_HEAD), device='cuda', dtype=torch.float32) + l_i = torch.zeros_like(m_i) + + for i in range(0, N_CTX_K, tile_size): + qk_tile = scores[:, :, :, i:i + tile_size] + v_tile = v_quantized[:, :, i:i + tile_size] + m_ij = torch.max(m_i, torch.max(qk_tile, dim=-1).values) + qk_tile -= m_ij.unsqueeze(-1) + p_tile = torch.exp(qk_tile) + l_ij = torch.sum(p_tile, dim=-1) + p_tile = (p_tile * input_metadata.p_scale).to(torch.int8) + + alpha = torch.exp(m_i - m_ij) + # We need float here since both p and v are quantized. So they might overflow the fp16 range. + acc = acc * alpha.unsqueeze(-1) + torch.einsum('bhqk,bhkd->bhqd', p_tile.float(), v_tile.float()) + m_i = m_ij + l_i = alpha * l_i + l_ij + + l_recip = 1 / l_i.unsqueeze(-1) + acc = acc * input_metadata.p_descale * input_metadata.v_descale * l_recip + ref_out = acc.to(torch.float16) + else: + p = torch.softmax(scores, dim=-1) + ref_out = (torch.einsum('bhqk,bhkd->bhqd', p.float(), v_quantized.float()) * v_descale).to(torch.float16) + + if causal: + nan_mask = torch.isnan(ref_out) + ref_out[nan_mask] = 0 + + torch.testing.assert_close(ref_out, tri_out, atol=2e-2, rtol=2e-2) + + +@pytest.mark.parametrize('Z, H, N_CTX_Q, N_CTX_K, D_HEAD', [ + (4, 48, 1024, 1024, 64), + (4, 12, 8192, 8192, 64), + (2, 4, 16384, 16384, 128), + (2, 16, 1020, 987, 128), + (2, 4, 7, 16219, 64), + (4, 48, 1, 1, 64), + (4, 48, 1, 1, 128), + (4, 48, 3, 3, 128), + (4, 48, 1001, 990, 64), + (1, 8, 8081, 7099, 64), + (1, 8, 16330, 15989, 128), + (4, 4, 1024, 1024, 33), + (4, 4, 65, 1019, 65), + (4, 4, 128, 128, 65), + (4, 4, 113, 123, 1), +]) +@pytest.mark.parametrize('causal', [True, False]) +@pytest.mark.parametrize('layout', ['bhsd']) +def test_op_fwd_int8_kv(Z, H, N_CTX_Q, N_CTX_K, D_HEAD, causal, layout, dtype=torch.float16): + torch.manual_seed(20) + + q, k, v, input_metadata = input_helper(Z, H, H, N_CTX_Q, N_CTX_K, D_HEAD, dtype, layout) + if causal: + input_metadata.need_causal() + + o = torch.empty_like(q) + + _, k_quantized, v_quantized = quantize_input(q, k, v, input_metadata, int8_kv=True) + k_descale, v_descale = input_metadata.k_descale, input_metadata.v_descale + k_dequantized = (k_quantized * k_descale).half() + v_dequantized = (v_quantized * v_descale).half() + + tri_out, _, _ = attention(q, k_quantized, v_quantized, o, input_metadata) + + # Compute scores + scores = torch.einsum('bhqd,bhkd->bhqk', q, k_dequantized).float() * input_metadata.sm_scale + + if causal: + mask = torch.tril(torch.ones(N_CTX_Q, N_CTX_K, device="cuda"), diagonal=N_CTX_K - N_CTX_Q) + scores[:, :, mask == 0] = float("-inf") + + p = torch.softmax(scores, dim=-1) + ref_out = torch.einsum('bhqk,bhkd->bhqd', p.half(), v_dequantized).to(torch.float16) + + if causal: + nan_mask = torch.isnan(ref_out) + ref_out[nan_mask] = 0 + + torch.testing.assert_close(ref_out, tri_out, atol=2e-2, rtol=2e-2) + + +@pytest.mark.parametrize('Z, H, N_CTX_Q, N_CTX_K, D_HEAD', [ + (4, 48, 1024, 1024, 64), + (4, 12, 8192, 8192, 64), + (2, 4, 16384, 16384, 128), + (2, 16, 1020, 987, 128), + (2, 4, 7, 16219, 64), + (4, 48, 1, 1, 64), + (4, 48, 1, 1, 128), + (4, 48, 3, 3, 128), + (4, 48, 1001, 990, 64), + (1, 8, 8081, 7099, 64), + (1, 8, 16330, 15989, 128), + (4, 4, 1024, 1024, 33), + (4, 4, 65, 1019, 65), + (4, 4, 128, 128, 65), + # TODO: This config fails. Disabled until triaged and fixed. + # (4, 4, 113, 123, 1), + # (2, 16, 15498, 2, 128), +]) +@pytest.mark.parametrize('causal', [True, False]) +@pytest.mark.parametrize('use_bias', [True]) +@pytest.mark.parametrize('dtype', [torch.float16, torch.bfloat16]) +def test_op_fwd_bias(Z, H, N_CTX_Q, N_CTX_K, D_HEAD, causal, use_bias, dtype): + torch.manual_seed(20) + sm_scale = D_HEAD**-0.5 + input_metadata = MetaData(sm_scale=sm_scale) + q, k, v, input_metadata = input_helper(Z, H, H, N_CTX_Q, N_CTX_K, D_HEAD, dtype, layout='bhsd') + if causal: + input_metadata.need_causal() + if use_bias: + bias = torch.randn((1, H, N_CTX_Q, N_CTX_K), dtype=dtype, device="cuda") + input_metadata.need_bias(bias, Z, H, N_CTX_Q, N_CTX_K) + else: + bias = None + o = torch.empty_like(q) + + # triton implementation + tri_out, _, _ = attention(q, k, v, o, input_metadata) + # reference implementation:171 + + scores = torch.einsum('bhqd,bhkd->bhqk', q, k).float() * sm_scale + if causal: + mask = torch.tril(torch.ones(N_CTX_Q, N_CTX_K, device="cuda"), diagonal=N_CTX_K - N_CTX_Q) + scores[:, :, mask == 0] = float("-inf") + if use_bias: + scores += input_metadata.bias + p = torch.softmax(scores, dim=-1) + if causal: + # If N_CTX_Q > N_CTX_K, there is at least one row of all -infs going into + # the softmax. This produces a row of NaNs as -inf - -inf == NaN. So we fix + # this by converting the NaNs to 0s, which is what they should be out of the softmax. + nan_mask = torch.isnan(p) + p[nan_mask == 1] = 0 + + ref_out = torch.einsum('bhqk,bhkd->bhqd', p.to(dtype), v) + # compare + torch.testing.assert_close(ref_out, tri_out, atol=2e-2, rtol=2e-2) + + +@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', [(4, 48, 8192, 64), (4, 48, 256, 64), (4, 48, 512, 64), + (4, 48, 1024, 64), (8, 48, 4096, 64), (4, 48, 8192, 64), + (4, 48, 128, 128), (4, 48, 4096, 128), (4, 48, 16384, 128), + (4, 16, 1024, 128), (4, 16, 8192, 128), (32, 48, 8192, 128)]) +@pytest.mark.parametrize('causal', [True, False]) +def test_op_varlen_fwd(Z, H, N_CTX, D_HEAD, causal, dtype=torch.float16): + + q, k, v, input_metadata = varlen_input_helper(Z, H, H, N_CTX, N_CTX, D_HEAD, dtype) + + tri_out = torch.empty_like(q) + ref_out = torch.empty_like(q) + + for i in range(0, input_metadata.num_contexts): + start_q, start_k = input_metadata.cu_seqlens_q[i], input_metadata.cu_seqlens_k[i] + end_q, end_k = input_metadata.cu_seqlens_q[i + 1], input_metadata.cu_seqlens_k[i + 1] + scores = torch.einsum('qhd,khd->qhk', q[start_q:end_q], k[start_k:end_k]).float() + p = torch.softmax(scores * input_metadata.sm_scale, dim=-1).half() + ref_out[start_q:end_q] = torch.einsum('qhk,khd->qhd', p, v[start_k:end_k]) + attention(q, k, v, tri_out, input_metadata) + torch.testing.assert_close(ref_out, tri_out, atol=1e-2, rtol=1e-2) + + +@pytest.mark.parametrize('Z, HQ, HK, N_CTX, D_HEAD', [(2, 48, 24, 128, 64), (4, 48, 12, 256, 64), (4, 48, 4, 512, 64), + (4, 48, 2, 1024, 64), (8, 48, 6, 4096, 64), (4, 48, 8, 16384, 64), + (4, 64, 16, 128, 128), (4, 64, 4, 4096, 128), + (4, 64, 8, 16384, 128), (4, 16, 4, 1024, 128), + (4, 16, 2, 8192, 128), (32, 128, 32, 8192, 128)]) +@pytest.mark.parametrize('causal', [False]) +def test_op_varlen_mqa_fwd(Z, HQ, HK, N_CTX, D_HEAD, causal, dtype=torch.float16): + q, k, v, input_metadata = varlen_input_helper(Z, HQ, HK, N_CTX, N_CTX, D_HEAD, dtype) + ref_out = torch.empty_like(q) + tri_out = torch.empty_like(q) + # Make KV look like HQ/HK "groups" of HK. Later, we will reshape so the + # size aligns with Q. + k_ref = k.view(k.shape[0], k.shape[1], 1, k.shape[2]).expand(-1, -1, HQ // HK, -1) + v_ref = v.view(v.shape[0], v.shape[1], 1, v.shape[2]).expand(-1, -1, HQ // HK, -1) + for i in range(0, input_metadata.num_contexts): + start_q, start_k = input_metadata.cu_seqlens_q[i], input_metadata.cu_seqlens_k[i] + end_q, end_k = input_metadata.cu_seqlens_q[i + 1], input_metadata.cu_seqlens_k[i + 1] + k_curr = k_ref[start_k:end_k] + k_curr = k_curr.reshape(k_curr.shape[0], -1, k_curr.shape[3]) + v_curr = v_ref[start_k:end_k] + v_curr = v_curr.reshape(v_curr.shape[0], -1, v_curr.shape[3]) + scores = torch.einsum('qhd,khd->qhk', q[start_q:end_q], k_curr).float() + p = torch.softmax(scores * input_metadata.sm_scale, dim=-1).half() + ref_out[start_q:end_q] = torch.einsum('qhk,khd->qhd', p, v_curr) + attention(q, k, v, tri_out, input_metadata) + torch.testing.assert_close(ref_out, tri_out, atol=1e-2, rtol=1e-2) + + +@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', [ + (4, 48, 1024, 64), + (4, 48, 2048, 64), + (2, 48, 4096, 64), + (1, 16, 1024, 64), + (1, 16, 1024, 128), + #(1, 16, 8192, 63), + #(1, 16, 1022, 64), +]) +@pytest.mark.parametrize('qseqlen_not_equal_kseqlen', [None]) +@pytest.mark.parametrize('torch_sdpa_test', [False, True]) +@pytest.mark.parametrize('causal', [True]) +@pytest.mark.parametrize('use_alibi', [False, True]) +def test_op_bwd(Z, H, N_CTX, D_HEAD, qseqlen_not_equal_kseqlen, causal, torch_sdpa_test, use_alibi, + dtype=torch.float16): + pytest.skip() + torch.manual_seed(20) + if qseqlen_not_equal_kseqlen is not None: + seqlen_q = qseqlen_not_equal_kseqlen + else: + seqlen_q = N_CTX + seqlen_k = N_CTX + + if causal and ((N_CTX - 1) & N_CTX): + pytest.skip() + if causal and seqlen_q != seqlen_k: + pytest.skip() + + sm_scale = D_HEAD**-0.5 + input_metadata = MetaData(sm_scale=sm_scale) + input_metadata.max_seqlens_q = seqlen_q + input_metadata.max_seqlens_k = seqlen_k + + dropout_p = 0 + q = (torch.empty((Z, H, seqlen_q, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.0, std=0.5).requires_grad_()) + k = (torch.empty((Z, H, seqlen_k, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.0, std=0.5).requires_grad_()) + v = (torch.empty((Z, H, seqlen_k, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.0, std=0.5).requires_grad_()) + o = torch.empty_like(q) + + if causal: + input_metadata.need_causal() + + if use_alibi and not torch_sdpa_test: + # for n heads the set of slopes is the geometric sequence that starts 2^(-8/n) + alibi_slopes = torch.tensor([2**(-8 / H * i) for i in range(1, H + 1)], dtype=torch.float32, + device="cuda").repeat(Z, 1) + input_metadata.need_alibi(alibi_slopes, Z, H) + dout = torch.randn_like(q) + # reference implementation + if torch_sdpa_test: + ref_out, ref_softmax = torch.ops.aten._scaled_dot_product_attention_math(q, k, v, dropout_p=dropout_p, + is_causal=causal, scale=sm_scale, + dropout_mask=None) + ref_out.backward(dout.to(device=ref_out.device, dtype=ref_out.dtype)) + ref_dv, v.grad = v.grad.clone(), None + ref_dk, k.grad = k.grad.clone(), None + ref_dq, q.grad = q.grad.clone(), None + else: + M = torch.tril(torch.ones((seqlen_q, seqlen_k), device="cuda")) + p = torch.matmul(q, k.transpose(2, 3)) * sm_scale + if use_alibi: + p += compute_alibi_tensor(alibi_slopes, N_CTX, N_CTX) + if causal: + p[:, :, M == 0] = float("-inf") + + p = torch.softmax(p.float(), dim=-1).type(dtype=p.dtype) + ref_out = torch.matmul(p, v) + ref_out.backward(dout) + ref_dv, v.grad = v.grad.clone(), None + ref_dk, k.grad = k.grad.clone(), None + ref_dq, q.grad = q.grad.clone(), None + + # # triton implementation + tri_out, _, _ = attention(q, k, v, o, input_metadata) + tri_out.backward(dout) + tri_dv, v.grad = v.grad.clone(), None + tri_dk, k.grad = k.grad.clone(), None + tri_dq, q.grad = q.grad.clone(), None + # test + #print("reference") + #print(ref_dv) + #print("tri") + #print(tri_dv) + # compare + torch.testing.assert_close(ref_out, tri_out, atol=1e-2, rtol=0) + # The current block size for gfx90a and gfx908 series is 64x64. This results in + # larger differences in float results due to rounding. + + if dtype == torch.bfloat16: + ATOL = 1e-1 * max(1.0, (seqlen_q + D_HEAD) / 64.0) + if dtype == torch.float32: + ATOL = 1e-3 * max(1.0, (seqlen_q + D_HEAD) / 64.0) + else: + ATOL = 1e-1 * max(1.0, (seqlen_q + D_HEAD) / 64.0) + + RTOL = 0 + + torch.testing.assert_close(ref_dv, tri_dv, atol=ATOL, rtol=RTOL) + torch.testing.assert_close(ref_dk, tri_dk, atol=ATOL, rtol=RTOL) + torch.testing.assert_close(ref_dq, tri_dq, atol=ATOL, rtol=RTOL) + + +def nonvarlen_benchmark_configs(): + configs = [ + # (16, 16, 16, 1024, 1024), + # (8, 16, 16, 2048, 2048), + # (4, 16, 16, 4096, 4096), + # (2, 16, 16, 8192, 8192), + # (1, 16, 16, 16384, 16384), + # (2, 48, 48, 1024, 1024), + # (2, 48, 48, 2048, 1024), + # (2, 48, 48, 4096, 8192), + # (2, 48, 48, 8192, 4096), + (2, 48, 48, 16384, 8192), + # (8, 16, 16, 1989, 15344), + # (4, 16, 16, 4097, 163), + # (2, 16, 16, 8122, 2159), + # (1, 16, 16, 16281, 7), + # (2, 48, 48, 1021, 1020), + # (2, 48, 48, 2001, 2048), + # (2, 48, 48, 3996, 9639), + # (2, 48, 48, 8181, 1021), + ] + return configs + + +def varlen_benchmark_configs(): + configs = [ + # (2, 16, 4, 1024, 1024), + # (8, 16, 2, 2048, 2048), + # (4, 16, 8, 4096, 4096), + # (2, 16, 4, 8192, 8192), + # (2, 16, 8, 16384, 16384), + # (2, 48, 12, 1024, 1024), + # (2, 48, 24, 2048, 2048), + # (2, 48, 8, 4096, 4096), + # (2, 48, 4, 8192, 8192), + (2, 48, 2, 16384, 16384), + # (2, 64, 32, 1024, 1024), + # (4, 64, 16, 2048, 2048), + # (4, 64, 8, 4096, 4096), + # (4, 64, 32, 8192, 8192), + # (4, 128, 16, 16384, 16384), + ] + return configs + + +def model_benchmark_configs(args): + config_file = args.model_configs + configs = get_model_configs(config_path=config_file, model_families=["llama3"], model=args.model) + fa_configs = [] + batch_size = args.b if args.b else 1 + + for model_name, config in configs.items(): + HQ = config["num_attention_heads"] + HK = HQ if config["num_key_value_heads"] is None else config["num_key_value_heads"] + N_CTX_Q = args.sq if args.sq else 8192 + N_CTX_K = args.sk if args.sk else N_CTX_Q + HEAD_DIM = config["hidden_size"] // HQ + fa_configs.append((model_name, batch_size, HQ, HK, N_CTX_Q, N_CTX_K, HEAD_DIM)) + + return fa_configs + + +def run_benchmark(custom, args): + + dtype = arg_to_torch_dtype[args.dtype] + hk = args.hq if not args.hk else args.hk + sk = args.sq if not args.sk else args.sk + head_size = 128 if not args.d else args.d + mode = 'fwd' + x_names = ['BATCH', 'HQ', 'HK', 'N_CTX_Q', 'N_CTX_K'] + causal = args.causal if not args.model else True + int8 = args.int8 + quantize_p = args.quantize_p and int8 + int8_kv = args.int8_kv and int8 + varlen = True if args.model else args.layout == 'thd' + configs = [] + plot_name = f'fused-attention-{mode}-d{head_size}-layout{args.layout}' + extra_args = {'D_HEAD': head_size, 'dtype': dtype, 'causal': causal, 'mode': mode} + if custom: + x_vals_list = [(args.b, args.hq, hk, args.sq, sk)] + else: + if varlen: + x_vals_list = varlen_benchmark_configs() + else: + x_vals_list = nonvarlen_benchmark_configs() + + if mode == 'bwd': + # Only those with N_CTX_Q == N_CTX_K work + new_x = [] + for v in x_vals_list: + if v[-1] == v[-2]: + new_x.append(v) + x_vals_list = new_x + + if args.model: + x_vals_list = model_benchmark_configs(args) + x_names = ['model', 'BATCH', 'HQ', 'HK', 'N_CTX_Q', 'N_CTX_K', 'D_HEAD'] + plot_name = f'fused-attention-{mode}-layout{args.layout}' + extra_args = {'dtype': dtype, 'causal': causal, 'mode': mode} + print_time = args.return_time + + line_vals = ['triton', 'torch'] # 'Time (ms)' if print_time else 'TFLOPS' + configs.append( + triton.testing.Benchmark(x_names=x_names, x_vals=x_vals_list, line_arg='provider', line_vals=line_vals, + line_names=line_vals, styles=[('green', '-'), ('red', '-')], + ylabel='Time (ms)' if print_time else 'TFLOPS', plot_name=plot_name, args=extra_args)) + + @triton.testing.perf_report(configs) + def bench_flash_attention(BATCH, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, dtype, causal, mode, provider, device="cuda", + model=None): + assert mode in ["fwd", "bwd"] + assert not (int8_kv and quantize_p) + warmup = 25 + rep = 100 + # TODO: Enable bias after testing. + # if use_bias: + # bias = torch.randn((1, H, N_CTX, N_CTX), dtype=torch.float32, device="cuda") + # input_metadata.need_bias(bias, BATCH, H, N_CTX, N_CTX) + # else: + # bias = None + # bias = None + + # Bwd pass only supports causal=True right now + if mode == 'bwd': + causal = True + + flops_per_matmul = 0 + if varlen: + q, k, v, input_metadata = varlen_input_helper(BATCH, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, dtype, + args.equal_seqlens) + for i in range(0, input_metadata.num_contexts): + seqlen_q = (input_metadata.cu_seqlens_q[i + 1] - input_metadata.cu_seqlens_q[i]).item() + seqlen_k = (input_metadata.cu_seqlens_k[i + 1] - input_metadata.cu_seqlens_k[i]).item() + # x2 in both cases for 2 GEMMs + if causal: + # If seqlen_q != seqlen_k then the causal mask ignores computation + # depending on which seqlen is larger. Either the lower triangle, or right triangle + # If seqlen_q is greater than seqlen_k, the lower triangle is non zero + # where the last row has seqlen_k valid element, the second last row has + # seqlen_k - 1 valid elements and so on until one element is valid in the + # seqlen_q - seqlen_k row, hence total valid elements are 1+2+...+seqlen_k + # which is seqlen_k*(seqlen_k+1)/2 + # If seqlen_q is less than seqlen_k, then we count the zero elements + # the first row has seqlen_q-1 zero elements, the second row has seqlen_q-2 + # zero elements and so on until the second last row has 1 zero element + # Total zero elements are 1+2+...+(seqlen_q-1) = seqlen_q*(seqlen_q-1)/2 + # Total non zero elements are seqlen_q*seqlen_k - (seqlen_q*(seqlen_q-1)/2) + valid_out_elements = ((seqlen_k**2 + seqlen_k) / 2) if seqlen_q > seqlen_k else \ + (seqlen_q * seqlen_k - ((seqlen_q**2 - seqlen_q) / 2)) + flops_per_matmul += valid_out_elements * HQ * D_HEAD * 2 + else: + flops_per_matmul += seqlen_q * seqlen_k * HQ * D_HEAD * 2 + else: + q, k, v, input_metadata = input_helper(BATCH, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, dtype, args.layout) + if causal: + # Same calculation as if varlen/if causal above + valid_out_elements = ((N_CTX_K**2 + N_CTX_K) / 2) if N_CTX_Q > N_CTX_K else \ + (N_CTX_Q * N_CTX_K - ((N_CTX_Q**2 - N_CTX_Q) / 2)) + flops_per_matmul = 2.0 * BATCH * HQ * valid_out_elements * D_HEAD + else: + flops_per_matmul = 2.0 * BATCH * HQ * N_CTX_Q * N_CTX_K * D_HEAD + if causal: + input_metadata.need_causal() + + if "triton" in provider: + o = torch.empty_like(q) + if int8: + q, k, v = quantize_input(q, k, v, input_metadata, quantize_p=quantize_p, int8_kv=int8_kv) + input_metadata.set_persistent(args.persistent) + fn = lambda: attention(q, k, v, o, input_metadata) + if mode == 'bwd': + o, _, _ = fn() + do = torch.randn_like(o) + fn = lambda: o.backward(do, retain_graph=True) + + elif "torch" in provider and args.layout in ["thd", "bhsd", "bshd"]: + # torch requires the layout to be (b (optional),...,h,s,d) + if args.layout in ["thd", "bshd"]: + q = q.transpose(-3, -2) + k = k.transpose(-3, -2) + v = v.transpose(-3, -2) + # check if GQA + HQ = q.shape[-3] + HK = k.shape[-3] + if HQ != HK: # TODO: sdpa(..., enable_gqa=True work) should work + k = k.repeat_interleave(q.size(-3) // k.size(-3), -3) + v = v.repeat_interleave(q.size(-3) // v.size(-3), -3) + + fn = lambda: torch.nn.functional.scaled_dot_product_attention( + q, k, v, attn_mask=None, dropout_p=0.0, is_causal=causal, scale=input_metadata.sm_scale) + else: + assert False, f"Unknown provider {provider} in flash-attention." + + ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep) + total_flops = 2 * flops_per_matmul + if mode == "bwd": + total_flops *= 2.5 # 2.0(bwd) + 0.5(recompute) + if print_time: + return ms + else: + return total_flops / ms * 1e-9 + + bench_flash_attention.run(save_path=".", print_data=True, show_plots=True) + + +def supported_layouts(): + layouts = \ + 'bhsd: Q, K, V are individual tensors of [batch, num_heads, seqlen_q/k, head_size]' \ + 'bshd: Q, K, V are individual tensors of [batch, seqlen_q/k, num_heads, head_size]' \ + 'thd: Q, K, V are individual tensors of [total_q/k, num_heads, head_size]' \ + 'This layout is sometimes called "varlen" or "grouped" layout.' + return layouts + + +def parse_args(): + parser = argparse.ArgumentParser( + prog="Benchmark FlashAttention", + allow_abbrev=False, + ) + parser.add_argument('-model_configs', type=str, default="model_configs.json", help="Model config json file.") + + available_models = get_available_models(model_families=["llama3"]) # Dynamically load model names + model_help = ( + "Model name to benchmark. Select from: [" + ", ".join(available_models) + + "]. Use 'all' to benchmark all models. Not providing runs the default benchmark script with custom configs.") + parser.add_argument('-model', type=str, default=None, help=model_help) + parser.add_argument("-b", type=int, default=0) + parser.add_argument("-hq", type=int, default=0) + parser.add_argument("-hk", type=int, default=0) + parser.add_argument("-sq", type=int, default=0) + parser.add_argument("-sk", type=int, default=0) + parser.add_argument("-equal_seqlens", action='store_true', default=False, + help='If specified, each context within the thd layout' \ + ' has same seqlen as sq and sk') + parser.add_argument("-d", type=int, default=0) + parser.add_argument("-causal", action='store_true', default=False) + parser.add_argument("-int8", action='store_true', default=False) + parser.add_argument("-quantize_p", action='store_true', default=False) + parser.add_argument("-int8_kv", action='store_true', default=False) + parser.add_argument("-dtype", default='fp16') + parser.add_argument("-return_time", action='store_true', default=False) + parser.add_argument("-layout", type=str, default='bhsd', help=supported_layouts()) + parser.add_argument( + "-persistent", nargs='?', const='fixed', choices=['fixed', 'dynamic'], default=None, + help="Enable persistent kernels. Use '-persistent dynamic' for dynamic scheduling of the tiles.") + return parser.parse_args() + + +arg_to_torch_dtype = {'fp16': torch.float16, 'bf16': torch.bfloat16, 'fp32': torch.float32} + + +def main(): + args = parse_args() + custom_config = False + assert args.layout == 'thd' or not args.equal_seqlens or args.model, \ + "Equal sequence lengths arg must be used with the thd layout or a model config." + if args.hq or args.hk or args.d: + custom_config = True + assert args.b and args.hq and args.sq and args.d, \ + "If custom config is specified, please provide \ + all of batch, number of Q heads, Q sequence length \ + and head size." + + if args.model: + assert not (args.hq or args.hk or args.d), \ + "Specifying model fixes hq, hk and d already. Do not provide them!" + + assert args.dtype in arg_to_torch_dtype, \ + "Only fp16, bf16 and f32 types currently supported." + + if args.model: + print("Note: Model config sets causal masking and THD layout (varlen) by default.") + + run_benchmark(custom_config, args) + + +if __name__ == '__main__': + test_op_fwd(2, 48, 48, 16384, 8192, 128, False, False, 'bshd', dtype=torch.float16) + sys.exit(main()) diff --git a/fa/model_configs.json b/fa/model_configs.json new file mode 100644 index 000000000000..5f0c28cd0e23 --- /dev/null +++ b/fa/model_configs.json @@ -0,0 +1,42 @@ +{ + "llama3": { + "8B": { + "num_attention_heads": 32, + "num_key_value_heads": 8, + "hidden_size": 4096, + "intermediate_size": 14336, + "vocab_size": 128256 + }, + "70B": { + "num_attention_heads": 64, + "num_key_value_heads": 8, + "hidden_size": 8192, + "intermediate_size": 28672, + "vocab_size": 128256 + }, + "405B": { + "num_attention_heads": 128, + "num_key_value_heads": 8, + "hidden_size": 16384, + "intermediate_size": 53248, + "vocab_size": 128256 + } + }, + "mistral": { + "7B": { + "hidden_size": 4096, + "intermediate_size": 14336, + "num_attention_heads": 32, + "num_key_value_heads": 8, + "vocab_size": 32000 + }, + "22B": { + "hidden_size": 6144, + "intermediate_size": 16384, + "num_attention_heads": 48, + "num_key_value_heads": 8, + "vocab_size": 32000 + } + + } +} diff --git a/fa/utils/__init__.py b/fa/utils/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/fa/utils/benchmark_utils.py b/fa/utils/benchmark_utils.py new file mode 100644 index 000000000000..11c19bcd0c18 --- /dev/null +++ b/fa/utils/benchmark_utils.py @@ -0,0 +1,71 @@ +import os +import json + +# Base directory where configs are located +BASE_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), "../")) + + +def get_model_configs(config_path='model_configs.json', model_families=["llama3"], model="all"): + """ + Load model names from the configuration file. + + Args: + config_path (str): User-provided path to the configuration JSON file. + model_families (list): List of model family names to retrieve. + + Returns: + dict: A dictionary of available models and their configurations for the specified families. + """ + # Resolve config path relative to ./perf-kernels/ + config_path = os.path.join(BASE_DIR, config_path) + + with open(config_path, 'r') as f: + configs = json.load(f) + + # Extract models and their configurations for the specified families + filtered_configs = {} + + for family in model_families: + if family in configs: + # Check if model filtering is required + if model == "all": + # Include all models in the family + for model_size, model_configs in configs[family].items(): + filtered_configs[f"{family}-{model_size}"] = model_configs + else: + # Parse the model string (e.g., llama3_8B or llama3-8B) + delimiter = "_" if "_" in model else "-" + model_parts = model.split(delimiter) + + # Check if the family and size match + if len(model_parts) == 2 and model_parts[0] == family: + model_size = model_parts[1] + if model_size in configs[family]: + filtered_configs[f"{family}-{model_size}"] = configs[family][model_size] + + if not filtered_configs: + print(f"Warning: No models selected for families: {model_families} with filter: '{model}'") + + return filtered_configs + + +def get_available_models(config_file='model_configs.json', model_families=["llama3"]): + """ + Load model names from the configuration file. + + Args: + config_file (str): Path to the configuration JSON file. + model_families (list): List of model family names to retrieve. + + Returns: + list: A list of available models for the specified families. + """ + # Resolve config path relative to ./perf-kernels/ + config_path = os.path.join(BASE_DIR, config_file) + + with open(config_path, 'r') as f: + configs = json.load(f) + + models = [f"{family}-{model}" for family in model_families if family in configs for model in configs[family]] + + return models diff --git a/fa/utils/rocprof_benchmark.py b/fa/utils/rocprof_benchmark.py new file mode 100644 index 000000000000..edccb7f3d314 --- /dev/null +++ b/fa/utils/rocprof_benchmark.py @@ -0,0 +1,59 @@ +import subprocess +import os +import pandas as pd +from prettytable import PrettyTable + + +def run_profiling(triton_dir, batch_size, output_file): + command = [ + "rocprof", "--stats", "-o", output_file, "python", f"{triton_dir}/python/perf-kernels/MLA_decode_rope.py", "-B", + str(batch_size), "-dtype", "bf16", "-use_rope" + ] + subprocess.run(command, check=True) + + +def parse_profiling_output(output_file, kernel_names): + df = pd.read_csv(output_file) + results = {} + for kernel in kernel_names: + kernel_data = df[df['Name'].str.strip('"') == kernel] + if not kernel_data.empty: + results[kernel] = kernel_data['AverageNs'].iloc[0] / 1000.0 + else: + results[kernel] = None + + # Calculate sum of other kernels + other_kernels = df[~df['Name'].str.strip('"').isin(kernel_names)] + other_kernels_sum = other_kernels['AverageNs'].sum() / 1000.0 + results['other_kernels_sum'] = other_kernels_sum + + return results + + +def main(): + triton_dir = os.environ.get("TRITONDIR", "~/triton") # Default to ~/triton if not set + output_file = os.path.expanduser("~/profiling.csv") + kernel_names = ["_fwd_grouped_kernel_stage1_rope.kd", "_fwd_grouped_kernel_stage1.kd"] + batch_sizes = [1, 4, 32, 64, 128] + + results = {B: {} for B in batch_sizes} + for B in batch_sizes: + print(f"Running profiling for B={B}...") + run_profiling(triton_dir, B, output_file) + output_stats_file = os.path.expanduser("~/profiling.stats.csv") + kernel_results = parse_profiling_output(output_stats_file, kernel_names) + results[B] = kernel_results + + table = PrettyTable() + table.field_names = ["B"] + kernel_names + ["Other Kernels Sum (µs)"] + for B in batch_sizes: + row = [B] + [results[B].get(kernel, "N/A") + for kernel in kernel_names] + [results[B].get('other_kernels_sum', "N/A")] + table.add_row(row) + + print("\nProfiling Summary (in microseconds):") + print(table) + + +if __name__ == "__main__": + main() diff --git a/fa/utils/rotary_embedding.py b/fa/utils/rotary_embedding.py new file mode 100644 index 000000000000..a864710601f1 --- /dev/null +++ b/fa/utils/rotary_embedding.py @@ -0,0 +1,283 @@ +# Adapted from https://raw.githubusercontent.com/vllm-project/vllm/refs/tags/v0.6.6.post1/vllm/model_executor/layers/rotary_embedding.py +"""Rotary Positional Embeddings.""" +import math +from typing import Optional, Tuple, Union + +import torch +import torch.nn as nn +# from vllm.model_executor.custom_op import CustomOp + +# from sglang.srt.layers.custom_op_util import register_custom_op + + +def _rotate_neox(x: torch.Tensor) -> torch.Tensor: + x1 = x[..., :x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2:] + + return torch.cat((-x2, x1), dim=-1) + + +def _rotate_gptj(x: torch.Tensor) -> torch.Tensor: + x1 = x[..., ::2] + x2 = x[..., 1::2] + x = torch.stack((-x2, x1), dim=-1) + return x.flatten(-2) + + +def _apply_rotary_emb( + x: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + is_neox_style: bool, +) -> torch.Tensor: + """ + Args: + x: [num_tokens, num_heads, head_size] + cos: [num_tokens, head_size // 2] + sin: [num_tokens, head_size // 2] + is_neox_style: Whether to use the Neox-style or GPT-J-style rotary + positional embeddings. + """ + cos = cos.unsqueeze(-2).to(x.dtype) + sin = sin.unsqueeze(-2).to(x.dtype) + if is_neox_style: + x1, x2 = torch.chunk(x, 2, dim=-1) + else: + x1 = x[..., ::2] + x2 = x[..., 1::2] + o1 = x1 * cos - x2 * sin + o2 = x2 * cos + x1 * sin + if is_neox_style: + return torch.cat((o1, o2), dim=-1) + else: + return torch.stack((o1, o2), dim=-1).flatten(-2) + + +class RotaryEmbedding(nn.Module): + """Original rotary positional embedding.""" + + def __init__( + self, + head_size: int, + rotary_dim: int, + max_position_embeddings: int, + base: int, + is_neox_style: bool, + dtype: torch.dtype, + ) -> None: + super().__init__() + self.head_size = head_size + self.rotary_dim = rotary_dim + self.max_position_embeddings = max_position_embeddings + self.base = base + self.is_neox_style = is_neox_style + self.dtype = dtype + + cache = self._compute_cos_sin_cache() + cache = cache.to(dtype) + self.cos_sin_cache: torch.Tensor + self.register_buffer("cos_sin_cache", cache, persistent=False) + + def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor: + """Compute the inverse frequency.""" + # NOTE(woosuk): To exactly match the HF implementation, we need to + # use CPU to compute the cache and then move it to GPU. However, we + # create the cache on GPU for faster initialization. This may cause + # a slight numerical difference between the HF implementation and ours. + inv_freq = 1.0 / (base**(torch.arange(0, self.rotary_dim, 2, dtype=torch.float) / self.rotary_dim)) + return inv_freq + + def _compute_cos_sin_cache(self) -> torch.Tensor: + """Compute the cos and sin cache.""" + inv_freq = self._compute_inv_freq(self.base) + t = torch.arange(self.max_position_embeddings, dtype=torch.float) + + freqs = torch.einsum("i,j -> ij", t, inv_freq) + cos = freqs.cos() + sin = freqs.sin() + cache = torch.cat((cos, sin), dim=-1) + return cache + + def forward_native( + self, + positions: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor, + offsets: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """A PyTorch-native implementation of forward().""" + if offsets is not None: + positions = positions + offsets + positions = positions.flatten() + num_tokens = positions.shape[0] + cos_sin = self.cos_sin_cache.index_select(0, positions) + cos, sin = cos_sin.chunk(2, dim=-1) + + query_shape = query.shape + query = query.view(num_tokens, -1, self.head_size) + query_rot = query[..., :self.rotary_dim] + query_pass = query[..., self.rotary_dim:] + query_rot = _apply_rotary_emb(query_rot, cos, sin, self.is_neox_style) + query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape) + + key_shape = key.shape + key = key.view(num_tokens, -1, self.head_size) + key_rot = key[..., :self.rotary_dim] + key_pass = key[..., self.rotary_dim:] + key_rot = _apply_rotary_emb(key_rot, cos, sin, self.is_neox_style) + key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape) + return query, key + + def extra_repr(self) -> str: + s = f"head_size={self.head_size}, rotary_dim={self.rotary_dim}" + s += f", max_position_embeddings={self.max_position_embeddings}" + s += f", base={self.base}, is_neox_style={self.is_neox_style}" + return s + + +# Inverse dim formula to find dim based on number of rotations +def _yarn_find_correction_dim( + num_rotations: int, + dim: int, + base: float = 10000, + max_position_embeddings: int = 2048, +) -> float: + return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / (2 * math.log(base)) + + +# Find dim range bounds based on rotations +def _yarn_find_correction_range( + low_rot: int, + high_rot: int, + dim: int, + base: float = 10000, + max_position_embeddings: int = 2048, +) -> Tuple[int, int]: + low = math.floor(_yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings)) + high = math.ceil(_yarn_find_correction_dim(high_rot, dim, base, max_position_embeddings)) + return max(low, 0), min(high, dim - 1) # Clamp values just in case + + +def _yarn_linear_ramp_mask(low: float, high: float, dim: int, dtype: torch.dtype, device) -> torch.Tensor: + if low == high: + high += 0.001 # Prevent singularity + + linear_func = (torch.arange(dim, dtype=dtype, device=device) - low) / (high - low) + ramp_func = torch.clamp(linear_func, 0, 1) + return ramp_func + + +def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float: + if scale <= 1: + return 1.0 + return 0.1 * mscale * math.log(scale) + 1.0 + + +class DeepseekScalingRotaryEmbedding(RotaryEmbedding): + """RotaryEmbedding extended with YaRN method. + + Credits to Peng et al. github.com/jquesnelle/yarn + """ + + def __init__( + self, + head_size: int, + rotary_dim: int, + max_position_embeddings: int, + base: int, + is_neox_style: bool, + scaling_factor: float, + dtype: torch.dtype, + *, + extrapolation_factor: float = 1, + attn_factor: float = 1, + beta_fast: int = 32, + beta_slow: int = 1, + mscale: float = 1, + mscale_all_dim: float = 0, + device: Optional[str] = "cuda", + ) -> None: + self.scaling_factor = scaling_factor + self.extrapolation_factor = extrapolation_factor + self.attn_factor = attn_factor + self.beta_fast = beta_fast + self.beta_slow = beta_slow + # Get n-d magnitude scaling corrected for interpolation. + self.mscale = float( + yarn_get_mscale(self.scaling_factor, float(mscale)) / + yarn_get_mscale(self.scaling_factor, float(mscale_all_dim)) * attn_factor) + self.device = device + super().__init__(head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype) + + def _compute_inv_freq(self, scaling_factor: float) -> torch.Tensor: + pos_freqs = self.base**(torch.arange(0, self.rotary_dim, 2, dtype=torch.float, device=self.device) / + self.rotary_dim) + inv_freq_extrapolation = 1.0 / pos_freqs + inv_freq_interpolation = 1.0 / (scaling_factor * pos_freqs) + + low, high = _yarn_find_correction_range( + self.beta_fast, + self.beta_slow, + self.rotary_dim, + self.base, + self.max_position_embeddings, + ) + # Get n-d rotational scaling corrected for extrapolation + inv_freq_mask = (1 - _yarn_linear_ramp_mask(low, high, self.rotary_dim // 2, dtype=torch.float, + device=self.device)) * self.extrapolation_factor + inv_freq = (inv_freq_interpolation * (1 - inv_freq_mask) + inv_freq_extrapolation * inv_freq_mask) + return inv_freq + + def _compute_cos_sin_cache(self) -> torch.Tensor: + inv_freq = self._compute_inv_freq(self.scaling_factor) + t = torch.arange( + self.max_position_embeddings * self.scaling_factor, + device=self.device, + dtype=torch.float32, + ) + freqs = torch.einsum("i,j -> ij", t, inv_freq) + cos = freqs.cos() * self.mscale + sin = freqs.sin() * self.mscale + cache = torch.cat((cos, sin), dim=-1) + + return cache + + def forward( + self, + positions: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor, + offsets: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """PyTorch-native implementation equivalent to forward().""" + query_rot = query[..., :self.rotary_dim] + key_rot = key[..., :self.rotary_dim] + if self.rotary_dim < self.head_size: + query_pass = query[..., self.rotary_dim:] + key_pass = key[..., self.rotary_dim:] + + self.cos_sin_cache: torch.Tensor = self.cos_sin_cache.to(positions.device) + cos_sin = self.cos_sin_cache[torch.add(positions, offsets) if offsets is not None else positions] + # (max_seq, 64). 32 sin, 32 cos + cos, sin = cos_sin.chunk(2, dim=-1) + + if self.is_neox_style: + # NOTE(woosuk): Here we assume that the positions tensor has the + # shape [batch_size, seq_len]. + cos = cos.repeat(1, 1, 2).unsqueeze(-2) + sin = sin.repeat(1, 1, 2).unsqueeze(-2) + else: + cos = cos.repeat_interleave(2, dim=-1).unsqueeze(-2) + sin = sin.repeat_interleave(2, dim=-1).unsqueeze(-2) + + rotate_fn = _rotate_neox if self.is_neox_style else _rotate_gptj + query_rot = query_rot * cos + rotate_fn(query_rot) * sin + key_rot = key_rot * cos + rotate_fn(key_rot) * sin + + if self.rotary_dim < self.head_size: + query = torch.cat((query_rot, query_pass), dim=-1) + key = torch.cat((key_rot, key_pass), dim=-1) + else: + query = query_rot + key = key_rot + return query, key diff --git a/fa/utils/sglang_ref.py b/fa/utils/sglang_ref.py new file mode 100644 index 000000000000..f862aee48f1e --- /dev/null +++ b/fa/utils/sglang_ref.py @@ -0,0 +1,619 @@ +# Copyright 2023-2024 SGLang Team +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +""" +Memory-efficient attention for decoding. +It supports page size = 1. +""" + +# Adapted from +# https://github.com/ModelTC/lightllm/blob/96353e868a840db4d103138caf15ed9dbea8c186/lightllm/models/deepseek2/triton_kernel/gqa_flash_decoding_stage1.py +# https://github.com/ModelTC/lightllm/blob/96353e868a840db4d103138caf15ed9dbea8c186/lightllm/models/deepseek2/triton_kernel/gqa_flash_decoding_stage2.py + +import logging + +import triton +import triton.language as tl + + +def is_hip(): + return triton.runtime.driver.active.get_current_target().backend == "hip" + + +is_hip_ = is_hip() + +logger = logging.getLogger(__name__) + +# TODO: Remove this when triton>=3.2.0. This issue will not affect performance and accuracy. +logger.warning("The following error message 'operation scheduled before its operands' can be ignored.") + + +@triton.jit +def tanh(x): + # Tanh is just a scaled sigmoid + return 2 * tl.sigmoid(2 * x) - 1 + + +@triton.jit +def _fwd_kernel_stage1( + Q, + K_Buffer, + V_Buffer, + sm_scale, + kv_indptr, + kv_indices, + Att_Out, + stride_qbs, + stride_qh, + stride_buf_kbs, + stride_buf_kh, + stride_buf_vbs, + stride_buf_vh, + stride_mid_ob, + stride_mid_oh, + stride_mid_os, + kv_group_num: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, + BLOCK_DV: tl.constexpr, + BLOCK_N: tl.constexpr, + NUM_KV_SPLITS: tl.constexpr, + logit_cap: tl.constexpr, + Lk: tl.constexpr, + Lv: tl.constexpr, +): + cur_batch = tl.program_id(0) + cur_head = tl.program_id(1) + split_kv_id = tl.program_id(2) + + cur_kv_head = cur_head // kv_group_num + + offs_d = tl.arange(0, BLOCK_DMODEL) + offs_dv = tl.arange(0, BLOCK_DV) + mask_d = offs_d < Lk + mask_dv = offs_dv < Lv + + cur_batch_kv_start_idx = tl.load(kv_indptr + cur_batch) + cur_batch_seq_len = tl.load(kv_indptr + cur_batch + 1) - cur_batch_kv_start_idx + + off_q = cur_batch * stride_qbs + cur_head * stride_qh + offs_d + q = tl.load(Q + off_q, mask=mask_d, other=0.0) + + kv_len_per_split = tl.cdiv(cur_batch_seq_len, NUM_KV_SPLITS) + split_kv_start = kv_len_per_split * split_kv_id + split_kv_end = tl.minimum(split_kv_start + kv_len_per_split, cur_batch_seq_len) + + e_max = -float("inf") + e_sum = 0.0 + acc = tl.zeros([BLOCK_DV], dtype=tl.float32) + + if split_kv_end > split_kv_start: + for start_n in range(split_kv_start, split_kv_end, BLOCK_N): + offs_n = start_n + tl.arange(0, BLOCK_N) + kv_loc = tl.load( + kv_indices + cur_batch_kv_start_idx + offs_n, + mask=offs_n < split_kv_end, + other=0, + ) + offs_buf_k = (kv_loc[:, None] * stride_buf_kbs + cur_kv_head * stride_buf_kh + offs_d[None, :]) + k = tl.load( + K_Buffer + offs_buf_k, + mask=(offs_n[:, None] < split_kv_end) & (mask_d[None, :]), + other=0.0, + ) + qk = tl.sum(q[None, :] * k, 1) + qk *= sm_scale + + if logit_cap > 0: + qk = logit_cap * tanh(qk / logit_cap) + + qk = tl.where(offs_n < split_kv_end, qk, float("-inf")) + + offs_buf_v = (kv_loc[:, None] * stride_buf_vbs + cur_kv_head * stride_buf_vh + offs_dv[None, :]) + v = tl.load( + V_Buffer + offs_buf_v, + mask=(offs_n[:, None] < split_kv_end) & (mask_dv[None, :]), + other=0.0, + ) + + n_e_max = tl.maximum(tl.max(qk, 0), e_max) + re_scale = tl.exp(e_max - n_e_max) + p = tl.exp(qk - n_e_max) + acc *= re_scale + acc += tl.sum(p[:, None] * v, 0) + + e_sum = e_sum * re_scale + tl.sum(p, 0) + e_max = n_e_max + + offs_mid_o = (cur_batch * stride_mid_ob + cur_head * stride_mid_oh + split_kv_id * stride_mid_os + offs_dv) + + tl.store( + Att_Out + offs_mid_o, + acc / e_sum, + mask=(mask_dv), + ) + + offs_mid_o_1 = (cur_batch * stride_mid_ob + cur_head * stride_mid_oh + split_kv_id * stride_mid_os + Lv) + + tl.store( + Att_Out + offs_mid_o_1, + e_max + tl.log(e_sum), + ) + + +def _decode_att_m_fwd( + q, + k_buffer, + v_buffer, + att_out, + kv_indptr, + kv_indices, + num_kv_splits, + sm_scale, + logit_cap, +): + BLOCK = 64 + # [TODO] work around SGPR limit on MI3xx + if is_hip_: + BLOCK = 32 + NUM_KV_SPLITS = num_kv_splits + Lk = k_buffer.shape[-1] + Lv = v_buffer.shape[-1] + + batch, head_num = kv_indptr.shape[0] - 1, q.shape[1] + + grid = (batch, head_num, NUM_KV_SPLITS) + kv_group_num = q.shape[1] // k_buffer.shape[1] + + if kv_group_num == 1: + num_warps = 4 + else: + num_warps = 2 + if is_hip_: + num_warps = 1 + + BLOCK_DMODEL = triton.next_power_of_2(Lk) + BLOCK_DV = triton.next_power_of_2(Lv) + + _fwd_kernel_stage1[grid]( + q, + k_buffer, + v_buffer, + sm_scale, + kv_indptr, + kv_indices, + att_out, + q.stride(0), + q.stride(1), + k_buffer.stride(0), + k_buffer.stride(1), + v_buffer.stride(0), + v_buffer.stride(1), + att_out.stride(0), + att_out.stride(1), + att_out.stride(2), + kv_group_num=kv_group_num, + BLOCK_DMODEL=BLOCK_DMODEL, + BLOCK_DV=BLOCK_DV, + BLOCK_N=BLOCK, + NUM_KV_SPLITS=NUM_KV_SPLITS, + logit_cap=logit_cap, + num_warps=num_warps, + num_stages=2, + Lk=Lk, + Lv=Lv, + ) + + +@triton.jit +def _fwd_grouped_kernel_stage1( + Q, + K_Buffer, + V_Buffer, + sm_scale, + kv_indptr, + kv_indices, + Att_Out, + stride_qbs, + stride_qh, + stride_buf_kbs, + stride_buf_kh, + stride_buf_vbs, + stride_buf_vh, + stride_mid_ob, + stride_mid_oh, + stride_mid_os, + kv_group_num: tl.constexpr, + q_head_num: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, + BLOCK_DPE: tl.constexpr, + BLOCK_DV: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_H: tl.constexpr, + NUM_KV_SPLITS: tl.constexpr, + logit_cap: tl.constexpr, + Lk: tl.constexpr, + Lv: tl.constexpr, +): + cur_batch = tl.program_id(0) + cur_head_id = tl.program_id(1) + cur_kv_head = cur_head_id // tl.cdiv(kv_group_num, BLOCK_H) + split_kv_id = tl.program_id(2) + + if BLOCK_H < kv_group_num: + VALID_BLOCK_H: tl.constexpr = BLOCK_H + else: + VALID_BLOCK_H: tl.constexpr = kv_group_num + cur_head = cur_head_id * VALID_BLOCK_H + tl.arange(0, BLOCK_H) + mask_h = cur_head < (cur_head_id + 1) * VALID_BLOCK_H + mask_h = mask_h & (cur_head < q_head_num) + + offs_d = tl.arange(0, BLOCK_DMODEL) + offs_dv = tl.arange(0, BLOCK_DV) + mask_d = offs_d < Lk + mask_dv = offs_dv < Lv + + cur_batch_kv_start_idx = tl.load(kv_indptr + cur_batch) + cur_batch_seq_len = tl.load(kv_indptr + cur_batch + 1) - cur_batch_kv_start_idx + + offs_q = cur_batch * stride_qbs + cur_head[:, None] * stride_qh + offs_d[None, :] + q = tl.load(Q + offs_q, mask=(mask_h[:, None]) & (mask_d[None, :]), other=0.0) + + if BLOCK_DPE > 0: + offs_dpe = BLOCK_DMODEL + tl.arange(0, BLOCK_DPE) + mask_dpe = offs_dpe < Lk + off_qpe = (cur_batch * stride_qbs + cur_head[:, None] * stride_qh + offs_dpe[None, :]) + qpe = tl.load(Q + off_qpe, mask=(mask_h[:, None]) & (mask_dpe[None, :]), other=0.0) + + kv_len_per_split = tl.cdiv(cur_batch_seq_len, NUM_KV_SPLITS) + split_kv_start = kv_len_per_split * split_kv_id + split_kv_end = tl.minimum(split_kv_start + kv_len_per_split, cur_batch_seq_len) + + e_max = tl.zeros([BLOCK_H], dtype=tl.float32) - float("inf") + e_sum = tl.zeros([BLOCK_H], dtype=tl.float32) + acc = tl.zeros([BLOCK_H, BLOCK_DV], dtype=tl.float32) + + if split_kv_end > split_kv_start: + for start_n in range(split_kv_start, split_kv_end, BLOCK_N): + offs_n = start_n + tl.arange(0, BLOCK_N) + kv_loc = tl.load( + kv_indices + cur_batch_kv_start_idx + offs_n, + mask=offs_n < split_kv_end, + other=0, + ) + offs_buf_k = (kv_loc[None, :] * stride_buf_kbs + cur_kv_head * stride_buf_kh + offs_d[:, None]) + k = tl.load( + K_Buffer + offs_buf_k, + mask=(offs_n[None, :] < split_kv_end) & (mask_d[:, None]), + other=0.0, + ) + qk = tl.dot(q, k.to(q.dtype)) + if BLOCK_DPE > 0: + offs_buf_kpe = (kv_loc[None, :] * stride_buf_kbs + cur_kv_head * stride_buf_kh + offs_dpe[:, None]) + kpe = tl.load( + K_Buffer + offs_buf_kpe, + mask=(offs_n[None, :] < split_kv_end) & (mask_dpe[:, None]), + other=0.0, + ) + qk += tl.dot(qpe, kpe.to(qpe.dtype)) + qk *= sm_scale + + if logit_cap > 0: + qk = logit_cap * tanh(qk / logit_cap) + + qk = tl.where(mask_h[:, None] & (offs_n[None, :] < split_kv_end), qk, float("-inf")) + + offs_buf_v = (kv_loc[:, None] * stride_buf_vbs + cur_kv_head * stride_buf_vh + offs_dv[None, :]) + v = tl.load( + V_Buffer + offs_buf_v, + mask=(offs_n[:, None] < split_kv_end) & (mask_dv[None, :]), + other=0.0, + ) + + n_e_max = tl.maximum(tl.max(qk, 1), e_max) + re_scale = tl.exp(e_max - n_e_max) + p = tl.exp(qk - n_e_max[:, None]) + acc *= re_scale[:, None] + acc += tl.dot(p.to(v.dtype), v) + + e_sum = e_sum * re_scale + tl.sum(p, 1) + e_max = n_e_max + + offs_mid_o = (cur_batch * stride_mid_ob + cur_head[:, None] * stride_mid_oh + split_kv_id * stride_mid_os + + offs_dv[None, :]) + + tl.store( + Att_Out + offs_mid_o, + acc / e_sum[:, None], + mask=(mask_h[:, None]) & (mask_dv[None, :]), + ) + + offs_mid_o_1 = (cur_batch * stride_mid_ob + cur_head * stride_mid_oh + split_kv_id * stride_mid_os + Lv) + + tl.store( + Att_Out + offs_mid_o_1, + e_max + tl.log(e_sum), + mask=mask_h, + ) + + +def _decode_grouped_att_m_fwd( + q, + k_buffer, + v_buffer, + att_out, + kv_indptr, + kv_indices, + num_kv_splits, + sm_scale, + logit_cap, +): + BLOCK = 32 + Lk = k_buffer.shape[-1] + Lv = v_buffer.shape[-1] + + # [TODO] work around shmem limit on MI3xx + if is_hip_ and Lk >= 576: + BLOCK = 16 + + if Lk == 576: + BLOCK_DMODEL = 512 + BLOCK_DPE = 64 + elif Lk == 288: + BLOCK_DMODEL = 256 + BLOCK_DPE = 32 + else: + BLOCK_DMODEL = triton.next_power_of_2(Lk) + BLOCK_DPE = 0 + BLOCK_DV = triton.next_power_of_2(Lv) + + batch, head_num = kv_indptr.shape[0] - 1, q.shape[1] + kv_group_num = q.shape[1] // k_buffer.shape[1] + + BLOCK_H = 16 + NUM_KV_SPLITS = num_kv_splits + grid = ( + batch, + triton.cdiv(head_num, min(BLOCK_H, kv_group_num)), + NUM_KV_SPLITS, + ) + + extra_kargs = {} + num_stages = 2 + if is_hip_: + # https://rocm.docs.amd.com/en/docs-6.2.0/how-to/llm-fine-tuning-optimization/optimizing-triton-kernel.html + # https://github.com/triton-lang/triton/blob/main/third_party/amd/backend/compiler.py + extra_kargs = {"waves_per_eu": 1, "matrix_instr_nonkdim": 16, "kpack": 2} + num_stages = 1 + + _fwd_grouped_kernel_stage1[grid]( + q, + k_buffer, + v_buffer, + sm_scale, + kv_indptr, + kv_indices, + att_out, + q.stride(0), + q.stride(1), + k_buffer.stride(0), + k_buffer.stride(1), + v_buffer.stride(0), + v_buffer.stride(1), + att_out.stride(0), + att_out.stride(1), + att_out.stride(2), + kv_group_num=kv_group_num, + q_head_num=head_num, + BLOCK_DMODEL=BLOCK_DMODEL, + BLOCK_DPE=BLOCK_DPE, + BLOCK_DV=BLOCK_DV, + BLOCK_N=BLOCK, + BLOCK_H=BLOCK_H, + NUM_KV_SPLITS=NUM_KV_SPLITS, + logit_cap=logit_cap, + num_warps=4, + num_stages=num_stages, + Lk=Lk, + Lv=Lv, + **extra_kargs, + ) + + +@triton.jit +def _fwd_kernel_stage2( + Mid_O, + O, + kv_indptr, + stride_mid_ob, + stride_mid_oh, + stride_mid_os, + stride_obs, + stride_oh, + NUM_KV_SPLITS: tl.constexpr, + BLOCK_DV: tl.constexpr, + Lv: tl.constexpr, +): + cur_batch = tl.program_id(0) + cur_head = tl.program_id(1) + + cur_batch_seq_len = tl.load(kv_indptr + cur_batch + 1) - tl.load(kv_indptr + cur_batch) + + offs_d = tl.arange(0, BLOCK_DV) + mask_d = offs_d < Lv + + e_sum = 0.0 + e_max = -float("inf") + acc = tl.zeros([BLOCK_DV], dtype=tl.float32) + + offs_v = cur_batch * stride_mid_ob + cur_head * stride_mid_oh + offs_d + offs_logic = cur_batch * stride_mid_ob + cur_head * stride_mid_oh + Lv + + for split_kv_id in range(0, NUM_KV_SPLITS): + kv_len_per_split = tl.cdiv(cur_batch_seq_len, NUM_KV_SPLITS) + split_kv_start = kv_len_per_split * split_kv_id + split_kv_end = tl.minimum(split_kv_start + kv_len_per_split, cur_batch_seq_len) + + if split_kv_end > split_kv_start: + tv = tl.load(Mid_O + offs_v + split_kv_id * stride_mid_os, mask=mask_d, other=0.0) + tlogic = tl.load(Mid_O + offs_logic + split_kv_id * stride_mid_os) + n_e_max = tl.maximum(tlogic, e_max) + + old_scale = tl.exp(e_max - n_e_max) + acc *= old_scale + exp_logic = tl.exp(tlogic - n_e_max) + acc += exp_logic * tv + + e_sum = e_sum * old_scale + exp_logic + e_max = n_e_max + + tl.store( + O + cur_batch * stride_obs + cur_head * stride_oh + offs_d, + acc / e_sum, + mask=mask_d, + ) + + +def _decode_softmax_reducev_fwd( + logits, + q, + o, + v_buffer, + kv_indptr, + num_kv_splits, +): + batch, head_num = q.shape[0], q.shape[1] + Lv = v_buffer.shape[-1] + BLOCK_DV = triton.next_power_of_2(Lv) + + NUM_KV_SPLITS = num_kv_splits + + extra_kargs = {} + if is_hip_: + # https://rocm.docs.amd.com/en/docs-6.2.0/how-to/llm-fine-tuning-optimization/optimizing-triton-kernel.html + # https://github.com/triton-lang/triton/blob/main/third_party/amd/backend/compiler.py + extra_kargs = {"waves_per_eu": 4, "matrix_instr_nonkdim": 16, "kpack": 2} + + grid = (batch, head_num) + _fwd_kernel_stage2[grid]( + logits, + o, + kv_indptr, + logits.stride(0), + logits.stride(1), + logits.stride(2), + o.stride(0), + o.stride(1), + NUM_KV_SPLITS=NUM_KV_SPLITS, + BLOCK_DV=BLOCK_DV, + Lv=Lv, + num_warps=4, + num_stages=2, + **extra_kargs, + ) + + +def decode_attention_fwd_normal( + q, + k_buffer, + v_buffer, + o, + kv_indptr, + kv_indices, + attn_logits, + num_kv_splits, + sm_scale, + logit_cap=0.0, +): + _decode_att_m_fwd( + q, + k_buffer, + v_buffer, + attn_logits, + kv_indptr, + kv_indices, + num_kv_splits, + sm_scale, + logit_cap, + ) + _decode_softmax_reducev_fwd(attn_logits, q, o, v_buffer, kv_indptr, num_kv_splits) + + +def decode_attention_fwd_grouped( + q, + k_buffer, + v_buffer, + o, + kv_indptr, + kv_indices, + attn_logits, + num_kv_splits, + sm_scale, + logit_cap=0.0, +): + _decode_grouped_att_m_fwd( + q, + k_buffer, + v_buffer, + attn_logits, + kv_indptr, + kv_indices, + num_kv_splits, + sm_scale, + logit_cap, + ) + _decode_softmax_reducev_fwd(attn_logits, q, o, v_buffer, kv_indptr, num_kv_splits) + + +def decode_attention_fwd( + q, + k_buffer, + v_buffer, + o, + kv_indptr, + kv_indices, + attn_logits, + num_kv_splits, + sm_scale, + logit_cap=0.0, +): + assert num_kv_splits == attn_logits.shape[2] + kv_group_num = q.shape[1] // v_buffer.shape[1] + + if kv_group_num == 1: + # MHA + decode_attention_fwd_normal( + q, + k_buffer, + v_buffer, + o, + kv_indptr, + kv_indices, + attn_logits, + num_kv_splits, + sm_scale, + logit_cap, + ) + else: + # GQA/MQA/MLA + decode_attention_fwd_grouped( + q, + k_buffer, + v_buffer, + o, + kv_indptr, + kv_indices, + attn_logits, + num_kv_splits, + sm_scale, + logit_cap, + ) From 06cf75a8d19a4a5f4651b1c42c26233d4dd40f90 Mon Sep 17 00:00:00 2001 From: Alexander Weinrauch Date: Wed, 16 Apr 2025 11:00:36 +0000 Subject: [PATCH 03/44] [FA] Place cvt layout in the same stage and cluster as LocalLoad so canonicalize can fold it --- .../amd/lib/TritonAMDGPUTransforms/StreamPipeline.cpp | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/StreamPipeline.cpp b/third_party/amd/lib/TritonAMDGPUTransforms/StreamPipeline.cpp index 4b95e2f2f930..d933268806ef 100644 --- a/third_party/amd/lib/TritonAMDGPUTransforms/StreamPipeline.cpp +++ b/third_party/amd/lib/TritonAMDGPUTransforms/StreamPipeline.cpp @@ -403,6 +403,17 @@ bool StreamPipeliner::createAsyncCopy(tt::LoadOp loadOp, Value alloc, scheduleOp(sharedLoad, SCHED_LOCAL_LOAD); loadOp->replaceAllUsesWith(ValueRange{sharedLoad}); + + // Make sure that a possible cvt is in the same stage or otherwise it will not + // get folded + if (sharedLoad->hasOneUse()) { + if (auto cvt = + dyn_cast(*sharedLoad->getUsers().begin())) { + LDBG("Change cvt layout stage and cluster"); + schedule.insert(cvt, localLoadStage, clusterVec[localLoadCluster]); + } + } + if (stages[SCHED_LOCAL_LOAD] != stages[SCHED_COMPUTE] && sharedLoad->hasOneUse()) { if (auto cvt = From 203fe11f19a58b3b958382d327da21a4179de3d1 Mon Sep 17 00:00:00 2001 From: Alexander Weinrauch Date: Wed, 23 Apr 2025 09:11:44 +0000 Subject: [PATCH 04/44] [ASYNC_COPY] Add env var to bypass permute, only works if the load dim is contiguous --- include/triton/Tools/Sys/GetEnv.hpp | 1 + .../TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp | 41 ++++++++++++------- 2 files changed, 27 insertions(+), 15 deletions(-) diff --git a/include/triton/Tools/Sys/GetEnv.hpp b/include/triton/Tools/Sys/GetEnv.hpp index adf8131b9263..b07fb91f263b 100644 --- a/include/triton/Tools/Sys/GetEnv.hpp +++ b/include/triton/Tools/Sys/GetEnv.hpp @@ -34,6 +34,7 @@ inline const std::set CACHE_INVALIDATING_ENV_VARS = { "TRITON_HIP_GLOBAL_PREFETCH", "TRITON_HIP_LOCAL_PREFETCH", "TRITON_HIP_USE_ASYNC_COPY", + "TRITON_HIP_ASYNC_COPY_BYPASS_PERMUTE", "TRITON_HIP_USE_BLOCK_PINGPONG", "TRITON_HIP_USE_IN_THREAD_TRANSPOSE", "TRITON_LLVM_DEBUG_ONLY", diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp index fa9c1f48d72a..a668b6031782 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp @@ -238,11 +238,11 @@ struct DirectToLdsLoadConversionBase : public LoadStoreConversionBase { } } - // Emits the computation to get the lane index which holds the source + // Emits the computation to get the lane id offset which holds the source // pointers/offsets we need to store to shared memory - Value emitSwizzledLaneIndex(RewriterBase &rewriter, TritonLLVMOpBuilder &b, - Location loc, Value coalescedShmem, - Value swizzledShmem, Value vecBytes) const { + Value emitSwizzledLaneOffset(RewriterBase &rewriter, TritonLLVMOpBuilder &b, + Location loc, Value coalescedShmem, + Value swizzledShmem, Value vecBytes) const { // Compute the laneOffset based on the difference in elements between // the two shmem addresses. laneOffset will be negative for half the // lanes because a smaller laneId might hold our global_ptr. @@ -250,9 +250,7 @@ struct DirectToLdsLoadConversionBase : public LoadStoreConversionBase { auto swizzledAddr = b.ptrtoint(i64_ty, swizzledShmem); auto diff = b.trunc(i32_ty, b.sub(swizzledAddr, coalescedAddr)); Value laneOffset = b.sdiv(diff, vecBytes); - // laneId + laneOffset will always stay inside the warp [0, - // threadsPerWarp) because we only swizzle inside a warp - return b.add(getLaneId(rewriter, loc), laneOffset); + return laneOffset; } // Swizzle the mask (1bit) based on selectLane via ballot @@ -541,11 +539,21 @@ struct BufferLoadToLocalOpConversion if (hasSwizzling) { // Apply swizzling to the src offsets - Value swizzledLaneId = - emitSwizzledLaneIndex(rewriter, b, loc, coalescedShmemAddr[i], - swizzledShmemAddr[i], vecBytesVal); - offsetIn = - targetInfo.shuffleIdx(rewriter, loc, offsetIn, swizzledLaneId); + Value laneOffset = + emitSwizzledLaneOffset(rewriter, b, loc, coalescedShmemAddr[i], + swizzledShmemAddr[i], vecBytesVal); + // laneId + laneOffset will always stay inside the warp [0, + // threadsPerWarp) because we only swizzle inside a warp + Value swizzledLaneId = b.add(getLaneId(rewriter, loc), laneOffset); + + if (tools::getBoolEnv("TRITON_HIP_ASYNC_COPY_BYPASS_PERMUTE")) { + offsetIn = b.add( + offsetIn, b.mul(laneOffset, b.i32_val(vecTy.getNumElements()))); + } else { + offsetIn = + targetInfo.shuffleIdx(rewriter, loc, offsetIn, swizzledLaneId); + } + if (mask) { pred = shuffleMask(rewriter, b, loc, targetInfo, swizzledLaneId, pred); @@ -666,9 +674,12 @@ struct AsyncCopyGlobalToLocalOpConversion if (hasSwizzling) { // Apply swizzling to the src pointers - Value swizzledLaneId = - emitSwizzledLaneIndex(rewriter, b, loc, coalescedShmemAddr[i], - swizzledShmemAddr[i], vecBytesVal); + Value laneOffset = + emitSwizzledLaneOffset(rewriter, b, loc, coalescedShmemAddr[i], + swizzledShmemAddr[i], vecBytesVal); + // laneId + laneOffset will always stay inside the warp [0, + // threadsPerWarp) because we only swizzle inside a warp + Value swizzledLaneId = b.add(getLaneId(rewriter, loc), laneOffset); srcPtr = targetInfo.shuffleIdx(rewriter, loc, srcPtr, swizzledLaneId); if (!maskElements.empty()) { pred = From b6643537c999a00c45d37247b8c81b8c2b90ebaa Mon Sep 17 00:00:00 2001 From: Alexander Weinrauch Date: Wed, 23 Apr 2025 14:07:19 +0000 Subject: [PATCH 05/44] [FA] Do not combine AsyncWaits to have a barrier in front of each memory ops cluster --- .../amd/lib/TritonAMDGPUTransforms/StreamPipeline.cpp | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/StreamPipeline.cpp b/third_party/amd/lib/TritonAMDGPUTransforms/StreamPipeline.cpp index d933268806ef..72eafc88e34a 100644 --- a/third_party/amd/lib/TritonAMDGPUTransforms/StreamPipeline.cpp +++ b/third_party/amd/lib/TritonAMDGPUTransforms/StreamPipeline.cpp @@ -1153,11 +1153,11 @@ struct PipelinePass : public TritonAMDGPUStreamPipelineBase { (void)sp.pipelineLoop(); } - if (useAsyncCopy) { - llvm::SmallSetVector waitOps; - moduleOp.walk([&](ttg::AsyncWaitOp waitOp) { waitOps.insert(waitOp); }); - tt::combineRedundantWaitOps(waitOps); - } + // if (useAsyncCopy) { + // llvm::SmallSetVector waitOps; + // moduleOp.walk([&](ttg::AsyncWaitOp waitOp) { waitOps.insert(waitOp); }); + // tt::combineRedundantWaitOps(waitOps); + // } } }; } // namespace From 012793ae2809884143a4e8dfdbdfd7f72c3382d0 Mon Sep 17 00:00:00 2001 From: Alexander Weinrauch Date: Wed, 23 Apr 2025 14:08:12 +0000 Subject: [PATCH 06/44] [ASYNC_COPY] Remove MemoryEffect of BufferLoadToLocal to avoid implicit barrier from Membar --- .../amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUOps.td | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/third_party/amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUOps.td b/third_party/amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUOps.td index 17d9409468d8..093fedc56fb4 100644 --- a/third_party/amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUOps.td +++ b/third_party/amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUOps.td @@ -248,7 +248,7 @@ def BufferLoadToLocalOp : TT_AMDGPU_Op<"buffer_load_to_local", [ let description = [{ AMD Buffer load operation. Similar to amdgpu.buffer_load op but directly wirtes to shared memory instead of into registers. }]; let arguments = (ins - Arg]>:$dest, + Arg:$dest, Arg]>:$ptr, I32Tensor:$offsets, Optional:$mask, From 3b74f4a3bd59d374a3f134b83642bad74229da62 Mon Sep 17 00:00:00 2001 From: Alexander Weinrauch Date: Wed, 23 Apr 2025 17:23:16 +0000 Subject: [PATCH 07/44] [FA] Compute max before mul QK_SCALE to fold sub into fma --- fa/flash-attention.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/fa/flash-attention.py b/fa/flash-attention.py index 3c982482309a..97dbb8b5e9ed 100644 --- a/fa/flash-attention.py +++ b/fa/flash-attention.py @@ -278,11 +278,11 @@ def _attn_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, bias_ptrs, stride_kn, stri qk = tl.where(causal_mask, qk, float("-inf")) # -- compute qk ---- if INT8_GEMM: - qk += ((((tl.dot(q, k).to(tl.float32) * q_descale)) * k_descale) * QK_SCALE) + qk += ((tl.dot(q, k).to(tl.float32) * q_descale)) * k_descale else: if INT8_KV: k = (k * k_descale).to(q.type.element_ty) - qk += (tl.dot(q, k) * QK_SCALE) + qk += tl.dot(q, k) if bias_ptrs is not None: bias_offs_n = start_n + tl.arange(0, BLOCK_N) if MASK_STEPS else None @@ -290,7 +290,7 @@ def _attn_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, bias_ptrs, stride_kn, stri # While bias is added after multiplying qk with sm_scale, # our optimization to use 2^x instead of e^x results in an additional # scale factor of log2(e) which we must also multiply the bias with. - qk += (bias * 1.44269504089) + qk += (bias * 1.44269504089 / QK_SCALE) if alibi_slope is not None: # Compute the global position of each token within the sequence @@ -298,11 +298,12 @@ def _attn_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, bias_ptrs, stride_kn, stri global_n_positions = start_n + tl.arange(0, BLOCK_N) alibi_block = compute_alibi_block(alibi_slope, actual_seqlen_q, actual_seqlen_k, global_m_positions, global_n_positions) - qk += (alibi_block * 1.44269504089) # scale factor of log2(e) + qk += (alibi_block * 1.44269504089 / QK_SCALE) # scale factor of log2(e) # softmax m_ij = tl.maximum(m_i, tl.max(qk, 1)) - qk = qk - m_ij[:, None] + m_ij_scaled = m_ij * QK_SCALE + qk = qk * QK_SCALE - m_ij_scaled[:, None] p = tl.math.exp2(qk) # CAVEAT: Must update l_ij before applying dropout @@ -316,7 +317,7 @@ def _attn_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, bias_ptrs, stride_kn, stri elif RETURN_ENCODED_SOFTMAX: tl.store(encoded_sm_ptrs, p.to(encoded_sm_ptrs.type.element_ty)) # -- update output accumulator -- - alpha = tl.math.exp2(m_i - m_ij) + alpha = tl.math.exp2(m_i * QK_SCALE - m_ij_scaled) acc = acc * alpha[:, None] if not PRE_LOAD_V: v = load_fn(v_ptrs, k_offs_n, k_offs_k, actual_seqlen_k, ACTUAL_BLOCK_DMODEL) From b059372640886b498f678b5e45bfc56ccad09d32 Mon Sep 17 00:00:00 2001 From: Alexander Weinrauch Date: Thu, 24 Apr 2025 08:23:35 +0000 Subject: [PATCH 08/44] [FA] Added 2 extra clusters to have async_waits in front of memory clusters --- .../TritonAMDGPUTransforms/StreamPipeline.cpp | 60 ++++++++++++------- 1 file changed, 40 insertions(+), 20 deletions(-) diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/StreamPipeline.cpp b/third_party/amd/lib/TritonAMDGPUTransforms/StreamPipeline.cpp index 72eafc88e34a..44376ab59e4f 100644 --- a/third_party/amd/lib/TritonAMDGPUTransforms/StreamPipeline.cpp +++ b/third_party/amd/lib/TritonAMDGPUTransforms/StreamPipeline.cpp @@ -178,9 +178,13 @@ class StreamPipeliner { // Stage for each SchedType Op int stages[SCHED_SIZE]; - // Cluster for each SchedType Op + // (not used anymore) Cluster for each SchedType Op std::array clusters; - std::array clusterVec; + // Clusters to hold ops defined by the design document 0-3 (will be mapped to + // clusters 0,2,3,5) + std::array mainClusters; + // Clusters to hold the two async waits (will be mapped to clusters 1,4) + std::array waitClusters; // Scheduling clusters tt::CoarseSchedule schedule; @@ -290,20 +294,32 @@ LogicalResult StreamPipeliner::initSchedule(int maxIndirectionLevel) { // std::array clusterVec; // std::generate(clusterVec.begin(), clusterVec.end(), // [&]() { return schedule.clusters.newAtBack(); }); - clusterVec = {schedule.clusters.newAtBack(), schedule.clusters.newAtBack(), - schedule.clusters.newAtBack(), schedule.clusters.newAtBack()}; - clusters[SCHED_GLOBAL_LOAD] = clusterVec[globalLoadCluster]; - clusters[SCHED_LOCAL_STORE] = clusterVec[localStoreCluster]; - clusters[SCHED_LOCAL_LOAD] = clusterVec[localLoadCluster]; - clusters[SCHED_COMPUTE] = clusterVec[computeCluster]; - // clusters[SCHED_ASYNC_WAIT] = clusterVec[asyncWaitCluster]; + // Create clusters, we have 6 because we have 2 extra for wait ops to be + // placed before the memory blocks + // SM2, DOT1 + mainClusters[0] = schedule.clusters.newAtBack(); + // Wait for V + waitClusters[0] = schedule.clusters.newAtBack(); + // LRV, ACK + mainClusters[1] = schedule.clusters.newAtBack(); + // DOT2, SM1 + mainClusters[2] = schedule.clusters.newAtBack(); + // Wait for K + waitClusters[1] = schedule.clusters.newAtBack(); + // LRK, ACV + mainClusters[3] = schedule.clusters.newAtBack(); + + clusters[SCHED_GLOBAL_LOAD] = mainClusters[globalLoadCluster]; + clusters[SCHED_LOCAL_STORE] = mainClusters[localStoreCluster]; + clusters[SCHED_LOCAL_LOAD] = mainClusters[localLoadCluster]; + clusters[SCHED_COMPUTE] = mainClusters[computeCluster]; // ATTENTION 4-stage - clusters[SCHED_GLOBAL_LOAD] = clusterVec[2]; - clusters[SCHED_LOCAL_STORE] = clusterVec[1]; - clusters[SCHED_LOCAL_LOAD] = clusterVec[1]; - clusters[SCHED_COMPUTE] = clusterVec[0]; + clusters[SCHED_GLOBAL_LOAD] = mainClusters[2]; + clusters[SCHED_LOCAL_STORE] = mainClusters[1]; + clusters[SCHED_LOCAL_LOAD] = mainClusters[1]; + clusters[SCHED_COMPUTE] = mainClusters[0]; // Always have ASYNC_WAIT as the first cluster because we want it at the top // of the schedule block @@ -381,6 +397,7 @@ bool StreamPipeliner::createAsyncCopy(tt::LoadOp loadOp, Value alloc, auto [loadStage, loadCluster] = schedule[loadOp]; auto localLoadStage = loadStage == 0 ? 1 : 3; auto localLoadCluster = loadStage == 0 ? 3 : 1; + auto waitCluster = loadStage == 0 ? 1 : 0; schedule.erase(loadOp); // Schedule new ops @@ -394,7 +411,8 @@ bool StreamPipeliner::createAsyncCopy(tt::LoadOp loadOp, Value alloc, // will write into. This is done by scheduling AsyncWait as the first // cluster. If AsyncCopy and LocalLoads are in the same stage we do not // assign a schdule so they are placed before the LocalLoads - schedule.insert(sharedLoad, localLoadStage, clusterVec[localLoadCluster]); + schedule.insert(sharedLoad, localLoadStage, mainClusters[localLoadCluster]); + schedule.insert(waitOp, localLoadStage, waitClusters[waitCluster]); // if (loadStage != stages[SCHED_LOCAL_LOAD]) // scheduleOp(waitOp, SCHED_ASYNC_WAIT); @@ -410,7 +428,7 @@ bool StreamPipeliner::createAsyncCopy(tt::LoadOp loadOp, Value alloc, if (auto cvt = dyn_cast(*sharedLoad->getUsers().begin())) { LDBG("Change cvt layout stage and cluster"); - schedule.insert(cvt, localLoadStage, clusterVec[localLoadCluster]); + schedule.insert(cvt, localLoadStage, mainClusters[localLoadCluster]); } } @@ -421,6 +439,8 @@ bool StreamPipeliner::createAsyncCopy(tt::LoadOp loadOp, Value alloc, scheduleOp(cvt, SCHED_LOCAL_LOAD); } + // Delete old loadOp + schedule.erase(loadOp); loadOp.erase(); return true; } @@ -747,7 +767,7 @@ LogicalResult StreamPipeliner::scheduleLoads(DenseSet &rootUsers) { if (schedule.count(loadOp) > 0) continue; // scheduleOp(loadOp, SCHED_GLOBAL_LOAD, stage); - schedule.insert(loadOp, i, clusterVec[i == 0 ? 1 : 3]); + schedule.insert(loadOp, i, mainClusters[i == 0 ? 1 : 3]); i++; } @@ -756,7 +776,7 @@ LogicalResult StreamPipeliner::scheduleLoads(DenseSet &rootUsers) { // Non-LoadOp(s) are the (final) root uses of all LoadOp(s). if (!isa(use)) { auto loadStage = schedule[loadOp].first; - schedule.insert(use, loadStage + 2, clusterVec[loadStage == 0 ? 0 : 2]); + schedule.insert(use, loadStage + 2, mainClusters[loadStage == 0 ? 0 : 2]); // scheduleOp(use, SCHED_COMPUTE); rootUsers.insert(use); } @@ -794,7 +814,7 @@ void StreamPipeliner::scheduleDependencies() { bool override = false; if (llvm::isa(op) && stage == 3) { LDBG("Update sched to 0"); - depCluster = clusterVec[0]; + depCluster = mainClusters[0]; override = true; } @@ -993,12 +1013,12 @@ LogicalResult StreamPipeliner::preprocessLoopAndBuildSchedule() { // Schedule reductions int c = 2; for (auto reduceOp : forOp.getBody()->getOps()) { - schedule.insert(reduceOp, c, clusterVec[c == 2 ? 2 : 0]); + schedule.insert(reduceOp, c, mainClusters[c == 2 ? 2 : 0]); c++; } for (auto exp2Op : forOp.getBody()->getOps()) { - schedule.insert(exp2Op, 2, clusterVec[2]); + schedule.insert(exp2Op, 2, mainClusters[2]); } LLVM_DEBUG({ LDBG("Coarse schedule after schedule reduction:"); From f3bb29368ac751ca7255abccf02706710418169f Mon Sep 17 00:00:00 2001 From: Alexander Weinrauch Date: Thu, 24 Apr 2025 08:55:43 +0000 Subject: [PATCH 09/44] [FA] Place LocalLoads before AsyncCopies --- .../amd/lib/TritonAMDGPUTransforms/StreamPipeline.cpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/StreamPipeline.cpp b/third_party/amd/lib/TritonAMDGPUTransforms/StreamPipeline.cpp index 44376ab59e4f..0ae5967d4391 100644 --- a/third_party/amd/lib/TritonAMDGPUTransforms/StreamPipeline.cpp +++ b/third_party/amd/lib/TritonAMDGPUTransforms/StreamPipeline.cpp @@ -396,7 +396,7 @@ bool StreamPipeliner::createAsyncCopy(tt::LoadOp loadOp, Value alloc, auto [loadStage, loadCluster] = schedule[loadOp]; auto localLoadStage = loadStage == 0 ? 1 : 3; - auto localLoadCluster = loadStage == 0 ? 3 : 1; + auto localLoadCluster = loadStage == 0 ? 1 : 0; auto waitCluster = loadStage == 0 ? 1 : 0; schedule.erase(loadOp); @@ -411,7 +411,7 @@ bool StreamPipeliner::createAsyncCopy(tt::LoadOp loadOp, Value alloc, // will write into. This is done by scheduling AsyncWait as the first // cluster. If AsyncCopy and LocalLoads are in the same stage we do not // assign a schdule so they are placed before the LocalLoads - schedule.insert(sharedLoad, localLoadStage, mainClusters[localLoadCluster]); + schedule.insert(sharedLoad, localLoadStage, waitClusters[localLoadCluster]); schedule.insert(waitOp, localLoadStage, waitClusters[waitCluster]); // if (loadStage != stages[SCHED_LOCAL_LOAD]) @@ -428,7 +428,7 @@ bool StreamPipeliner::createAsyncCopy(tt::LoadOp loadOp, Value alloc, if (auto cvt = dyn_cast(*sharedLoad->getUsers().begin())) { LDBG("Change cvt layout stage and cluster"); - schedule.insert(cvt, localLoadStage, mainClusters[localLoadCluster]); + schedule.insert(cvt, localLoadStage, waitClusters[localLoadCluster]); } } From 77884fab76619dab4d14e491727fe094a637a3cd Mon Sep 17 00:00:00 2001 From: Alexander Weinrauch Date: Thu, 24 Apr 2025 10:20:04 +0000 Subject: [PATCH 10/44] [FA][ASYNC_COPY] Force vec=8 for shared encodings to avoid 32bit direct to lds loads --- lib/Dialect/TritonGPU/IR/Dialect.cpp | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/lib/Dialect/TritonGPU/IR/Dialect.cpp b/lib/Dialect/TritonGPU/IR/Dialect.cpp index 216e4dc2efb1..9112476a5857 100644 --- a/lib/Dialect/TritonGPU/IR/Dialect.cpp +++ b/lib/Dialect/TritonGPU/IR/Dialect.cpp @@ -1797,6 +1797,15 @@ SwizzledSharedEncodingAttr AMDMfmaEncodingAttr::composeSharedLayoutForOperand( int innerDimLength = operandShape[sharedOrder[0]]; int elemsPerOneBanksRow = (numBanks * bankBitWidth) / elemBitWidth; + // If we load via AsyncCopy we have to write coalesced into LDS so if + // vectorSize * elemBitWidth < 128 we can only load 32 bit direct-to-lds loads + // (64bit is not suppoted by the HW). So we extend vectorSize to get 128bit to + // get wider loads and accept some bank conflicts durcing ds_reads + // TODO (alex): Make this more generic for async copy + if (vectorSize == 4 && elemBitWidth == 16) { + vectorSize = 8; + } + int perPhase = std::max(1, elemsPerOneBanksRow / innerDimLength); int maxPhase = std::max(std::min(simdWidth / perPhase, innerDimLength / vectorSize), 1u); From b2e2ad02fa8ed37fd72175390560077984ba667d Mon Sep 17 00:00:00 2001 From: Alexander Weinrauch Date: Thu, 24 Apr 2025 15:20:49 +0000 Subject: [PATCH 11/44] [FA] Place dots at the top of clusters --- .../TritonAMDGPUTransforms/StreamPipeline.cpp | 20 ++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/StreamPipeline.cpp b/third_party/amd/lib/TritonAMDGPUTransforms/StreamPipeline.cpp index 0ae5967d4391..0870ec29b8de 100644 --- a/third_party/amd/lib/TritonAMDGPUTransforms/StreamPipeline.cpp +++ b/third_party/amd/lib/TritonAMDGPUTransforms/StreamPipeline.cpp @@ -185,6 +185,7 @@ class StreamPipeliner { std::array mainClusters; // Clusters to hold the two async waits (will be mapped to clusters 1,4) std::array waitClusters; + std::array dotClusters; // Scheduling clusters tt::CoarseSchedule schedule; @@ -297,17 +298,22 @@ LogicalResult StreamPipeliner::initSchedule(int maxIndirectionLevel) { // Create clusters, we have 6 because we have 2 extra for wait ops to be // placed before the memory blocks - // SM2, DOT1 + + // DOT1 + dotClusters[0] = schedule.clusters.newAtBack(); + // SM2, mainClusters[0] = schedule.clusters.newAtBack(); - // Wait for V + // Wait for V, LRV waitClusters[0] = schedule.clusters.newAtBack(); - // LRV, ACK + // ACK mainClusters[1] = schedule.clusters.newAtBack(); - // DOT2, SM1 + // DOT2 + dotClusters[1] = schedule.clusters.newAtBack(); + // SM1 mainClusters[2] = schedule.clusters.newAtBack(); - // Wait for K + // Wait for K, LRK waitClusters[1] = schedule.clusters.newAtBack(); - // LRK, ACV + // ACV mainClusters[3] = schedule.clusters.newAtBack(); clusters[SCHED_GLOBAL_LOAD] = mainClusters[globalLoadCluster]; @@ -776,7 +782,7 @@ LogicalResult StreamPipeliner::scheduleLoads(DenseSet &rootUsers) { // Non-LoadOp(s) are the (final) root uses of all LoadOp(s). if (!isa(use)) { auto loadStage = schedule[loadOp].first; - schedule.insert(use, loadStage + 2, mainClusters[loadStage == 0 ? 0 : 2]); + schedule.insert(use, loadStage + 2, dotClusters[loadStage == 0 ? 0 : 1]); // scheduleOp(use, SCHED_COMPUTE); rootUsers.insert(use); } From fab1281fccb6fd6b0c46f8841f367c87b2cdbd6b Mon Sep 17 00:00:00 2001 From: Alexander Weinrauch Date: Thu, 24 Apr 2025 15:48:40 +0000 Subject: [PATCH 12/44] [FA] Split 4-stage clusters into 8 clusters to better controll the order in the loop --- .../TritonAMDGPUTransforms/StreamPipeline.cpp | 67 ++++++++----------- 1 file changed, 28 insertions(+), 39 deletions(-) diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/StreamPipeline.cpp b/third_party/amd/lib/TritonAMDGPUTransforms/StreamPipeline.cpp index 0870ec29b8de..7c8731355ad3 100644 --- a/third_party/amd/lib/TritonAMDGPUTransforms/StreamPipeline.cpp +++ b/third_party/amd/lib/TritonAMDGPUTransforms/StreamPipeline.cpp @@ -180,11 +180,11 @@ class StreamPipeliner { int stages[SCHED_SIZE]; // (not used anymore) Cluster for each SchedType Op std::array clusters; - // Clusters to hold ops defined by the design document 0-3 (will be mapped to - // clusters 0,2,3,5) - std::array mainClusters; - // Clusters to hold the two async waits (will be mapped to clusters 1,4) - std::array waitClusters; + + // Clusters to hold the different Ops for the 4-stage pipeliner + std::array localReadClusters; + std::array softmaxClusters; + std::array asyncCopyClusters; std::array dotClusters; // Scheduling clusters @@ -296,36 +296,33 @@ LogicalResult StreamPipeliner::initSchedule(int maxIndirectionLevel) { // std::generate(clusterVec.begin(), clusterVec.end(), // [&]() { return schedule.clusters.newAtBack(); }); - // Create clusters, we have 6 because we have 2 extra for wait ops to be - // placed before the memory blocks + // Create clusters in order of 4-stage pipeliner. You can swap lines below to + // change the schedule of the loop. Not all combination are valid, e.g. if a + // consumer and producer from the same stage are in the wrong cluster order + // the loop expander will silently fail // DOT1 dotClusters[0] = schedule.clusters.newAtBack(); // SM2, - mainClusters[0] = schedule.clusters.newAtBack(); + softmaxClusters[0] = schedule.clusters.newAtBack(); // Wait for V, LRV - waitClusters[0] = schedule.clusters.newAtBack(); + localReadClusters[0] = schedule.clusters.newAtBack(); // ACK - mainClusters[1] = schedule.clusters.newAtBack(); + asyncCopyClusters[0] = schedule.clusters.newAtBack(); // DOT2 dotClusters[1] = schedule.clusters.newAtBack(); // SM1 - mainClusters[2] = schedule.clusters.newAtBack(); + softmaxClusters[1] = schedule.clusters.newAtBack(); // Wait for K, LRK - waitClusters[1] = schedule.clusters.newAtBack(); + localReadClusters[1] = schedule.clusters.newAtBack(); // ACV - mainClusters[3] = schedule.clusters.newAtBack(); - - clusters[SCHED_GLOBAL_LOAD] = mainClusters[globalLoadCluster]; - clusters[SCHED_LOCAL_STORE] = mainClusters[localStoreCluster]; - clusters[SCHED_LOCAL_LOAD] = mainClusters[localLoadCluster]; - clusters[SCHED_COMPUTE] = mainClusters[computeCluster]; + asyncCopyClusters[1] = schedule.clusters.newAtBack(); - // ATTENTION 4-stage - clusters[SCHED_GLOBAL_LOAD] = mainClusters[2]; - clusters[SCHED_LOCAL_STORE] = mainClusters[1]; - clusters[SCHED_LOCAL_LOAD] = mainClusters[1]; - clusters[SCHED_COMPUTE] = mainClusters[0]; + // ATTENTION 4-stage (not used) + clusters[SCHED_GLOBAL_LOAD] = softmaxClusters[1]; + clusters[SCHED_LOCAL_STORE] = asyncCopyClusters[0]; + clusters[SCHED_LOCAL_LOAD] = asyncCopyClusters[0]; + clusters[SCHED_COMPUTE] = dotClusters[0]; // Always have ASYNC_WAIT as the first cluster because we want it at the top // of the schedule block @@ -417,8 +414,9 @@ bool StreamPipeliner::createAsyncCopy(tt::LoadOp loadOp, Value alloc, // will write into. This is done by scheduling AsyncWait as the first // cluster. If AsyncCopy and LocalLoads are in the same stage we do not // assign a schdule so they are placed before the LocalLoads - schedule.insert(sharedLoad, localLoadStage, waitClusters[localLoadCluster]); - schedule.insert(waitOp, localLoadStage, waitClusters[waitCluster]); + schedule.insert(sharedLoad, localLoadStage, + localReadClusters[localLoadCluster]); + schedule.insert(waitOp, localLoadStage, localReadClusters[localLoadCluster]); // if (loadStage != stages[SCHED_LOCAL_LOAD]) // scheduleOp(waitOp, SCHED_ASYNC_WAIT); @@ -434,7 +432,7 @@ bool StreamPipeliner::createAsyncCopy(tt::LoadOp loadOp, Value alloc, if (auto cvt = dyn_cast(*sharedLoad->getUsers().begin())) { LDBG("Change cvt layout stage and cluster"); - schedule.insert(cvt, localLoadStage, waitClusters[localLoadCluster]); + schedule.insert(cvt, localLoadStage, localReadClusters[localLoadCluster]); } } @@ -772,8 +770,7 @@ LogicalResult StreamPipeliner::scheduleLoads(DenseSet &rootUsers) { int stage = (maxIndirectionLevel - indLevel) * stagesBetweenLoads; if (schedule.count(loadOp) > 0) continue; - // scheduleOp(loadOp, SCHED_GLOBAL_LOAD, stage); - schedule.insert(loadOp, i, mainClusters[i == 0 ? 1 : 3]); + schedule.insert(loadOp, i, asyncCopyClusters[i == 0 ? 0 : 1]); i++; } @@ -816,25 +813,17 @@ void StreamPipeliner::scheduleDependencies() { if (stage_ != stage) continue; auto depCluster = cluster; - LDBG("Stage: " << stage); bool override = false; if (llvm::isa(op) && stage == 3) { - LDBG("Update sched to 0"); - depCluster = mainClusters[0]; + depCluster = softmaxClusters[0]; override = true; } auto moveStages = [this, stage, cluster = cluster, depCluster = depCluster, override](Operation *op) { - LDBG("Schedule Op: " << *op); if (llvm::isa(op)) { - LDBG("Is a cvt layout\n"); return std::make_pair(stage, cluster); } - if (override) { - LDBG("Override to 0!"); - // return std::make_pair(stage, clusterVec[0]); - } return std::make_pair(stage, depCluster); }; schedule.insertDepsOfOp(op, false, false, moveStages); @@ -1019,12 +1008,12 @@ LogicalResult StreamPipeliner::preprocessLoopAndBuildSchedule() { // Schedule reductions int c = 2; for (auto reduceOp : forOp.getBody()->getOps()) { - schedule.insert(reduceOp, c, mainClusters[c == 2 ? 2 : 0]); + schedule.insert(reduceOp, c, softmaxClusters[c == 2 ? 1 : 0]); c++; } for (auto exp2Op : forOp.getBody()->getOps()) { - schedule.insert(exp2Op, 2, mainClusters[2]); + schedule.insert(exp2Op, 2, softmaxClusters[1]); } LLVM_DEBUG({ LDBG("Coarse schedule after schedule reduction:"); From e0ea5e741d8fabc5d161a7606c84dff2bfeee2eb Mon Sep 17 00:00:00 2001 From: Alexander Weinrauch Date: Thu, 24 Apr 2025 16:02:25 +0000 Subject: [PATCH 13/44] [FA] Revert order change in SM clusters --- third_party/amd/lib/TritonAMDGPUTransforms/StreamPipeline.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/StreamPipeline.cpp b/third_party/amd/lib/TritonAMDGPUTransforms/StreamPipeline.cpp index 7c8731355ad3..1e37dee8609c 100644 --- a/third_party/amd/lib/TritonAMDGPUTransforms/StreamPipeline.cpp +++ b/third_party/amd/lib/TritonAMDGPUTransforms/StreamPipeline.cpp @@ -322,7 +322,7 @@ LogicalResult StreamPipeliner::initSchedule(int maxIndirectionLevel) { clusters[SCHED_GLOBAL_LOAD] = softmaxClusters[1]; clusters[SCHED_LOCAL_STORE] = asyncCopyClusters[0]; clusters[SCHED_LOCAL_LOAD] = asyncCopyClusters[0]; - clusters[SCHED_COMPUTE] = dotClusters[0]; + clusters[SCHED_COMPUTE] = softmaxClusters[0]; // Always have ASYNC_WAIT as the first cluster because we want it at the top // of the schedule block From 119846272b92e330dcd09f64ebb188d228f1a9a9 Mon Sep 17 00:00:00 2001 From: Lixun Zhang Date: Sun, 27 Apr 2025 02:38:09 +0000 Subject: [PATCH 14/44] [FA] Set vecSize=nonKDim for V shared layout to avoid bank conflicts I'll submit a PR upstream later. --- lib/Dialect/TritonGPU/IR/Dialect.cpp | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/lib/Dialect/TritonGPU/IR/Dialect.cpp b/lib/Dialect/TritonGPU/IR/Dialect.cpp index 9112476a5857..af28de4e12de 100644 --- a/lib/Dialect/TritonGPU/IR/Dialect.cpp +++ b/lib/Dialect/TritonGPU/IR/Dialect.cpp @@ -1806,6 +1806,15 @@ SwizzledSharedEncodingAttr AMDMfmaEncodingAttr::composeSharedLayoutForOperand( vectorSize = 8; } + // This is a hack optimization for the V tensor shared layout, which + // - is not kContig + // - local_load from the tensor will have kWidth=4 + // - ds_read_tr is used + // In this case, we can set vecSize to nonkDim of the mfma instruction + // to avoid read bank conflicts + if (!isKContig) + vectorSize = getMDim(); + int perPhase = std::max(1, elemsPerOneBanksRow / innerDimLength); int maxPhase = std::max(std::min(simdWidth / perPhase, innerDimLength / vectorSize), 1u); From fb186d4af0301b31a5583c9bc4cd029b57607e87 Mon Sep 17 00:00:00 2001 From: Alexander Weinrauch Date: Mon, 28 Apr 2025 09:19:13 +0000 Subject: [PATCH 15/44] [FA] Removed old vectorSize workaround --- lib/Dialect/TritonGPU/IR/Dialect.cpp | 9 --------- 1 file changed, 9 deletions(-) diff --git a/lib/Dialect/TritonGPU/IR/Dialect.cpp b/lib/Dialect/TritonGPU/IR/Dialect.cpp index af28de4e12de..a6a817a247f3 100644 --- a/lib/Dialect/TritonGPU/IR/Dialect.cpp +++ b/lib/Dialect/TritonGPU/IR/Dialect.cpp @@ -1797,15 +1797,6 @@ SwizzledSharedEncodingAttr AMDMfmaEncodingAttr::composeSharedLayoutForOperand( int innerDimLength = operandShape[sharedOrder[0]]; int elemsPerOneBanksRow = (numBanks * bankBitWidth) / elemBitWidth; - // If we load via AsyncCopy we have to write coalesced into LDS so if - // vectorSize * elemBitWidth < 128 we can only load 32 bit direct-to-lds loads - // (64bit is not suppoted by the HW). So we extend vectorSize to get 128bit to - // get wider loads and accept some bank conflicts durcing ds_reads - // TODO (alex): Make this more generic for async copy - if (vectorSize == 4 && elemBitWidth == 16) { - vectorSize = 8; - } - // This is a hack optimization for the V tensor shared layout, which // - is not kContig // - local_load from the tensor will have kWidth=4 From 321248196344ebaa436053f73bdadb10f7c28fe8 Mon Sep 17 00:00:00 2001 From: Alexander Weinrauch Date: Tue, 29 Apr 2025 12:42:49 +0000 Subject: [PATCH 16/44] [FA] Revert "Place AsyncWait at the top of schedule" This reverts commit 718ee20be3bcdeb75057d210f659697f34995653. --- .../amd/lib/TritonAMDGPUTransforms/StreamPipeline.cpp | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/StreamPipeline.cpp b/third_party/amd/lib/TritonAMDGPUTransforms/StreamPipeline.cpp index 1e37dee8609c..c226441047c6 100644 --- a/third_party/amd/lib/TritonAMDGPUTransforms/StreamPipeline.cpp +++ b/third_party/amd/lib/TritonAMDGPUTransforms/StreamPipeline.cpp @@ -113,7 +113,6 @@ class StreamPipeliner { SCHED_LOCAL_STORE, SCHED_LOCAL_LOAD, SCHED_COMPUTE, - // SCHED_ASYNC_WAIT, SCHED_SIZE }; @@ -127,8 +126,6 @@ class StreamPipeliner { stages[SCHED_GLOBAL_LOAD] = 0; stages[SCHED_LOCAL_STORE] = _globalPrefetch; stages[SCHED_LOCAL_LOAD] = lastStage - _localPrefetch; - // AsyncWait should be in same stage as the LocalLoad - // stages[SCHED_ASYNC_WAIT] = stages[SCHED_LOCAL_LOAD]; stages[SCHED_COMPUTE] = lastStage; // stages[SCHED_ASYNC_WAIT] = stages[SCHED_LOCAL_LOAD]; @@ -324,9 +321,10 @@ LogicalResult StreamPipeliner::initSchedule(int maxIndirectionLevel) { clusters[SCHED_LOCAL_LOAD] = asyncCopyClusters[0]; clusters[SCHED_COMPUTE] = softmaxClusters[0]; - // Always have ASYNC_WAIT as the first cluster because we want it at the top - // of the schedule block - // clusters[SCHED_ASYNC_WAIT] = schedule.clusters.newAtFront(); + // clusters[SCHED_GLOBAL_LOAD] = clusterVec[globalLoadCluster]; + // clusters[SCHED_LOCAL_STORE] = clusterVec[localStoreCluster]; + // clusters[SCHED_LOCAL_LOAD] = clusterVec[localLoadCluster]; + // clusters[SCHED_COMPUTE] = clusterVec[computeCluster]; LDBG("Cluster schedule:" << " GLOBAL_LOAD cluster = " << globalLoadCluster << ", LOCAL_STORE cluster = " << localStoreCluster From 34beed72823989f292795a315f9a25447c47d049 Mon Sep 17 00:00:00 2001 From: Jungwook Park Date: Tue, 29 Apr 2025 13:19:28 -0500 Subject: [PATCH 17/44] [FA][PINGPONG] Add support for FAv3 pingpong. Initial support over already arranged ops. --- .../TritonAMDGPUTransforms/BlockPingpong.cpp | 63 +++++++++++++++++-- 1 file changed, 59 insertions(+), 4 deletions(-) diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/BlockPingpong.cpp b/third_party/amd/lib/TritonAMDGPUTransforms/BlockPingpong.cpp index 7be03c4e6fda..f6df8bcb3cd2 100644 --- a/third_party/amd/lib/TritonAMDGPUTransforms/BlockPingpong.cpp +++ b/third_party/amd/lib/TritonAMDGPUTransforms/BlockPingpong.cpp @@ -37,6 +37,8 @@ class Pingponger { SmallVector gLoadOps; SmallVector lLoadOps; SmallVector lStoreOps; + SmallVector asyncCopyOps; + SmallVector asyncWaitOps; SmallVector dotOps; SmallVector> subViewOps; SmallVector> loadSliceOps; @@ -73,6 +75,7 @@ class Pingponger { void transformOnePPClusters(OpBuilder &builder, Location loc); LogicalResult transformFourPPClusters(OpBuilder &builder, Location loc); LogicalResult transformTwoPPClusters(OpBuilder &builder, Location loc); + LogicalResult transformFAv3(OpBuilder &builder, Location loc); void addAsymmetricSyncToLoop(OpBuilder &builder, Location loc); void updateOpInsertion(Operation *Op); void appendOp(Operation *Op); @@ -625,6 +628,43 @@ LogicalResult Pingponger::transformTwoPPClusters(OpBuilder &builder, return success(); } +// Fixme : document the scheduling. +// Assuming pipeliner already ordered the ops. +LogicalResult Pingponger::transformFAv3(OpBuilder &builder, Location loc) { + builder.setInsertionPointToStart(forOp.getBody()); + updateOpInsertion(dotOps[0]); + prependOp(builder.create(loc, lowPriority), false); + + // dot cluster 0 operations here. + + updateOpInsertion(asyncWaitOps[0]); + prependOp(builder.create(loc, highPriority), false); + appendOp(builder.create(loc, 0)); + + // mem cluster 0 operations here. + + updateOpInsertion(dotOps[1]); + // below ops are inserted backward + prependOp(builder.create(loc, lowPriority), true); + prependOp(builder.create(loc), true); + prependOp(builder.create(loc, 0), true); + + // dot cluster 1 operations here. + + updateOpInsertion(asyncWaitOps[1]); + prependOp(builder.create(loc, highPriority), false); + appendOp(builder.create(loc, 0)); + + // mem cluster 1 operations here. + + updateOpInsertion(lastInsertedOp->getBlock()->getTerminator()); + prependOp(builder.create(loc), true); + prependOp(builder.create(loc, 0), true); + + // Fixme: validate the case here? + return success(); +} + // This function wraps forOp with cond_barrier. First, hold half of the warps // (warpHigh) in a block before the loop so the barriers in the loop synchronize // warps at the different point per the warp groups. After the loop, hold @@ -657,10 +697,10 @@ void Pingponger::addAsymmetricSyncToLoop(OpBuilder &builder, Location loc) { } void Pingponger::getDotPingponged() { - if (numStages != 2) { + if (numStages != 2 && numStages != 4) { std::stringstream message; - message << "All ping pong scheduling requires 2 stages. Found " << numStages - << " stages"; + message << "All ping pong scheduling requires 2 or 4 stages. Found " + << numStages << " stages"; LDBG(message.str()); return; } @@ -683,11 +723,26 @@ void Pingponger::getDotPingponged() { lLoadOps.push_back(lLoad); } else if (auto lStore = dyn_cast(op)) lStoreOps.push_back(lStore); - else if (auto pingpongDot = dyn_cast(op)) + else if (auto pingpongDot = dyn_cast(op)) { if (pingpongDot.getType().getRank() == 2) dotOps.push_back(pingpongDot); + } else if (auto asyncOp = dyn_cast(op)) + asyncCopyOps.push_back(asyncOp); + else if (auto asyncOp = dyn_cast(op)) + asyncWaitOps.push_back(asyncOp); }); + // Fixme : use proper condition to identify FAv3 + if (numStages == 4 && dotOps.size() == 2) { + if (transformFAv3(builder, loc).failed()) { + LDBG("Encountered failure when trying to execute the FAv3 ping pong " + "cluster transformation"); + return; + } + addAsymmetricSyncToLoop(builder, loc); + return; + } + // Currently, pingpong scheduling is known as helpful under limited condition. // Individual conditions are checked while collecting each operation such as // software pipelining and dot rank=2. Also only accept the for-loop with From 38610635525564bedaafe6d8a4049b8c298c9fbc Mon Sep 17 00:00:00 2001 From: Alexander Weinrauch Date: Wed, 30 Apr 2025 08:42:43 +0000 Subject: [PATCH 18/44] [FA][PINGPONG] Allow block pingpong with num_stages==4 --- third_party/amd/backend/compiler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/third_party/amd/backend/compiler.py b/third_party/amd/backend/compiler.py index 8bf06547d33e..ff5d25d33015 100644 --- a/third_party/amd/backend/compiler.py +++ b/third_party/amd/backend/compiler.py @@ -261,7 +261,7 @@ def make_ttgir(mod, metadata, options): passes.ttgpuir.add_remove_layout_conversions(pm) amd.passes.ttgpuir.add_reorder_instructions(pm) use_block_pingpong = is_pingpong_schedule_enabled(options.arch) - if use_block_pingpong and options.num_stages == 2: + if use_block_pingpong and options.num_stages in [2, 4]: amd.passes.ttgpuir.add_block_pingpong(pm, options.num_stages) if knobs.amd.use_buffer_ops: From fc6d1d9510bb19a77e8e042dd84a24db22d73450 Mon Sep 17 00:00:00 2001 From: Alexander Weinrauch Date: Fri, 2 May 2025 10:36:38 +0000 Subject: [PATCH 19/44] [FA][PINGPONG] Bail out if async wait count != 2 --- third_party/amd/lib/TritonAMDGPUTransforms/BlockPingpong.cpp | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/BlockPingpong.cpp b/third_party/amd/lib/TritonAMDGPUTransforms/BlockPingpong.cpp index f6df8bcb3cd2..4ffcd35d0a0f 100644 --- a/third_party/amd/lib/TritonAMDGPUTransforms/BlockPingpong.cpp +++ b/third_party/amd/lib/TritonAMDGPUTransforms/BlockPingpong.cpp @@ -631,6 +631,10 @@ LogicalResult Pingponger::transformTwoPPClusters(OpBuilder &builder, // Fixme : document the scheduling. // Assuming pipeliner already ordered the ops. LogicalResult Pingponger::transformFAv3(OpBuilder &builder, Location loc) { + if (asyncWaitOps.size() != 2) { + return llvm::failure(); + } + builder.setInsertionPointToStart(forOp.getBody()); updateOpInsertion(dotOps[0]); prependOp(builder.create(loc, lowPriority), false); From d6a041912b476e047de925546a71127af210f295 Mon Sep 17 00:00:00 2001 From: Alexander Weinrauch Date: Sat, 3 May 2025 12:59:43 +0000 Subject: [PATCH 20/44] [FA] Do not pipeline second loop (causal) --- .../amd/lib/TritonAMDGPUTransforms/StreamPipeline.cpp | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/StreamPipeline.cpp b/third_party/amd/lib/TritonAMDGPUTransforms/StreamPipeline.cpp index c226441047c6..f927d0f8d819 100644 --- a/third_party/amd/lib/TritonAMDGPUTransforms/StreamPipeline.cpp +++ b/third_party/amd/lib/TritonAMDGPUTransforms/StreamPipeline.cpp @@ -1158,7 +1158,13 @@ struct PipelinePass : public TritonAMDGPUStreamPipelineBase { loops.push_back(forOp); }); + // Skip second (causal) loop + int loopCount{}; + for (scf::ForOp forOp : loops) { + if (loopCount++ == 1) + continue; + if (!checkPrecondition(forOp)) continue; StreamPipeliner sp(forOp, tt::getNumStagesOrDefault(forOp, numStages), From 1d1e8cc626bccb37859669ecac26f869cfe9060a Mon Sep 17 00:00:00 2001 From: Alexander Weinrauch Date: Wed, 7 May 2025 13:33:31 +0000 Subject: [PATCH 21/44] [FA] Split FourStagePipeliner to separate file and do very basic selection based on the loop. This is not meant as a permanent solution just to make this branch useable for other workloads --- .../lib/TritonAMDGPUTransforms/CMakeLists.txt | 1 + .../FourStagePipeliner.cpp | 952 ++++++++++++++++++ .../FourStagePipeliner.h | 168 ++++ .../TritonAMDGPUTransforms/StreamPipeline.cpp | 199 +--- 4 files changed, 1173 insertions(+), 147 deletions(-) create mode 100644 third_party/amd/lib/TritonAMDGPUTransforms/FourStagePipeliner.cpp create mode 100644 third_party/amd/lib/TritonAMDGPUTransforms/FourStagePipeliner.h diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/CMakeLists.txt b/third_party/amd/lib/TritonAMDGPUTransforms/CMakeLists.txt index 836720b43901..fbb6a71df190 100644 --- a/third_party/amd/lib/TritonAMDGPUTransforms/CMakeLists.txt +++ b/third_party/amd/lib/TritonAMDGPUTransforms/CMakeLists.txt @@ -4,6 +4,7 @@ add_triton_library(TritonAMDGPUTransforms CanonicalizePointers.cpp CoalesceAsyncCopy.cpp ConvertToBufferOps.cpp + FourStagePipeliner.cpp OptimizeEpilogue.cpp HoistLayoutConversions.cpp ReorderInstructions.cpp diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/FourStagePipeliner.cpp b/third_party/amd/lib/TritonAMDGPUTransforms/FourStagePipeliner.cpp new file mode 100644 index 000000000000..e30da8cf74a5 --- /dev/null +++ b/third_party/amd/lib/TritonAMDGPUTransforms/FourStagePipeliner.cpp @@ -0,0 +1,952 @@ +#include "FourStagePipeliner.h" +#include "TritonAMDGPUTransforms/Passes.h" +#include "mlir/Support/LLVM.h" +#include "third_party/amd/include/Dialect/TritonAMDGPU/IR/Dialect.h" +#include "third_party/amd/lib/TritonAMDGPUToLLVM/SchedInstructions.h" +#include "triton/Analysis/AxisInfo.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/Triton/IR/OpInterfaces.h" +#include "triton/Dialect/TritonGPU/IR/Attributes.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/Transforms/PipelineExpander.h" +#include "triton/Dialect/TritonGPU/Transforms/PipeliningUtility.h" +#include "triton/Dialect/TritonGPU/Transforms/Schedule.h" +#include "triton/Dialect/TritonGPU/Transforms/Utility.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/Support/Debug.h" + +//===----------------------------------------------------------------------===// +// This file will create a schedule that will be handed over to the pipeline +// expander. +// Software pipeliners are usually separated into two pieces, one that create a +// modulo schedule and an expander that rewrites the loop and emits a prologue +// and epilogue. This pass first calls a helper that will pre-process the IR +// to create stream operations and create a modulo schedule. Then we call the +// expander to generate the prologue and new loop and epilogue. +//===----------------------------------------------------------------------===// + +#define GEN_PASS_CLASSES +#include "TritonAMDGPUTransforms/Passes.h.inc" + +#define DEBUG_TYPE "tritonamdgpu-four-stage-pipeline" +#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") +#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") + +using namespace mlir; +namespace tt = mlir::triton; +namespace ttg = mlir::triton::gpu; + +static Operation *streamPredication(RewriterBase &rewriter, Operation *op, + Value pred) { + // The epilogue peeling generates a select for the stage output. This causes + // too much register pressure with the loop result and the epilogue-dot in + // regs for the select. Conditionally executing the dot will allow the backend + // to optimize the select away as redundant. + if (auto dotOp = dyn_cast(op)) { + auto loc = dotOp->getLoc(); + auto ifOp = rewriter.create(loc, dotOp->getResult(0).getType(), + pred, /*withElseRegion=*/true); + auto thenB = ifOp.getThenBodyBuilder(); + auto yield = thenB.create(loc, dotOp->getResult(0)); + dotOp->moveBefore(yield); + ifOp.getElseBodyBuilder().create(loc, dotOp->getOperand(2)); + return ifOp; + } + return tt::predicateOp(rewriter, op, pred); +} + +FourStagePipeliner::FourStagePipeliner(scf::ForOp _forOp, int _numStages, + int _globalPrefetch, int _localPrefetch, + bool _useAsyncCopy) + : forOp(_forOp), numStages(_numStages), numBuffers(1), + useAsyncCopy(_useAsyncCopy), schedule(numStages), + axisInfoAnalysis(forOp->getParentOfType()) { + int lastStage = numStages - 1; + stages[SCHED_GLOBAL_LOAD] = 0; + stages[SCHED_LOCAL_STORE] = _globalPrefetch; + stages[SCHED_LOCAL_LOAD] = lastStage - _localPrefetch; + stages[SCHED_COMPUTE] = lastStage; + stages[SCHED_ASYNC_WAIT] = stages[SCHED_LOCAL_LOAD]; + + options.supportDynamicLoops = true; + options.peelEpilogue = true; + options.predicateFn = streamPredication; +} + +bool FourStagePipeliner::checkPrecondition(scf::ForOp forOp, int numStages) { + // Skip the second loop (causual loop) + static bool isFirst = true; + if (!isFirst) + return false; + isFirst = false; + + unsigned dotCount{}; + unsigned reduceCount{}; + + if (tt::getNumStagesOrDefault(forOp, numStages) != 4) + return false; + + if (!forOp.getBody()) + return false; + + for (auto &op : *forOp.getBody()) { + if (isa(op)) { + dotCount++; + } else if (isa(op)) { + reduceCount++; + } + } + return dotCount == 2 && reduceCount == 2; +} + +// Init Schedule Config based on settings and loop characteristics. +// Create clusters in order of ops in loop. This can interleave ops +// from different stages in the same cluster to achieve better backend +// scheduling. +// WARNING: Changing the order of schedule.clusters.newAtBack() calls +// can cause invalid schedules to be produced. +LogicalResult FourStagePipeliner::initSchedule(int maxIndirectionLevel) { + bool pairedGlobalLoadLocalStore = stages[SCHED_LOCAL_STORE] == 0; + stages[SCHED_LOCAL_STORE] += maxIndirectionLevel; + + LDBG( + "Stage schedule:" << " GLOBAL_LOAD stage = " << stages[SCHED_GLOBAL_LOAD] + << ", LOCAL_STORE stage = " << stages[SCHED_LOCAL_STORE] + << ", LOCAL_LOAD stage = " << stages[SCHED_LOCAL_LOAD] + << ", COMPUTE stage = " << stages[SCHED_COMPUTE] + << ", ASYNC_WAIT stage = " << stages[SCHED_ASYNC_WAIT] + << "; total = " << numStages); + + if (stages[SCHED_LOCAL_STORE] >= numStages || + stages[SCHED_LOCAL_STORE] > stages[SCHED_LOCAL_LOAD]) { + LDBG("Invalid stage schedule"); + return failure(); + } + + // Calculate the number of buffers needed for each load. + // TODO: Use the precise number of buffers needed by the particular load. + numBuffers = + std::max(1, stages[SCHED_LOCAL_LOAD] - stages[SCHED_LOCAL_STORE]); + // If we use AsyncCopy we need one more buffer since we are not using a + // register buffer + if (useAsyncCopy) { + numBuffers += 1; + } + numBuffers = 2; + + LDBG("deduced max shared memory buffer number = " << numBuffers); + + // We place async wait as the first cluster because we want to have it being + // the first in the main loop after pipelining. + int asyncWaitCluster = 0; + + // If tt.load and ttg.local_store are in the same stage + // spread them apart to allow overlap with compute + // else + // Initiate ttg.local_store before tt.load + int globalLoadCluster = 1; + int localStoreCluster = 3; + if (!pairedGlobalLoadLocalStore) { + globalLoadCluster = 3; + localStoreCluster = 2; + } + + // If ttg.local_load and ttg.local_store are in the same stage + // spread them apart to allow overlap with compute + // else if they share the buffer + // ttg.local_load must come first + // else + // schedule ttg.local_load in the middle + int localLoadCluster = globalLoadCluster; + if (stages[SCHED_LOCAL_LOAD] == stages[SCHED_LOCAL_STORE]) { + localLoadCluster = std::max(3, localStoreCluster + 1); + } else if (numBuffers == 1 && localLoadCluster >= localStoreCluster) { + // For 1 buffer, ttg.local_load must occur before ttg.local_store + localLoadCluster = localStoreCluster - 1; + } + + // Schedule compute with ttg.local_load if paired + // otherwise, schedule in the middle + int computeCluster = 2; + if (stages[SCHED_LOCAL_LOAD] == stages[SCHED_COMPUTE]) { + computeCluster = localLoadCluster; + } + + // Create clusters in order of 4-stage pipeliner. You can swap lines below to + // change the schedule of the loop. Not all combination are valid, e.g. if a + // consumer and producer from the same stage are in the wrong cluster order + // the loop expander will silently fail + + // DOT1 + dotClusters[0] = schedule.clusters.newAtBack(); + // SM2, + softmaxClusters[0] = schedule.clusters.newAtBack(); + // Wait for V, LRV + localReadClusters[0] = schedule.clusters.newAtBack(); + // ACK + asyncCopyClusters[0] = schedule.clusters.newAtBack(); + // DOT2 + dotClusters[1] = schedule.clusters.newAtBack(); + // SM1 + softmaxClusters[1] = schedule.clusters.newAtBack(); + // Wait for K, LRK + localReadClusters[1] = schedule.clusters.newAtBack(); + // ACV + asyncCopyClusters[1] = schedule.clusters.newAtBack(); + + // ATTENTION 4-stage (not used) + clusters[SCHED_GLOBAL_LOAD] = softmaxClusters[1]; + clusters[SCHED_LOCAL_STORE] = asyncCopyClusters[0]; + clusters[SCHED_LOCAL_LOAD] = asyncCopyClusters[0]; + clusters[SCHED_ASYNC_WAIT] = asyncCopyClusters[0]; + clusters[SCHED_COMPUTE] = softmaxClusters[0]; + // Make assignments + // std::array clusterVec; + // std::generate(clusterVec.begin(), clusterVec.end(), + // [&]() { return schedule.clusters.newAtBack(); }); + + // clusters[SCHED_GLOBAL_LOAD] = clusterVec[globalLoadCluster]; + // clusters[SCHED_LOCAL_STORE] = clusterVec[localStoreCluster]; + // clusters[SCHED_LOCAL_LOAD] = clusterVec[localLoadCluster]; + // clusters[SCHED_COMPUTE] = clusterVec[computeCluster]; + // clusters[SCHED_ASYNC_WAIT] = clusterVec[asyncWaitCluster]; + + LDBG("Cluster schedule:" << " GLOBAL_LOAD cluster = " << globalLoadCluster + << ", LOCAL_STORE cluster = " << localStoreCluster + << ", LOCAL_LOAD cluster = " << localLoadCluster + << ", COMPUTE cluster = " << computeCluster + << ", ASYNC_WAIT cluster = " << asyncWaitCluster + << "; total = " << SCHED_SIZE); + + return success(); +} + +bool FourStagePipeliner::createAsyncCopy(tt::LoadOp loadOp, Value alloc, + Value extractIdx) { + assert(useAsyncCopy); + // If we have a single buffer we would require another barrier after the + // local_reads so instead we fall back to pipeline with registers + // Removing this check will create incorrect IR, see + // MembarUtility.h:membarFilter + if (numBuffers == 1) + return false; + + OpBuilder builder(loadOp); + Location loc = loadOp.getLoc(); + + Value src = loadOp.getPtr(); + auto srcTy = cast(src.getType()); + + ttg::MemDescType allocTy = cast(alloc.getType()); + auto sharedEncodingAttr = + cast(allocTy.getEncoding()); + + // Extract local subview from shared allocation + Value zero = builder.create(forOp.getLoc(), 0, 32); + SmallVector loadOffsets(allocTy.getRank(), zero); + loadOffsets[0] = extractIdx; + auto sharedMemorySpace = ttg::SharedMemorySpaceAttr::get(forOp.getContext()); + auto subviewTy = ttg::MemDescType::get( + allocTy.getShape().drop_front(), allocTy.getElementType(), + allocTy.getEncoding(), sharedMemorySpace, /*mutableMemory=*/true); + auto viewLoad = + builder.create(loc, subviewTy, alloc, loadOffsets); + + // If the load is used by an existing local allocation we replace it with the + // new subview + SmallVector allocsToErase; + for (Operation *user : loadOp->getUsers()) { + if (auto alloc = dyn_cast(user)) { + tt::replaceUsesAndPropagateType(builder, alloc, viewLoad); + allocsToErase.push_back(alloc); + } + } + for (auto alloc : allocsToErase) + alloc.erase(); + + auto copyOp = builder.create( + loadOp.getLoc(), src, viewLoad, loadOp.getMask(), loadOp.getOther(), + loadOp.getCache(), loadOp.getEvict(), loadOp.getIsVolatile()); + + // Insert synchronization primitives to create barriers during lowering + auto commitOp = + builder.create(loc, copyOp->getResult(0)); + + ttg::AsyncWaitOp waitOp = + builder.create(loc, commitOp->getResult(0), 0); + + // Create local load which consumes the async token from the AsyncWait + auto sharedLoad = + builder.create(loc, loadOp.getType(), viewLoad, waitOp); + + auto [loadStage, loadCluster] = schedule[loadOp]; + // Schedule new ops + schedule.insert(copyOp, loadStage, loadCluster); + // Place ttg.async_commit_group op following AsyncCopyGlobalToLocal so the + // later UpdateAsyncWaitCount pass can deduce better waitcnts + schedule.insert(commitOp, loadStage, loadCluster); + // If the LocalLoads are scheduled to a later stage than AsyncCopy we need to + // place the AsyncCopy prefetches after the AsyncWaits which create a barrier + // to ensure all warps are finished reading the shared buffer we will write + // into. This is done by scheduling AsyncWait as the first cluster. + // If AsyncCopy and LocalLoads are in the same stage we do not assign a + // schdule so they are placed before the LocalLoads + // Disable for FA + // if (loadStage != stages[SCHED_LOCAL_LOAD]) + // scheduleOp(waitOp, SCHED_ASYNC_WAIT); + + // if (stages[SCHED_LOCAL_LOAD] != stages[SCHED_COMPUTE]) + // scheduleOp(sharedLoad, SCHED_LOCAL_LOAD); + + loadOp->replaceAllUsesWith(ValueRange{sharedLoad}); + + // 4-stage pipeliner scheduleing + auto localLoadStage = loadStage == 0 ? 1 : 3; + auto localLoadCluster = loadStage == 0 ? 1 : 0; + schedule.insert(sharedLoad, localLoadStage, + localReadClusters[localLoadCluster]); + schedule.insert(waitOp, localLoadStage, localReadClusters[localLoadCluster]); + + // Make sure that a possible cvt is in the same stage or otherwise it will not + // get folded + if (sharedLoad->hasOneUse()) { + if (auto cvt = + dyn_cast(*sharedLoad->getUsers().begin())) { + LDBG("Change cvt layout stage and cluster"); + schedule.insert(cvt, localLoadStage, localReadClusters[localLoadCluster]); + } + } + + if (stages[SCHED_LOCAL_LOAD] != stages[SCHED_COMPUTE] && + sharedLoad->hasOneUse()) { + if (auto cvt = + dyn_cast(*sharedLoad->getUsers().begin())) + scheduleOp(cvt, SCHED_LOCAL_LOAD); + } + + // Delete old loadOp + schedule.erase(loadOp); + loadOp.erase(); + return true; +} + +void FourStagePipeliner::createStreamCopy(tt::LoadOp loadOp, Value alloc, + Value extractIdx) { + OpBuilder builder(forOp); + Value zero = builder.create(forOp.getLoc(), 0, 32); + // Replace the load with insert/extract slice. + builder.setInsertionPoint(loadOp); + Location loc = loadOp.getLoc(); + Value src = loadOp.getPtr(); + Value mask = loadOp.getMask(); + Value other = loadOp.getOther(); + + ttg::MemDescType allocTy = cast(alloc.getType()); + SmallVector copyOffsets(allocTy.getRank(), zero); + Operation *copy = builder.clone(*loadOp); + + auto [stage, cluster] = schedule[loadOp]; + schedule.erase(loadOp); + schedule.insert(copy, stage, cluster); + + // Extract part. + SmallVector loadOffsets(allocTy.getRank(), zero); + loadOffsets[0] = extractIdx; + auto sharedMemorySpace = ttg::SharedMemorySpaceAttr::get(forOp.getContext()); + auto subviewTy = ttg::MemDescType::get( + allocTy.getShape().drop_front(), allocTy.getElementType(), + allocTy.getEncoding(), sharedMemorySpace, /*mutableMemory=*/true); + auto viewLoad = + builder.create(loc, subviewTy, alloc, loadOffsets); + // Clean up old local caches. + SmallVector allocsToErase; + for (Operation *user : loadOp->getUsers()) { + if (auto alloc = dyn_cast(user)) { + tt::replaceUsesAndPropagateType(builder, alloc, viewLoad.getResult()); + allocsToErase.push_back(alloc); + } + } + for (auto alloc : allocsToErase) + alloc.erase(); + + // Prefetch load ahead of the dot stage if is used by the dot. + auto storeOp = + builder.create(loc, copy->getResult(0), viewLoad); + scheduleOp(viewLoad, SCHED_LOCAL_STORE); + scheduleOp(storeOp, SCHED_LOCAL_STORE); + + // Create local load + auto sharedLoad = + builder.create(loc, loadOp.getType(), viewLoad); + Value result = sharedLoad.getResult(); + if (stages[SCHED_LOCAL_LOAD] != stages[SCHED_COMPUTE]) + scheduleOp(sharedLoad, SCHED_LOCAL_LOAD); + + // If the currently processed `LoadOp` is labeled with an index regarding + // to which `DotOp` operand the corresponding data belongs to, then label the + // expanded `LocalStoreOp` with the same index. This is required for + // instruction scheduling hints to correctly count the emitted `ds_write` + // instructions for each GEMM tile. + if (auto attr = loadOp->getAttr(tt::amdgpu::OpIdxAttr::getMnemonic())) { + storeOp->setAttr(tt::amdgpu::OpIdxAttr::getMnemonic(), attr); + } + + loadOp->replaceAllUsesWith(ValueRange{result}); + + if (stages[SCHED_LOCAL_LOAD] != stages[SCHED_COMPUTE] && result.hasOneUse()) { + if (auto cvt = dyn_cast(*result.getUsers().begin())) + scheduleOp(cvt, SCHED_LOCAL_LOAD); + } + + loadOp.erase(); +} + +// Returns the given |inputValue|'s dot user result encoding and updates |opIdx| +// with which dot operand |inputValue| is fed into if possible. +static ttg::AMDMfmaEncodingAttr getDotEncoding(Value inputValue, + unsigned *opIdx) { + if (!llvm::hasSingleElement(inputValue.getUses())) + return nullptr; + + Operation *user = *inputValue.getUsers().begin(); + if (user->getNumResults() != 1 || + user->getBlock() != inputValue.getParentBlock()) + return nullptr; + + if (auto dotOp = dyn_cast(user)) { + OpOperand &use = *inputValue.getUses().begin(); + *opIdx = use.getOperandNumber(); + auto dotType = cast(dotOp->getResult(0).getType()); + return dyn_cast(dotType.getEncoding()); + } + return getDotEncoding(user->getResult(0), opIdx); +} + +// If all the transitive uses of the given value have are used by a convert to +// the same dot operand encoding, return true and get the shared encoding that +// needs to be used to be compatible with users' layouts. +static std::optional +getSharedEncIfAllUsersAreDotEnc(Value loadedValue) { + ttg::SwizzledSharedEncodingAttr attr; + for (Operation *user : loadedValue.getUsers()) { + LDBG(" getSharedEncIfAllUsersAreDotEnc current user: " << *user); + if (user->getNumResults() != 1) + return std::nullopt; + + ttg::SwizzledSharedEncodingAttr tempAttr; + Value userResult = user->getResult(0); + Type userResType = userResult.getType(); + if (auto memDesc = dyn_cast(userResType)) { + // First time we find a shared encoding in the chain, save it and try to + // use it if it is compatible with the other users. + tempAttr = cast(memDesc.getEncoding()); + if (!getSharedEncIfAllUsersAreDotEnc(userResult).has_value()) + return std::nullopt; + } else { + if (!isa(user)) + return std::nullopt; + + auto srcTy = cast(loadedValue.getType()); + auto ctaLayout = ttg::getCTALayout(srcTy.getEncoding()); + auto order = getOrderForMemory(srcTy); + unsigned bitWidth = srcTy.getElementType().getIntOrFloatBitWidth(); + SmallVector sharedOrder; + int rank = order.size(); + // TODO rework this when shared -> dotOperand conversions support + // arbitrary shared memory ordering + if (rank == 3) { + // Move the batch dimension (dim #0) to be the last so that it will be + // the slowest varying dimension. + for (unsigned i = 0; i < rank; ++i) + if (order[i] != 0) + sharedOrder.emplace_back(order[i]); + sharedOrder.emplace_back(0); + } else { + sharedOrder = order; + } + + auto userResEnc = cast(userResType).getEncoding(); + if (auto dotOpEnc = dyn_cast(userResEnc)) { + tempAttr = ttg::SwizzledSharedEncodingAttr::get( + loadedValue.getContext(), dotOpEnc, srcTy.getShape(), sharedOrder, + ctaLayout, bitWidth, /*needTrans=*/false); + } else if (auto llEnc = dyn_cast(userResEnc)) { + // We use linear layout directly for scaled dot fp8 operands. For such + // cases, we need to look further down the def-use chain to find the dot + // op for the mfma layout to deduce operand index and other information. + unsigned opIdx; + if (auto dotEnc = getDotEncoding(userResult, &opIdx)) { + unsigned vecSize = llEnc.getLinearLayout().getNumConsecutiveInOut(); + LDBG("deduced opIdx: " << opIdx << "; deduced vecSize: " << vecSize); + tempAttr = dotEnc.composeSharedLayoutForOperand( + ctaLayout, opIdx, srcTy.getShape(), order, vecSize, bitWidth, + /*needTrans=*/false); + } + } + } + // Check that the shared encodings needed by the users are compatible. + if (!tempAttr || (attr != nullptr && attr != tempAttr)) + return std::nullopt; + attr = tempAttr; + } + return attr; +} + +// Create a map from load ops to their indirection levels and the final uses +// of the load op (another load op, or a dot op). +// +// Indirection level is "0" for the load op directly used by the dot op, +// "1" for the load op used by the load op used by the dot op, and so on. +void FourStagePipeliner::computeLoadOpsToIndirectionLevelAndUse() { + DenseSet seen; + + // Recursively visit the given op and its operands to discover all load ops + // and collect their indirection levels and uses. + std::function dfs = + [&](Operation *op, int distance, Operation *use) { + // Skip previously visited load ops. + if (!seen.insert(op).second) + return; + + if (isa(op)) { + // TODO: What if there are multiple uses at different distances? + loadOpToIndLevelAndUse.emplace_back(op, distance, use); + use = op; + ++distance; + } + for (Value operand : op->getOperands()) { + Operation *defOp = operand.getDefiningOp(); + if (defOp && defOp->getBlock() == op->getBlock()) { + dfs(defOp, distance, use); + } + } + }; + + for (Operation &op : forOp.getBody()->without_terminator()) { + if (!isa(op)) + continue; + seen.clear(); + dfs(&op, 0, &op); + } + + // If the loop has numStages attribute, also consider pipelining other loads + // that are not directly used by dot ops. + if (forOp->hasAttr(tt::kNumStagesAttrName)) { + for (Operation &op : forOp.getBody()->without_terminator()) { + if (!isa(op)) + dfs(&op, 0, &op); + } + } +} + +// Goes through all load ops to identify those that can be pipelined and assign +// layout to them. +void FourStagePipeliner::assignMemoryLayouts() { + for (auto &[op, dist, use] : loadOpToIndLevelAndUse) { + if (loadToInfo.count(op)) + // TODO: We'd need to verify that the distance is the same. + continue; + + auto loadOp = cast(op); + assert(!isLoadFromTensorPtr(loadOp) && + "Block ptr should have been lowered before this pass."); + auto ptr = loadOp.getPtr(); + unsigned vec = axisInfoAnalysis.getContiguity(ptr); + if (auto mask = loadOp.getMask()) + vec = std::min(vec, axisInfoAnalysis.getMaskAlignment(mask)); + + auto tensorTy = dyn_cast(ptr.getType()); + if (!tensorTy) { + LDBG("Skip non-tensor load " << loadOp); + continue; + } + + auto pointeeTy = + cast(tensorTy.getElementType()).getPointeeType(); + unsigned width = vec * pointeeTy.getIntOrFloatBitWidth(); + + LDBG("assign memory layouts (width=" << width << ") for load " << loadOp); + LoadInfo loadInfo; + if (isa(use)) { + // Only use shared memory when feeding into a dot op. + loadInfo.usedByDot = true; + // If the max continugous bits we can read is < 32, buffer in registers. + if (width >= 32) { + loadInfo.sharedEncoding = + getSharedEncIfAllUsersAreDotEnc(op->getResult(0)).value_or(nullptr); + } + } else if (auto useOp = dyn_cast(use)) { + // The use of this loadOp is another loadOp. If the use is not in the + // loadToInfo already, it means that the use is not valid for pipelining + // for some reason. We should skip this loadOp, too. + // + // Note that we have an assumption that the use of this loadOp has already + // be processed in a previous loop iteration. This assumption is held by + // how loadOpsToIndirectionLevelAndUse recursively collects + // loadOpToIndLevelAndUse using DFS. + if (loadToInfo.count(useOp) == 0) { + continue; + } + } + + loadToInfo[op] = loadInfo; + } +} + +LogicalResult +FourStagePipeliner::scheduleLoads(DenseSet &rootUsers) { + // Get all loads that are (transitively) used by dot ops and their distance + // to the dot op. + computeLoadOpsToIndirectionLevelAndUse(); + LLVM_DEBUG({ + LDBG("Found " << loadOpToIndLevelAndUse.size() << " loads to pipeline:"); + for (const auto &[l, i, u] : loadOpToIndLevelAndUse) { + LDBG(" - load: " << *l); + LDBG(" at indirection level: " << i); + LDBG(" used by op: " << *u); + } + }); + if (loadOpToIndLevelAndUse.empty()) + return failure(); + + // Check which loads are good for pipelining, and assign them memory layouts. + assignMemoryLayouts(); + if (loadToInfo.empty()) + return failure(); + + // Filter out load ops that cannot be pipelined. + int resize = 0; + for (int i = 0, e = loadOpToIndLevelAndUse.size(); i < e; ++i) { + auto [loadOp, distance, use] = loadOpToIndLevelAndUse[i]; + if (loadToInfo.count(loadOp) != 0) + loadOpToIndLevelAndUse[resize++] = loadOpToIndLevelAndUse[i]; + } + loadOpToIndLevelAndUse.resize(resize); + + // Calculate the stage distance between applicable loads. + int maxIndirectionLevel = -1; + for (auto [loadOp, dist, use] : loadOpToIndLevelAndUse) + maxIndirectionLevel = std::max(maxIndirectionLevel, dist); + + LDBG("maxIndirectionLevel = " << maxIndirectionLevel); + if (maxIndirectionLevel >= numStages) + return failure(); + + if (failed(initSchedule(maxIndirectionLevel))) + return failure(); + + // The stage gap between chained loads--this allows us to "spread" loads + // with a non-one step in case the number of stages given by the user is + // large. + assert(numStages >= 2 && "requires num_stages=2 at least"); + unsigned stagesBetweenLoads = + llvm::divideCeil(numStages - 2, maxIndirectionLevel + 1); + LDBG("stagesBetweenLoads = " << stagesBetweenLoads); + + // Assign stages to the loads. + // FA: + // Load1: Stage=0, cluster=1 + // Load2: Stage=1, cluster=3 + int i{}; + for (auto [loadOp, indLevel, _] : loadOpToIndLevelAndUse) { + int stage = (maxIndirectionLevel - indLevel) * stagesBetweenLoads; + if (schedule.count(loadOp) > 0) + continue; + schedule.insert(loadOp, i, asyncCopyClusters[i == 0 ? 0 : 1]); + i++; + } + + // Put the root uses of the loads in the last stage. + for (auto &[loadOp, dist, use] : loadOpToIndLevelAndUse) { + // Non-LoadOp(s) are the (final) root uses of all LoadOp(s). + if (!isa(use)) { + auto loadStage = schedule[loadOp].first; + schedule.insert(use, loadStage + 2, dotClusters[loadStage == 0 ? 0 : 1]); + // scheduleOp(use, SCHED_COMPUTE); + rootUsers.insert(use); + } + } + + // Calculate distance from the load to the use. + for (auto [loadOp, _, use] : loadOpToIndLevelAndUse) { + loadToInfo[loadOp].distToUse = schedule[use].first - schedule[loadOp].first; + } + + LLVM_DEBUG({ + LDBG("Chosen loads to pipeline:"); + for (const auto &[load, info] : loadToInfo) { + LDBG(" - load: " << *load); + LDBG(" distToUse: " << info.distToUse); + LDBG(" usedByDot: " << info.usedByDot); + } + }); + + return success(); +} + +// Add dependencies of anchor ops to the coarse schedule. Schedule them to +// the same stage and ordering cluster as the anchor op. +void FourStagePipeliner::scheduleDependencies() { + SmallVector> + opsInOrder = schedule.getOpsInOrder(forOp); + // Schedule dependencies stage by stage. + for (int stage = 0; stage < numStages; ++stage) { + for (auto [op, stage_, cluster] : opsInOrder) { + if (stage_ != stage) + continue; + auto depCluster = cluster; + bool override = false; + if (llvm::isa(op) && stage == 3) { + depCluster = softmaxClusters[0]; + override = true; + } + + auto moveStages = [this, stage, cluster = cluster, + depCluster = depCluster, override](Operation *op) { + if (llvm::isa(op)) { + return std::make_pair(stage, cluster); + } + return std::make_pair(stage, depCluster); + }; + schedule.insertDepsOfOp(op, false, false, moveStages); + } + } +} + +// Find dependencies with distance of 1. They will go to the next stage, +// but in the cluster before the current op. +void FourStagePipeliner::scheduleDistanceOneDependencies() { + auto getNestedOperands = [](Operation *op) { + SmallVector operands; + op->walk([&](Operation *nestedOp) { + for (Value operand : nestedOp->getOperands()) { + if (operand.getParentBlock()->getParentOp()->isAncestor(nestedOp)) + operands.push_back(operand); + } + }); + return operands; + }; + + // Mapping from the cluster to the cluster before it. + DenseMap + dist1Cluster; + for (auto &op : forOp.getBody()->without_terminator()) { + if (schedule.count(&op) == 0) + continue; + auto [stage, cluster] = schedule[&op]; + // Can't schedule past the last stage. + if (stage == numStages - 1) + continue; + for (Value operand : getNestedOperands(&op)) { + auto arg = dyn_cast(operand); + if (!arg || arg.getArgNumber() == 0 || arg.getOwner() != op.getBlock()) + continue; + auto yieldOp = op.getBlock()->getTerminator(); + Value v = yieldOp->getOperand(arg.getArgNumber() - 1); + Operation *defOp = v.getDefiningOp(); + if (!defOp || schedule.count(defOp) != 0) + continue; + if (isa(defOp)) { + // Exception: schedule loads with a distance of 1 together with the + // current op. + schedule.insertIfAbsent(defOp, stage, cluster); + schedule.insertDepsOfOp(defOp, stage, cluster, true); + } else { + if (dist1Cluster.count(&cluster) == 0) { + dist1Cluster[&cluster] = schedule.clusters.newBefore(cluster); + } + schedule.insertIfAbsent(defOp, stage + 1, dist1Cluster[&cluster]); + schedule.insertDepsOfOp(defOp, stage + 1, dist1Cluster[&cluster], true); + } + } + } +} + +void FourStagePipeliner::scheduleRemainingToLastStage() { + int lastStage = numStages - 1; + // Assign the rest of the ops to the last stage. + // Take care of the ordering of the ops - uses cannot be scheduled to the + // cluster before the definition. + auto cluster = clusters[SCHED_COMPUTE]; + DenseMap opToCluster; + for (auto &op : forOp.getBody()->without_terminator()) { + if (schedule.count(&op) == 0) + opToCluster[&op] = cluster; + } + SmallVector queue; + for (auto [op, stage, cluster] : schedule.getOpsInOrder(forOp)) { + // We really only care about the producers from the last stage. + // Others will be scheduled before these ops anyway. + if (stage == lastStage) { + queue.push_back(op); + } + } + while (!queue.empty()) { + Operation *op = queue.pop_back_val(); + for (auto user : op->getUsers()) { + if (opToCluster.count(user)) { + tt::CoarseSchedule::Cluster userCluster = opToCluster[user]; + tt::CoarseSchedule::Cluster opCluster = schedule[op].second; + if (*userCluster < *opCluster) { + opToCluster[user] = opCluster; + queue.push_back(user); + } + } + } + } + for (auto [op, cluster] : opToCluster) { + schedule.insert(op, lastStage, cluster); + } +} + +// Create an allocation that can hold distance number of loadOp shapes. +Value FourStagePipeliner::createAlloc( + Operation *loadOp, ttg::SwizzledSharedEncodingAttr sharedEnc) { + OpBuilder builder(forOp); + Attribute sharedMemorySpace = + ttg::SharedMemorySpaceAttr::get(forOp.getContext()); + auto ty = cast(loadOp->getResultTypes()[0]); + SmallVector bufferShape(ty.getShape().begin(), ty.getShape().end()); + bufferShape.insert(bufferShape.begin(), numBuffers); + Type memdescType = ttg::MemDescType::get(bufferShape, ty.getElementType(), + sharedEnc, sharedMemorySpace, + /*mutableMemory=*/true); + auto alloc = builder.create(loadOp->getLoc(), memdescType); + sharedMemAllocs.push_back(alloc); + return alloc; +} + +// Convert load ops into shared memory allocation loads and apply +// multi-buffering based on the required number of buffers. +void FourStagePipeliner::createStreamOps() { + SmallVector> loadToAllocs; + for (auto &[loadOp, info] : loadToInfo) { + if (!info.sharedEncoding || info.isAsync) + continue; + + Value alloc = createAlloc(loadOp, info.sharedEncoding); + assert(alloc && "Failed to create alloc for the async load."); + loadToAllocs.emplace_back(loadOp, alloc); + } + + IRRewriter builder(forOp.getContext()); + builder.setInsertionPoint(forOp); + + Location loc = forOp.getLoc(); + Value minusOne = builder.create(loc, -1, 32); + Value zero = builder.create(loc, 0, 32); + Value one = builder.create(loc, 1, 32); + Value extractIdx = minusOne; + Value numBuffersVal = + builder.create(loc, numBuffers, 32); + + unsigned newOperandIndex = forOp.getBody()->getNumArguments(); + // Patch the loop to add the new loop carried dependencies. + (void)addIterArgsToLoop(builder, forOp, {extractIdx}); + + // Create one counter for the extract indices to avoid creating long + // live range. + extractIdx = forOp.getBody()->getArgument(newOperandIndex); + + builder.setInsertionPoint(forOp.getBody(), forOp.getBody()->begin()); + extractIdx = builder.create(loc, extractIdx, one); + Value cndExt = builder.create(loc, arith::CmpIPredicate::slt, + extractIdx, numBuffersVal); + extractIdx = builder.create(loc, cndExt, extractIdx, zero); + + // Replace tt.loads with async copies or stream copies + for (auto &[op, alloc] : loadToAllocs) { + if (auto loadOp = dyn_cast(op)) { + if (useAsyncCopy && createAsyncCopy(loadOp, alloc, extractIdx)) + continue; + createStreamCopy(loadOp, alloc, extractIdx); + } + } + // Patch the yield with the updated counters. + appendToForOpYield(forOp, {extractIdx}); +} + +LogicalResult FourStagePipeliner::preprocessLoopAndBuildSchedule() { + // Schedule the loads and root ops (dot ops) in the loop. This will give us + // a scaffold for the final schedule. + DenseSet rootUsers; + if (failed(scheduleLoads(rootUsers))) + return failure(); + if (loadToInfo.empty()) + return failure(); + + LLVM_DEBUG({ + LDBG("Coarse schedule loads only:"); + schedule.dump(); + }); + + // Convert the loads into shared memory allocations and loads from them. + createStreamOps(); + LLVM_DEBUG({ + LDBG("Coarse schedule with replaced laod ops:"); + schedule.dump(); + }); + + // Schedule reductions + int c = 2; + for (auto reduceOp : forOp.getBody()->getOps()) { + schedule.insert(reduceOp, c, softmaxClusters[c == 2 ? 1 : 0]); + c++; + } + + for (auto exp2Op : forOp.getBody()->getOps()) { + schedule.insert(exp2Op, 2, softmaxClusters[1]); + } + LLVM_DEBUG({ + LDBG("Coarse schedule after schedule reduction:"); + schedule.dump(); + }); + + scheduleDependencies(); + LLVM_DEBUG({ + LDBG("Coarse schedule with dependencies:"); + schedule.dump(); + }); + + scheduleDistanceOneDependencies(); + LLVM_DEBUG({ + LDBG("Coarse schedule with dist 1:"); + schedule.dump(); + }); + + scheduleRemainingToLastStage(); + LLVM_DEBUG({ + LDBG("Final coarse schedule:"); + schedule.dump(); + }); + + // Create the final schedule for the kernel loop. This will dictate the + // stages and order of operations to the pipeline expander. + std::vector> coarseSchedule = + schedule.createFinalSchedule(forOp); + + // Fill out the pipeline options. + options.getScheduleFn = + [coarseSchedule](scf::ForOp, + std::vector> &s) { + s = std::move(coarseSchedule); + }; + + OpBuilder builder(forOp); + builder.setInsertionPointAfter(forOp); + // Explicitly deallocate created allocations. + for (auto alloc : sharedMemAllocs) + builder.create(forOp.getLoc(), alloc); + + return success(); +} + +LogicalResult FourStagePipeliner::pipelineLoop() { + if (failed(preprocessLoopAndBuildSchedule())) + return failure(); + LDBG("Loop before sending to expander:\n" << *forOp); + + IRRewriter rewriter(forOp->getContext()); + rewriter.setInsertionPoint(forOp); + return tt::pipelineForLoop(rewriter, forOp, options); +} diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/FourStagePipeliner.h b/third_party/amd/lib/TritonAMDGPUTransforms/FourStagePipeliner.h new file mode 100644 index 000000000000..dd01de342cdb --- /dev/null +++ b/third_party/amd/lib/TritonAMDGPUTransforms/FourStagePipeliner.h @@ -0,0 +1,168 @@ +#ifndef TRITON_THIRD_PARTY_AMD_LIB_TRITONAMDGPUTRANSFORMS_FOURSTAGEPIPELINE_H_ +#define TRITON_THIRD_PARTY_AMD_LIB_TRITONAMDGPUTRANSFORMS_FOURSTAGEPIPELINE_H_ + +#include "TritonAMDGPUTransforms/Passes.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Support/LLVM.h" +#include "third_party/amd/include/Dialect/TritonAMDGPU/IR/Dialect.h" +#include "third_party/amd/lib/TritonAMDGPUToLLVM/SchedInstructions.h" +#include "triton/Analysis/AxisInfo.h" +#include "triton/Dialect/Triton/IR/OpInterfaces.h" +#include "triton/Dialect/TritonGPU/IR/Attributes.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/Transforms/PipelineExpander.h" +#include "triton/Dialect/TritonGPU/Transforms/PipeliningUtility.h" +#include "triton/Dialect/TritonGPU/Transforms/Schedule.h" +#include "triton/Dialect/TritonGPU/Transforms/Utility.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/Support/Debug.h" + +using namespace mlir; + +//===----------------------------------------------------------------------===// +// Software pipelining generally works by anchoring on global load ops in the +// main loop and rotating the loop to schedule global load ops for future loop +// iterations together with compute for the current iteration. In this way, we +// can 1) issue memory operations earlier to hide the latency and 2) break the +// strong dependency inside on loop iteration to give backends flexibility to +// better interleave instructions for better instruction-level parallelism. +// +// This FourStagePipeliner class creates the pipelining schedule and calls the +// PipelineExpander to rewrite the `scf.for` loop accordingly. A schedule +// consists of multiple stages, where ops from different stages can overlap +// executions because the dependencies are loop carried. +// +// The general flow of this process is: +// +// 1. The user provides a `num_stages` that specifies how many stages the +// pipeline will have. The number of stages must be larger than the distance +// from the first independent load to the compute in order to pipeline. +// 1.a. User may also specify `global_prefetch=` to set the number of +// stages between tt.load and ttg.local_store ops. +// 1.b. User may also specify `local_prefetch=` to set the number of +// stages between ttg.local_load and compute. +// 2. A schedule is created based on the distance between the global loads +// in the first stages and the compute that uses the loaded values in the +// last stage (num_stages - 1). Each operation will be clustered in the +// order to best overlap with other operations (see details below in the +// initSchedule method). +// 3. When the compute is a tt.dot, the scheduler will insert a shared +// memory allocation between the global load and tt.dot. The ttg.local_store +// will save the global load value to shared memory and the ttg.local_load +// will load the relevant tiles for the tt.dot. These operations will be +// scheduled according to various scheduling schemes outlined below in the +// initSchedule method (see details there). +// 4. Finally the schedule will be passed to the PipelineExpander to rewrite +// accordingly. The new implementation will consist of: +// a. Prologue: containing the ramp-up of num_stages-1 stages for +// iteratorions i=[0, num_stages-1). +// b. New loop: ordered by cluster and iterated on each operation by +// `i + (num_stages-op_stage)`. +// c. Epilogue: ramp-down of the last `num_stages-1` iterations for the +// ops in stages 1 to last_stage. This must consider that the loop +// bounds may be shorter than num_stages. In this case, the epilogue +// iterations must align with the prologue. +// +class FourStagePipeliner { + // Define categories of scheduling details per Operation types. + // The FourStagePipeliner schedules 5 types of operations: + // 1. GLOBAL_LOAD: tt.load / ttg.async_copy_global_to_local + // 2. LOCAL_STORE: ttg.local_store + // 3. LOCAL_LOAD: ttg.local_load + // 4. COMPUTE: ops that use the loaded data + // 5. ASYNC_WAIT: ttg.async_wait + // Note that ttg ops mentioned in the above list are created in this pass. + enum SchedType { + SCHED_GLOBAL_LOAD, + SCHED_LOCAL_STORE, + SCHED_LOCAL_LOAD, + SCHED_COMPUTE, + SCHED_ASYNC_WAIT, + SCHED_SIZE + }; + +public: + FourStagePipeliner(scf::ForOp _forOp, int _numStages, int _globalPrefetch, + int _localPrefetch, bool _useAsyncCopy); + + static bool checkPrecondition(scf::ForOp forOp, int numStages); + + LogicalResult pipelineLoop(); + +private: + LogicalResult initSchedule(int maxIndirectionLevel); + + void computeLoadOpsToIndirectionLevelAndUse(); + void assignMemoryLayouts(); + LogicalResult scheduleLoads(DenseSet &rootUsers); + void scheduleDependencies(); + void scheduleDistanceOneDependencies(); + void scheduleRemainingToLastStage(); + + LogicalResult preprocessLoopAndBuildSchedule(); + + Value createAlloc(Operation *loadOp, + triton::gpu::SwizzledSharedEncodingAttr sharedEnc); + bool createAsyncCopy(triton::LoadOp loadOp, Value alloc, Value extractIdx); + void createStreamCopy(triton::LoadOp loadOp, Value alloc, Value extractIdx); + void createStreamOps(); + + void scheduleOp(Operation *op, SchedType type, int stage = -1) { + if (stage < 0) + stage = stages[type]; + schedule.insert(op, stage, clusters[type]); + } + +private: + // Data members + scf::ForOp forOp; + + // User settings + int numStages; + + // Computed number of buffers + int numBuffers; + + // Directly store to shared memory with AsyncCopy when pipelining tt.loads + bool useAsyncCopy; + + // Stage for each SchedType Op + int stages[SCHED_SIZE]; + // (not used anymore) Cluster for each SchedType Op + std::array clusters; + + // Clusters to hold the different Ops for the 4-stage pipeliner + std::array localReadClusters; + std::array softmaxClusters; + std::array asyncCopyClusters; + std::array dotClusters; + + // Scheduling clusters + triton::CoarseSchedule schedule; + + // Mapping and indirection level for each `tt.load` to its use. + SmallVector> loadOpToIndLevelAndUse; + + struct LoadInfo { + // Shared layout is used for loads feeding into dot ops. + triton::gpu::SwizzledSharedEncodingAttr sharedEncoding = nullptr; + // The distance of this load's stage to its use' stage. + int distToUse = 0; + bool usedByDot = false; + bool isAsync = false; + }; + + // Mapping for each pipelined load to scheduling details. + llvm::MapVector loadToInfo; + + // Lookup alignment/contiguity mappings for the current module. + triton::ModuleAxisInfoAnalysis axisInfoAnalysis; + + // Capture list of new shared memory buffers. + SmallVector sharedMemAllocs; + + // Pipelining options for the PipelineExpander + triton::PipeliningOption options; +}; + +#endif diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/StreamPipeline.cpp b/third_party/amd/lib/TritonAMDGPUTransforms/StreamPipeline.cpp index f927d0f8d819..a37e170223dd 100644 --- a/third_party/amd/lib/TritonAMDGPUTransforms/StreamPipeline.cpp +++ b/third_party/amd/lib/TritonAMDGPUTransforms/StreamPipeline.cpp @@ -1,3 +1,4 @@ +#include "FourStagePipeliner.h" #include "TritonAMDGPUTransforms/Passes.h" #include "mlir/Support/LLVM.h" #include "third_party/amd/include/Dialect/TritonAMDGPU/IR/Dialect.h" @@ -113,6 +114,7 @@ class StreamPipeliner { SCHED_LOCAL_STORE, SCHED_LOCAL_LOAD, SCHED_COMPUTE, + SCHED_ASYNC_WAIT, SCHED_SIZE }; @@ -127,7 +129,7 @@ class StreamPipeliner { stages[SCHED_LOCAL_STORE] = _globalPrefetch; stages[SCHED_LOCAL_LOAD] = lastStage - _localPrefetch; stages[SCHED_COMPUTE] = lastStage; - // stages[SCHED_ASYNC_WAIT] = stages[SCHED_LOCAL_LOAD]; + stages[SCHED_ASYNC_WAIT] = stages[SCHED_LOCAL_LOAD]; options.supportDynamicLoops = true; options.peelEpilogue = true; @@ -175,15 +177,9 @@ class StreamPipeliner { // Stage for each SchedType Op int stages[SCHED_SIZE]; - // (not used anymore) Cluster for each SchedType Op + // Cluster for each SchedType Op std::array clusters; - // Clusters to hold the different Ops for the 4-stage pipeliner - std::array localReadClusters; - std::array softmaxClusters; - std::array asyncCopyClusters; - std::array dotClusters; - // Scheduling clusters tt::CoarseSchedule schedule; @@ -224,14 +220,13 @@ LogicalResult StreamPipeliner::initSchedule(int maxIndirectionLevel) { bool pairedGlobalLoadLocalStore = stages[SCHED_LOCAL_STORE] == 0; stages[SCHED_LOCAL_STORE] += maxIndirectionLevel; - LDBG("Stage schedule:" - << " GLOBAL_LOAD stage = " << stages[SCHED_GLOBAL_LOAD] - << ", LOCAL_STORE stage = " << stages[SCHED_LOCAL_STORE] - << ", LOCAL_LOAD stage = " << stages[SCHED_LOCAL_LOAD] - << ", COMPUTE stage = " - << stages[SCHED_COMPUTE] - // << ", ASYNC_WAIT stage = " << stages[SCHED_ASYNC_WAIT] - << "; total = " << numStages); + LDBG( + "Stage schedule:" << " GLOBAL_LOAD stage = " << stages[SCHED_GLOBAL_LOAD] + << ", LOCAL_STORE stage = " << stages[SCHED_LOCAL_STORE] + << ", LOCAL_LOAD stage = " << stages[SCHED_LOCAL_LOAD] + << ", COMPUTE stage = " << stages[SCHED_COMPUTE] + << ", ASYNC_WAIT stage = " << stages[SCHED_ASYNC_WAIT] + << "; total = " << numStages); if (stages[SCHED_LOCAL_STORE] >= numStages || stages[SCHED_LOCAL_STORE] > stages[SCHED_LOCAL_LOAD]) { @@ -248,7 +243,6 @@ LogicalResult StreamPipeliner::initSchedule(int maxIndirectionLevel) { if (useAsyncCopy) { numBuffers += 1; } - numBuffers = 2; LDBG("deduced max shared memory buffer number = " << numBuffers); @@ -289,42 +283,15 @@ LogicalResult StreamPipeliner::initSchedule(int maxIndirectionLevel) { } // Make assignments - // std::array clusterVec; - // std::generate(clusterVec.begin(), clusterVec.end(), - // [&]() { return schedule.clusters.newAtBack(); }); - - // Create clusters in order of 4-stage pipeliner. You can swap lines below to - // change the schedule of the loop. Not all combination are valid, e.g. if a - // consumer and producer from the same stage are in the wrong cluster order - // the loop expander will silently fail - - // DOT1 - dotClusters[0] = schedule.clusters.newAtBack(); - // SM2, - softmaxClusters[0] = schedule.clusters.newAtBack(); - // Wait for V, LRV - localReadClusters[0] = schedule.clusters.newAtBack(); - // ACK - asyncCopyClusters[0] = schedule.clusters.newAtBack(); - // DOT2 - dotClusters[1] = schedule.clusters.newAtBack(); - // SM1 - softmaxClusters[1] = schedule.clusters.newAtBack(); - // Wait for K, LRK - localReadClusters[1] = schedule.clusters.newAtBack(); - // ACV - asyncCopyClusters[1] = schedule.clusters.newAtBack(); - - // ATTENTION 4-stage (not used) - clusters[SCHED_GLOBAL_LOAD] = softmaxClusters[1]; - clusters[SCHED_LOCAL_STORE] = asyncCopyClusters[0]; - clusters[SCHED_LOCAL_LOAD] = asyncCopyClusters[0]; - clusters[SCHED_COMPUTE] = softmaxClusters[0]; - - // clusters[SCHED_GLOBAL_LOAD] = clusterVec[globalLoadCluster]; - // clusters[SCHED_LOCAL_STORE] = clusterVec[localStoreCluster]; - // clusters[SCHED_LOCAL_LOAD] = clusterVec[localLoadCluster]; - // clusters[SCHED_COMPUTE] = clusterVec[computeCluster]; + std::array clusterVec; + std::generate(clusterVec.begin(), clusterVec.end(), + [&]() { return schedule.clusters.newAtBack(); }); + + clusters[SCHED_GLOBAL_LOAD] = clusterVec[globalLoadCluster]; + clusters[SCHED_LOCAL_STORE] = clusterVec[localStoreCluster]; + clusters[SCHED_LOCAL_LOAD] = clusterVec[localLoadCluster]; + clusters[SCHED_COMPUTE] = clusterVec[computeCluster]; + clusters[SCHED_ASYNC_WAIT] = clusterVec[asyncWaitCluster]; LDBG("Cluster schedule:" << " GLOBAL_LOAD cluster = " << globalLoadCluster << ", LOCAL_STORE cluster = " << localStoreCluster @@ -389,51 +356,31 @@ bool StreamPipeliner::createAsyncCopy(tt::LoadOp loadOp, Value alloc, ttg::AsyncWaitOp waitOp = builder.create(loc, commitOp->getResult(0), 0); - // scheduleOp(waitOp, SCHED_ASYNC_WAIT); // Create local load which consumes the async token from the AsyncWait auto sharedLoad = builder.create(loc, loadOp.getType(), viewLoad, waitOp); auto [loadStage, loadCluster] = schedule[loadOp]; - auto localLoadStage = loadStage == 0 ? 1 : 3; - auto localLoadCluster = loadStage == 0 ? 1 : 0; - auto waitCluster = loadStage == 0 ? 1 : 0; - schedule.erase(loadOp); // Schedule new ops schedule.insert(copyOp, loadStage, loadCluster); // Place ttg.async_commit_group op following AsyncCopyGlobalToLocal so the // later UpdateAsyncWaitCount pass can deduce better waitcnts schedule.insert(commitOp, loadStage, loadCluster); - // If the LocalLoads are scheduled to a later stage than AsyncCopy we need - // to place the AsyncCopy prefetches after the AsyncWaits which create a - // barrier to ensure all warps are finished reading the shared buffer we - // will write into. This is done by scheduling AsyncWait as the first - // cluster. If AsyncCopy and LocalLoads are in the same stage we do not - // assign a schdule so they are placed before the LocalLoads - schedule.insert(sharedLoad, localLoadStage, - localReadClusters[localLoadCluster]); - schedule.insert(waitOp, localLoadStage, localReadClusters[localLoadCluster]); - - // if (loadStage != stages[SCHED_LOCAL_LOAD]) - // scheduleOp(waitOp, SCHED_ASYNC_WAIT); + // If the LocalLoads are scheduled to a later stage than AsyncCopy we need to + // place the AsyncCopy prefetches after the AsyncWaits which create a barrier + // to ensure all warps are finished reading the shared buffer we will write + // into. This is done by scheduling AsyncWait as the first cluster. + // If AsyncCopy and LocalLoads are in the same stage we do not assign a + // schdule so they are placed before the LocalLoads + if (loadStage != stages[SCHED_LOCAL_LOAD]) + scheduleOp(waitOp, SCHED_ASYNC_WAIT); if (stages[SCHED_LOCAL_LOAD] != stages[SCHED_COMPUTE]) scheduleOp(sharedLoad, SCHED_LOCAL_LOAD); loadOp->replaceAllUsesWith(ValueRange{sharedLoad}); - - // Make sure that a possible cvt is in the same stage or otherwise it will not - // get folded - if (sharedLoad->hasOneUse()) { - if (auto cvt = - dyn_cast(*sharedLoad->getUsers().begin())) { - LDBG("Change cvt layout stage and cluster"); - schedule.insert(cvt, localLoadStage, localReadClusters[localLoadCluster]); - } - } - if (stages[SCHED_LOCAL_LOAD] != stages[SCHED_COMPUTE] && sharedLoad->hasOneUse()) { if (auto cvt = @@ -441,8 +388,6 @@ bool StreamPipeliner::createAsyncCopy(tt::LoadOp loadOp, Value alloc, scheduleOp(cvt, SCHED_LOCAL_LOAD); } - // Delete old loadOp - schedule.erase(loadOp); loadOp.erase(); return true; } @@ -759,30 +704,21 @@ LogicalResult StreamPipeliner::scheduleLoads(DenseSet &rootUsers) { llvm::divideCeil(numStages - 2, maxIndirectionLevel + 1); LDBG("stagesBetweenLoads = " << stagesBetweenLoads); - // Assign stages to the loads. - // FA: - // Load1: Stage=0, cluster=1 - // Load2: Stage=1, cluster=3 - int i{}; - for (auto [loadOp, indLevel, _] : loadOpToIndLevelAndUse) { - int stage = (maxIndirectionLevel - indLevel) * stagesBetweenLoads; - if (schedule.count(loadOp) > 0) - continue; - schedule.insert(loadOp, i, asyncCopyClusters[i == 0 ? 0 : 1]); - i++; - } - // Put the root uses of the loads in the last stage. for (auto &[loadOp, dist, use] : loadOpToIndLevelAndUse) { // Non-LoadOp(s) are the (final) root uses of all LoadOp(s). if (!isa(use)) { - auto loadStage = schedule[loadOp].first; - schedule.insert(use, loadStage + 2, dotClusters[loadStage == 0 ? 0 : 1]); - // scheduleOp(use, SCHED_COMPUTE); + scheduleOp(use, SCHED_COMPUTE); rootUsers.insert(use); } } + // Assign stages to the loads. + for (auto [loadOp, indLevel, _] : loadOpToIndLevelAndUse) { + int stage = (maxIndirectionLevel - indLevel) * stagesBetweenLoads; + scheduleOp(loadOp, SCHED_GLOBAL_LOAD, stage); + } + // Calculate distance from the load to the use. for (auto [loadOp, _, use] : loadOpToIndLevelAndUse) { loadToInfo[loadOp].distToUse = schedule[use].first - schedule[loadOp].first; @@ -810,21 +746,7 @@ void StreamPipeliner::scheduleDependencies() { for (auto [op, stage_, cluster] : opsInOrder) { if (stage_ != stage) continue; - auto depCluster = cluster; - bool override = false; - if (llvm::isa(op) && stage == 3) { - depCluster = softmaxClusters[0]; - override = true; - } - - auto moveStages = [this, stage, cluster = cluster, - depCluster = depCluster, override](Operation *op) { - if (llvm::isa(op)) { - return std::make_pair(stage, cluster); - } - return std::make_pair(stage, depCluster); - }; - schedule.insertDepsOfOp(op, false, false, moveStages); + schedule.insertDepsOfOp(op, stage, cluster, false); } } } @@ -998,25 +920,6 @@ LogicalResult StreamPipeliner::preprocessLoopAndBuildSchedule() { // Convert the loads into shared memory allocations and loads from them. createStreamOps(); - LLVM_DEBUG({ - LDBG("Coarse schedule with replaced laod ops:"); - schedule.dump(); - }); - - // Schedule reductions - int c = 2; - for (auto reduceOp : forOp.getBody()->getOps()) { - schedule.insert(reduceOp, c, softmaxClusters[c == 2 ? 1 : 0]); - c++; - } - - for (auto exp2Op : forOp.getBody()->getOps()) { - schedule.insert(exp2Op, 2, softmaxClusters[1]); - } - LLVM_DEBUG({ - LDBG("Coarse schedule after schedule reduction:"); - schedule.dump(); - }); scheduleDependencies(); LLVM_DEBUG({ @@ -1158,25 +1061,27 @@ struct PipelinePass : public TritonAMDGPUStreamPipelineBase { loops.push_back(forOp); }); - // Skip second (causal) loop - int loopCount{}; - for (scf::ForOp forOp : loops) { - if (loopCount++ == 1) - continue; - if (!checkPrecondition(forOp)) continue; - StreamPipeliner sp(forOp, tt::getNumStagesOrDefault(forOp, numStages), - globalPrefetch, localPrefetch, useAsyncCopy); - (void)sp.pipelineLoop(); + + if (FourStagePipeliner::checkPrecondition(forOp, numStages)) { + FourStagePipeliner fsp(forOp, + tt::getNumStagesOrDefault(forOp, numStages), + globalPrefetch, localPrefetch, useAsyncCopy); + (void)fsp.pipelineLoop(); + } else { + StreamPipeliner sp(forOp, tt::getNumStagesOrDefault(forOp, numStages), + globalPrefetch, localPrefetch, useAsyncCopy); + (void)sp.pipelineLoop(); + } } - // if (useAsyncCopy) { - // llvm::SmallSetVector waitOps; - // moduleOp.walk([&](ttg::AsyncWaitOp waitOp) { waitOps.insert(waitOp); }); - // tt::combineRedundantWaitOps(waitOps); - // } + if (useAsyncCopy) { + llvm::SmallSetVector waitOps; + moduleOp.walk([&](ttg::AsyncWaitOp waitOp) { waitOps.insert(waitOp); }); + tt::combineRedundantWaitOps(waitOps); + } } }; } // namespace From d46f750035ec95a4342fdc8ca3435efbf6681fa9 Mon Sep 17 00:00:00 2001 From: Jungwook Park Date: Thu, 1 May 2025 14:57:46 -0500 Subject: [PATCH 22/44] [GEMM] Add combine dot_scaled and addF --- lib/Dialect/Triton/Transforms/Combine.cpp | 25 +++++++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/lib/Dialect/Triton/Transforms/Combine.cpp b/lib/Dialect/Triton/Transforms/Combine.cpp index 6fab87c8a562..22496cf67192 100644 --- a/lib/Dialect/Triton/Transforms/Combine.cpp +++ b/lib/Dialect/Triton/Transforms/Combine.cpp @@ -239,6 +239,29 @@ class RankedReduceDescriptorLoads : public mlir::OpRewritePattern { } }; +class CombineDotScaledAddPattern : public mlir::OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + mlir::LogicalResult + matchAndRewrite(triton::DotScaledOp dotOp, + mlir::PatternRewriter &rewriter) const override { + if (!dotOp->hasOneUse() || !isZero(dotOp.getC())) + return failure(); + auto user = dotOp->getUsers().begin(); + if (auto addOp = llvm::dyn_cast(*user)) { + auto acc = (addOp.getRhs() == dotOp) ? addOp.getLhs() : addOp.getRhs(); + IRMapping mapping; + mapping.map(dotOp.getC(), acc); + auto newOp = rewriter.clone(*dotOp, mapping); + rewriter.replaceOp(addOp, newOp->getResults()); + rewriter.eraseOp(dotOp); + return success(); + } + return failure(); + } +}; + } // anonymous namespace class CombineOpsPass : public impl::TritonCombineOpsBase { @@ -253,6 +276,8 @@ class CombineOpsPass : public impl::TritonCombineOpsBase { patterns.add(context); patterns.add(context); patterns.add(context); + + patterns.add(context); // %} patterns.add(context); patterns.add(context); From 4a5ece6345231d9b8c8b6f0a876ef168c5c764cf Mon Sep 17 00:00:00 2001 From: Alexander Weinrauch Date: Tue, 6 May 2025 17:24:48 +0000 Subject: [PATCH 23/44] [GEMM] Do not swizzle the scale --- lib/Dialect/TritonGPU/IR/Dialect.cpp | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/lib/Dialect/TritonGPU/IR/Dialect.cpp b/lib/Dialect/TritonGPU/IR/Dialect.cpp index a6a817a247f3..f63fda14ecde 100644 --- a/lib/Dialect/TritonGPU/IR/Dialect.cpp +++ b/lib/Dialect/TritonGPU/IR/Dialect.cpp @@ -1814,6 +1814,12 @@ SwizzledSharedEncodingAttr AMDMfmaEncodingAttr::composeSharedLayoutForOperand( if (getMDim() == 4) maxPhase = 4; + // Disable swizzling for scales + if (operandIdx >= 2) { + return SwizzledSharedEncodingAttr::get(getContext(), 1, 1, 1, sharedOrder, + ctaLayout); + } + return SwizzledSharedEncodingAttr::get(getContext(), vectorSize, perPhase, maxPhase, sharedOrder, ctaLayout); } From 8285bfcba26040e5dd9f7a5805ccbfe1602edb08 Mon Sep 17 00:00:00 2001 From: Vinayak Gokhale Date: Mon, 12 May 2025 15:11:51 +0000 Subject: [PATCH 24/44] Add layout conversion pass optim at the end --- third_party/amd/backend/compiler.py | 1 + 1 file changed, 1 insertion(+) diff --git a/third_party/amd/backend/compiler.py b/third_party/amd/backend/compiler.py index ff5d25d33015..3fb38ac11be2 100644 --- a/third_party/amd/backend/compiler.py +++ b/third_party/amd/backend/compiler.py @@ -275,6 +275,7 @@ def make_ttgir(mod, metadata, options): passes.common.add_symbol_dce(pm) if use_async_copy: amd.passes.ttgpuir.add_update_async_wait_count(pm, options.arch) + passes.ttgpuir.add_remove_layout_conversions(pm) pm.run(mod) return mod From 916cb0623a7bb527b957bf73305f2a8770f0f789 Mon Sep 17 00:00:00 2001 From: Jungwook Park Date: Mon, 12 May 2025 15:34:54 -0500 Subject: [PATCH 25/44] Initial commit to enable pingpong for dot_scaled with mxfp4 Computation part interleaves mfma and ds_read Placed extra conditional barrier to overlap computation part and buffer_load part. Dot slicing by plognjen at https://github.com/plognjen/triton/tree/slice_dot_scaled requires vmcnt fix to achieve full performance. --- .../TritonAMDGPUTransforms/BlockPingpong.cpp | 212 +++++++++++++++++- 1 file changed, 202 insertions(+), 10 deletions(-) diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/BlockPingpong.cpp b/third_party/amd/lib/TritonAMDGPUTransforms/BlockPingpong.cpp index 4ffcd35d0a0f..0d3e0aaa5bf9 100644 --- a/third_party/amd/lib/TritonAMDGPUTransforms/BlockPingpong.cpp +++ b/third_party/amd/lib/TritonAMDGPUTransforms/BlockPingpong.cpp @@ -40,6 +40,7 @@ class Pingponger { SmallVector asyncCopyOps; SmallVector asyncWaitOps; SmallVector dotOps; + SmallVector dotSOps; SmallVector> subViewOps; SmallVector> loadSliceOps; SmallVector dotSliceOps; @@ -70,12 +71,19 @@ class Pingponger { LogicalResult genLocalSlice(OpBuilder &builder, Value v, Attribute dotEncoding, unsigned opIdx, unsigned numSlices, int64_t sliceWidth); + LogicalResult genLocalSliceScales(OpBuilder &builder, Value v, + Attribute dotEncoding, unsigned opIdx, + unsigned numSlices, int64_t sliceWidth); LogicalResult sliceDot(OpBuilder &builder, Location loc, tt::DotOp op, unsigned numSlices); + LogicalResult sliceDotScaled(OpBuilder &builder, Location loc, + tt::DotScaledOp op, unsigned numSlices); + void transformOnePPClusters(OpBuilder &builder, Location loc); LogicalResult transformFourPPClusters(OpBuilder &builder, Location loc); LogicalResult transformTwoPPClusters(OpBuilder &builder, Location loc); LogicalResult transformFAv3(OpBuilder &builder, Location loc); + LogicalResult transformFP4(OpBuilder &builder, Location loc); void addAsymmetricSyncToLoop(OpBuilder &builder, Location loc); void updateOpInsertion(Operation *Op); void appendOp(Operation *Op); @@ -97,6 +105,10 @@ class Pingponger { DenseSet &dotLocalStores); template void findClosestPredOps(Value v, DenseSet &matchingOps); + + LogicalResult genLocalSliceHelper(OpBuilder &builder, Value v, unsigned opIdx, + unsigned numSlices, int64_t sliceWidth, + RankedTensorType tensorType); }; void Pingponger::updateOpInsertion(Operation *op) { lastInsertedOp = op; } @@ -412,8 +424,6 @@ LogicalResult Pingponger::genLocalSlice(OpBuilder &builder, Value v, Attribute dotEncoding, unsigned opIdx, unsigned numSlices, int64_t sliceWidth) { - SmallVector slices; - SmallVector subviews; // TODO: support transformed input to dot auto localLoad = v.getDefiningOp(); if (!localLoad) @@ -429,21 +439,71 @@ LogicalResult Pingponger::genLocalSlice(OpBuilder &builder, Value v, return failure(); auto dotOperandEnc = ttg::DotOperandEncodingAttr::get( builder.getContext(), opIdx, dotEncoding, kWidth); + + auto tensorType = RankedTensorType::get(shape, elementType, dotOperandEnc); + + return genLocalSliceHelper(builder, v, opIdx, numSlices, sliceWidth, + tensorType); +} + +LogicalResult Pingponger::genLocalSliceScales(OpBuilder &builder, Value v, + Attribute dotEncoding, + unsigned opIdx, + unsigned numSlices, + int64_t sliceWidth) { + auto localLoad = v.getDefiningOp(); + if (!localLoad) + return failure(); + auto memDesc = localLoad.getSrc(); + auto type = cast(memDesc.getType()); + SmallVector shape = llvm::to_vector(type.getShape()); + Type elementType = type.getElementType(); + int64_t kIdx = opIdx == 0 ? 1 : 0; + shape[kIdx] = sliceWidth; + + auto ll = mlir::triton::gpu::toLinearLayout(shape, dotEncoding); + auto dotOperandEnc = ttg::LinearEncodingAttr::get(type.getContext(), ll); + auto tensorType = RankedTensorType::get(shape, elementType, dotOperandEnc); + + return genLocalSliceHelper(builder, v, 0, numSlices, sliceWidth, tensorType); +} + +LogicalResult Pingponger::genLocalSliceHelper(OpBuilder &builder, Value v, + unsigned opIdx, + unsigned numSlices, + int64_t sliceWidth, + RankedTensorType tensorType) { + + SmallVector slices; + SmallVector subviews; + + auto localLoad = v.getDefiningOp(); + if (!localLoad) + return failure(); + + auto memDesc = localLoad.getSrc(); + auto type = cast(memDesc.getType()); + SmallVector shape = llvm::to_vector(type.getShape()); + Type elementType = type.getElementType(); + int64_t kIdx = opIdx == 0 ? 1 : 0; + shape[kIdx] = sliceWidth; + auto subviewDescType = ttg::MemDescType::get( shape, elementType, type.getEncoding(), type.getMemorySpace(), type.getMutableMemory(), type.getAllocShape()); + for (int i = 0; i < numSlices; i++) { SmallVector offsetsVal; SmallVector offsets = {0, 0}; - offsets[kIdx] = i; + offsets[opIdx == 0 ? 1 : 0] = i; for (int64_t off : offsets) { - offsetsVal.push_back(constOffsets[off]); + offsetsVal.push_back(builder.create( + v.getLoc(), off * sliceWidth, 32)); } Value newSmem = builder.create( v.getLoc(), subviewDescType, memDesc, offsetsVal); - Value prefetchSlice = builder.create( - v.getLoc(), RankedTensorType::get(shape, elementType, dotOperandEnc), - newSmem); + Value prefetchSlice = + builder.create(v.getLoc(), tensorType, newSmem); subviews.push_back(newSmem.getDefiningOp()); slices.push_back(prefetchSlice.getDefiningOp()); } @@ -491,6 +551,78 @@ LogicalResult Pingponger::sliceDot(OpBuilder &builder, Location loc, return success(); } +LogicalResult Pingponger::sliceDotScaled(OpBuilder &builder, Location loc, + tt::DotScaledOp op, + unsigned numSlices) { + builder.setInsertionPointToStart(forOp.getBody()); + auto typeB = op.getB().getType(); + auto typeScaleB = op.getBScale().getType(); + auto shapeB = typeB.getShape(); + auto shapeScaleB = typeScaleB.getShape(); + + int64_t sliceWidth = shapeB[0] / numSlices; + int64_t sliceScaleWidth = shapeScaleB[1] / numSlices; + if (shapeB[1] % numSlices != 0) + return failure(); + + builder.setInsertionPointAfter(op); + auto dotEncoding = op.getType().getEncoding(); + + // Generate slices for operands A and B + if (genLocalSlice(builder, op.getA(), dotEncoding, 0, numSlices, sliceWidth) + .failed() || + genLocalSlice(builder, op.getB(), dotEncoding, 1, numSlices, sliceWidth) + .failed()) + return failure(); + + // Generate slices for scale tensors if they exist + Value aScale = op.getAScale(); + Value bScale = op.getBScale(); + + if (aScale) { + if (genLocalSliceScales(builder, aScale, + op.getAScale().getType().getEncoding(), 0, + numSlices, sliceScaleWidth) + .failed()) + return failure(); + } + + if (bScale) { + if (genLocalSliceScales(builder, bScale, + op.getBScale().getType().getEncoding(), 0, + numSlices, sliceScaleWidth) + .failed()) + return failure(); + } + + Operation *prevDot = op; + for (int i = 0; i < numSlices; i++) { + IRMapping mapping; + mapping.map(op.getA(), loadSliceOps[0][i]->getResult(0)); + mapping.map(op.getB(), loadSliceOps[1][i]->getResult(0)); + + // Map scale tensors if they exist + if (aScale) + mapping.map(op.getAScale(), loadSliceOps[2][i]->getResult(0)); + if (bScale) + mapping.map(op.getBScale(), loadSliceOps[3][i]->getResult(0)); + + if (i > 0) + mapping.map(op.getC(), prevDot->getResult(0)); + + auto newOp = builder.clone(*op, mapping); + prevDot = newOp; + dotSliceOps.push_back(newOp); + } + + // Replace original op with the last slice and cleanup + op->replaceAllUsesWith(prevDot); + op->erase(); + for (auto loads : lLoadOps) + loads->erase(); + return success(); +} + // Transform a loop into four Dot - Memory (ping - pong) clusters // This transform is useful when the original dot tile is too large that there's // not enough registers to hold data for a Dot cluster. This path slices the dot @@ -669,6 +801,51 @@ LogicalResult Pingponger::transformFAv3(OpBuilder &builder, Location loc) { return success(); } +LogicalResult Pingponger::transformFP4(OpBuilder &builder, Location loc) { + + + builder.setInsertionPointAfter(forOp); + + //FIXME: This is duplicated code, need to refactorize. + auto i32ty = builder.getIntegerType(32); + auto workIDX = builder.create(loc, i32ty); + workIDX->moveBefore(forOp); + builder.setInsertionPointAfter(workIDX); + auto constZero = builder.create(loc, 0, 32); + auto constWarpSize = builder.create(loc, 256, 32); + auto warpIDX = builder.create(loc, workIDX, constWarpSize); + auto warpLow = builder.create(loc, arith::CmpIPredicate::eq, + warpIDX, constZero); + auto warpHigh = builder.create(loc, arith::CmpIPredicate::ne, + warpIDX, constZero); + + + + builder.setInsertionPointAfter(dotSOps[0]); + + if (sliceDotScaled(builder, loc, dotSOps[0], 4).failed()) + return failure(); + updateOpInsertion(dotSliceOps[0]); + + appendOp(builder.create(loc, 0)); + appendOp(builder.create(loc, warpLow)); + appendOp(builder.create(loc, 0)); + for (int j=0; j<4; j++){ + for (int i=0; i<4; i++) + appendOp(subViewOps[i][j]); + for (int i=0; i<4; i++) + appendOp(loadSliceOps[i][j]); + appendOp(builder.create(loc, 0)); + appendOp(dotSliceOps[j]); + } + + appendOp(builder.create(loc, 0)); + appendOp(builder.create(loc, warpHigh)); + + + return success(); +} + // This function wraps forOp with cond_barrier. First, hold half of the warps // (warpHigh) in a block before the loop so the barriers in the loop synchronize // warps at the different point per the warp groups. After the loop, hold @@ -730,9 +907,11 @@ void Pingponger::getDotPingponged() { else if (auto pingpongDot = dyn_cast(op)) { if (pingpongDot.getType().getRank() == 2) dotOps.push_back(pingpongDot); - } else if (auto asyncOp = dyn_cast(op)) + } else if (auto pingpongDot = dyn_cast(op)) { + dotSOps.push_back(pingpongDot); + } else if (auto asyncOp = dyn_cast(op)) { asyncCopyOps.push_back(asyncOp); - else if (auto asyncOp = dyn_cast(op)) + } else if (auto asyncOp = dyn_cast(op)) asyncWaitOps.push_back(asyncOp); }); @@ -752,7 +931,9 @@ void Pingponger::getDotPingponged() { // software pipelining and dot rank=2. Also only accept the for-loop with // supported combination of operations because this transformation is very // tightly scheduling the latencies. - if (gLoadOps.size() < 2 || lLoadOps.size() < 2 || dotOps.size() != 1) { + + //FIXME: get better condition to enable pingpong either for dot or for dot_scaled + if ((dotSOps.size() != 1) || (gLoadOps.size() < 2 || lLoadOps.size() < 2 || dotSOps.size() != 1)){ std::stringstream message; message << "Unable to match ping pong scheduling pattern. Details: " << gLoadOps.size() << " global loads, " << lLoadOps.size() @@ -761,6 +942,17 @@ void Pingponger::getDotPingponged() { return; } + //FIXME: place tile size restriction here and obtain kWidth + kWidth = 16; + if (transformFP4(builder, dotSOps[0]->getLoc()).failed()) { + LDBG("Encountered failure when trying to execute the two ping pong " + "cluster transformation"); + return; + } + addAsymmetricSyncToLoop(builder, loc); + + return; + // Determine if we have a persistent GEMM. This will decide how we interpret // any memory operations that we find in conditionals. auto assumeNotTaken = isPersistentGemm(dotOps.size()); From b3c2f943320f70e753fd416772a64bcf6e8559ad Mon Sep 17 00:00:00 2001 From: Jungwook Park Date: Tue, 13 May 2025 07:11:16 -0500 Subject: [PATCH 26/44] Fix to the gemm pingpong. Fix incorrect condition to choose enable transforms. Fix missing tokens to the local_load --- third_party/amd/lib/TritonAMDGPUTransforms/BlockPingpong.cpp | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/BlockPingpong.cpp b/third_party/amd/lib/TritonAMDGPUTransforms/BlockPingpong.cpp index 0d3e0aaa5bf9..7bdd6e6f9d76 100644 --- a/third_party/amd/lib/TritonAMDGPUTransforms/BlockPingpong.cpp +++ b/third_party/amd/lib/TritonAMDGPUTransforms/BlockPingpong.cpp @@ -481,6 +481,7 @@ LogicalResult Pingponger::genLocalSliceHelper(OpBuilder &builder, Value v, if (!localLoad) return failure(); + auto waitToken = localLoad.getToken(); auto memDesc = localLoad.getSrc(); auto type = cast(memDesc.getType()); SmallVector shape = llvm::to_vector(type.getShape()); @@ -503,7 +504,7 @@ LogicalResult Pingponger::genLocalSliceHelper(OpBuilder &builder, Value v, Value newSmem = builder.create( v.getLoc(), subviewDescType, memDesc, offsetsVal); Value prefetchSlice = - builder.create(v.getLoc(), tensorType, newSmem); + builder.create(v.getLoc(), tensorType, newSmem, waitToken); subviews.push_back(newSmem.getDefiningOp()); slices.push_back(prefetchSlice.getDefiningOp()); } @@ -933,7 +934,7 @@ void Pingponger::getDotPingponged() { // tightly scheduling the latencies. //FIXME: get better condition to enable pingpong either for dot or for dot_scaled - if ((dotSOps.size() != 1) || (gLoadOps.size() < 2 || lLoadOps.size() < 2 || dotSOps.size() != 1)){ + if ((dotSOps.size() != 1) && (gLoadOps.size() < 2 || lLoadOps.size() < 2 || dotSOps.size() != 1)){ std::stringstream message; message << "Unable to match ping pong scheduling pattern. Details: " << gLoadOps.size() << " global loads, " << lLoadOps.size() From a19dd6d6060a1053993fb98ac53b85372de051e4 Mon Sep 17 00:00:00 2001 From: Jungwook Park Date: Tue, 13 May 2025 09:05:19 -0500 Subject: [PATCH 27/44] Add restriction to dot_scaled pingpong. Only enable for 256x256x256 tilesize --- .../TritonAMDGPUTransforms/BlockPingpong.cpp | 19 +++++++++++++++---- 1 file changed, 15 insertions(+), 4 deletions(-) diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/BlockPingpong.cpp b/third_party/amd/lib/TritonAMDGPUTransforms/BlockPingpong.cpp index 7bdd6e6f9d76..82c5cddf32aa 100644 --- a/third_party/amd/lib/TritonAMDGPUTransforms/BlockPingpong.cpp +++ b/third_party/amd/lib/TritonAMDGPUTransforms/BlockPingpong.cpp @@ -945,12 +945,23 @@ void Pingponger::getDotPingponged() { //FIXME: place tile size restriction here and obtain kWidth kWidth = 16; - if (transformFP4(builder, dotSOps[0]->getLoc()).failed()) { - LDBG("Encountered failure when trying to execute the two ping pong " + if (dotSOps.size() == 1){ + auto dotSType = dotSOps[0].getType(); + auto dotSShape = dotSType.getShape(); + auto aType = dotSOps[0].getA().getType(); + auto aShape = aType.getShape(); + auto elemWidth = aType.getElementTypeBitWidth(); + int64_t tileSize = dotSShape[0] * dotSShape[1] * aShape[1]; + if(tileSize != 8388608 || aShape[1] != 128 || elemWidth != 8) + return; + + if (transformFP4(builder, dotSOps[0]->getLoc()).failed()) { + LDBG("Encountered failure when trying to execute the two ping pong " "cluster transformation"); - return; + return; + } + addAsymmetricSyncToLoop(builder, loc); } - addAsymmetricSyncToLoop(builder, loc); return; From f6065b93e96e1cdb5137e27f7b8f21199399cf54 Mon Sep 17 00:00:00 2001 From: Alexander Weinrauch Date: Tue, 13 May 2025 17:21:23 +0000 Subject: [PATCH 28/44] Revert "[AMD] Use v_permlane to optimize MFAM to linear layout on GFX950 (#6744)" This reverts commit 4ecc496ed82c0e2cfd721c6ee71d0b468975c7f9. --- .../TritonGPU/IR/LinearLayoutConversions.h | 3 +- .../TritonGPU/IR/LinearLayoutConversions.cpp | 19 +- test/Conversion/amd/mfma-shortcut.mlir | 229 ++++++++---------- .../ConvertLayoutOpToLLVM.cpp | 115 --------- .../OptimizeEpilogue.cpp | 20 +- 5 files changed, 129 insertions(+), 257 deletions(-) diff --git a/include/triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h b/include/triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h index b00eb5084112..6a6047216b71 100644 --- a/include/triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h +++ b/include/triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h @@ -287,7 +287,8 @@ LinearLayout chooseScaledMfmaScaleLayout( // 8 elements. This layout is useful for emitting the widest 128-bit global // store instructions. Since it closely resembles mfmaLayout, conversion between // the two can be done using transferWithinWarp, without involving LDS -std::optional chooseMfmaLikeStoreLayout(RankedTensorType valType); +LinearLayout chooseMfmaLikeStoreLayout(AMDMfmaEncodingAttr mfmaLayout, + ArrayRef shape); } // namespace mlir::triton::gpu #endif // TRITON_DIALECT_TRITONGPU_IR_LINEARLAYOUTCONVERSIONS_H diff --git a/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp b/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp index d1397549de27..22949a1489a3 100644 --- a/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp +++ b/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp @@ -1533,17 +1533,10 @@ LinearLayout chooseScaledMfmaScaleLayout( return newLL; } -std::optional -chooseMfmaLikeStoreLayout(RankedTensorType valType) { - auto mfmaLayout = cast(valType.getEncoding()); - - // Currently support transposed [B]F16 MFMA32x32 on CDNA4 - bool isMfma32 = mfmaLayout.getMDim() == 32 && mfmaLayout.getNDim() == 32; - Type elemType = valType.getElementType(); - if (!(valType.getRank() == 2 && (elemType.isF16() || elemType.isBF16()) && - mfmaLayout.getVersionMajor() == 4 && mfmaLayout.getIsTransposed() && - isMfma32)) - return {}; +LinearLayout chooseMfmaLikeStoreLayout(AMDMfmaEncodingAttr mfmaLayout, + ArrayRef shape) { + assert(shape.size() == 2 && mfmaLayout.getMDim() == 32 && + mfmaLayout.getNDim() == 32 && mfmaLayout.getIsTransposed()); MLIRContext *ctx = mfmaLayout.getContext(); StringAttr kRegister = S("register"); @@ -1568,8 +1561,8 @@ chooseMfmaLikeStoreLayout(RankedTensorType valType) { identityStandardND(kWarp, mfmaLayout.getWarpsPerCTA(), order); LinearLayout ctaLayout = mfma8Layout.transposeOuts(standardOutDims) * warpLayout.transposeOuts(standardOutDims); - mfma8Layout = combineCtaCgaWithShape(ctaLayout, mfmaLayout.getCTALayout(), - valType.getShape()); + mfma8Layout = + combineCtaCgaWithShape(ctaLayout, mfmaLayout.getCTALayout(), shape); return mfma8Layout; } diff --git a/test/Conversion/amd/mfma-shortcut.mlir b/test/Conversion/amd/mfma-shortcut.mlir index 94f1650c39de..0e64eed47040 100644 --- a/test/Conversion/amd/mfma-shortcut.mlir +++ b/test/Conversion/amd/mfma-shortcut.mlir @@ -1,14 +1,13 @@ -// RUN: triton-opt %s --tritongpu-reduce-data-duplication --allocate-shared-memory --convert-triton-amdgpu-to-llvm=arch="gfx942" -split-input-file | FileCheck %s --check-prefix=GFX942 -// RUN: triton-opt %s --tritongpu-reduce-data-duplication --allocate-shared-memory --convert-triton-amdgpu-to-llvm=arch="gfx950" -split-input-file | FileCheck %s --check-prefix=GFX950 +// RUN: triton-opt %s --tritongpu-reduce-data-duplication --allocate-shared-memory --convert-triton-amdgpu-to-llvm=arch="gfx942" -split-input-file | FileCheck %s #mfma = #ttg.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 16], isTransposed = true}> #dotop = #ttg.dot_op<{opIdx = 0, parent = #mfma, kWidth=4}> module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} { - // GFX942-LABEL: shortcut_mfma16 + // CHECK-LABEL: shortcut_mfma16 tt.func public @shortcut_mfma16(%arg0: tensor<16x16xf16, #mfma>) { - // GFX942-NOT: store - // GFX942-NOT: load - // GFX942: llvm.return + // CHECK-NOT: store + // CHECK-NOT: load + // CHECK: llvm.return %0 = ttg.convert_layout %arg0 : tensor<16x16xf16, #mfma> -> tensor<16x16xf16, #dotop> tt.return } @@ -19,11 +18,11 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.thr #mfma = #ttg.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 16], isTransposed = true}> #dotop = #ttg.dot_op<{opIdx = 0, parent = #mfma, kWidth=8}> module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} { - // GFX942-LABEL: no_shortcut_mfma16 + // CHECK-LABEL: no_shortcut_mfma16 tt.func public @no_shortcut_mfma16(%arg0: tensor<16x16xf16, #mfma>) { - // GFX942: store - // GFX942: load - // GFX942: llvm.return + // CHECK: store + // CHECK: load + // CHECK: llvm.return %0 = ttg.convert_layout %arg0 : tensor<16x16xf16, #mfma> -> tensor<16x16xf16, #dotop> tt.return } @@ -35,38 +34,38 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.thr #dotop0 = #ttg.dot_op<{opIdx = 0, parent = #mfma, kWidth=8}> module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} { - // GFX942-LABEL: mfma_dot_cvt_f8_mfma32 + // CHECK-LABEL: mfma_dot_cvt_f8_mfma32 tt.func public @mfma_dot_cvt_f8_mfma32(%arg0: tensor<128x32xf8E4M3FNUZ, #mfma>) { - // GFX942-NOT: store - // GFX942-NOT: load + // CHECK-NOT: store + // CHECK-NOT: load - // GFX942: [[val3:%.*]] = llvm.extractvalue %arg0[3] - // GFX942: [[val7:%.*]] = llvm.extractvalue %arg0[7] + // CHECK: [[val3:%.*]] = llvm.extractvalue %arg0[3] + // CHECK: [[val7:%.*]] = llvm.extractvalue %arg0[7] - // GFX942-DAG: [[c32:%.*]] = llvm.mlir.constant(32 : i32) - // GFX942-DAG: [[c64:%.*]] = llvm.mlir.constant(64 : i32) + // CHECK-DAG: [[c32:%.*]] = llvm.mlir.constant(32 : i32) + // CHECK-DAG: [[c64:%.*]] = llvm.mlir.constant(64 : i32) - // GFX942: [[threadId:%.*]] = rocdl.workitem.id.x - // GFX942: [[laneId:%.*]] = llvm.urem [[threadId]], [[c64]] - // GFX942: [[mask0:%.*]] = llvm.icmp "slt" [[laneId]], [[c32]] + // CHECK: [[threadId:%.*]] = rocdl.workitem.id.x + // CHECK: [[laneId:%.*]] = llvm.urem [[threadId]], [[c64]] + // CHECK: [[mask0:%.*]] = llvm.icmp "slt" [[laneId]], [[c32]] - // GFX942: [[shflLaneId:%.*]] = llvm.add [[laneId]], [[c32]] - // GFX942: [[addr32:%.*]] = llvm.urem [[shflLaneId]], [[c64]] + // CHECK: [[shflLaneId:%.*]] = llvm.add [[laneId]], [[c32]] + // CHECK: [[addr32:%.*]] = llvm.urem [[shflLaneId]], [[c64]] - // GFX942: [[vec0:%.*]] = llvm.insertelement [[val3]], {{.*}} : vector<4xi8> - // GFX942: [[vec1:%.*]] = llvm.insertelement [[val7]], {{.*}} : vector<4xi8> + // CHECK: [[vec0:%.*]] = llvm.insertelement [[val3]], {{.*}} : vector<4xi8> + // CHECK: [[vec1:%.*]] = llvm.insertelement [[val7]], {{.*}} : vector<4xi8> - // GFX942: [[bvec0:%.*]] = llvm.bitcast [[vec0]] - // GFX942: [[c2:%.*]] = llvm.mlir.constant(2 : i32) - // GFX942: [[addr:%.*]] = llvm.shl [[addr32]], [[c2]] - // GFX942: [[bShflVec0:%.*]] = rocdl.ds_bpermute [[addr]], [[bvec0]] - // GFX942: [[shflVec0:%.*]] = llvm.bitcast [[bShflVec0]] + // CHECK: [[bvec0:%.*]] = llvm.bitcast [[vec0]] + // CHECK: [[c2:%.*]] = llvm.mlir.constant(2 : i32) + // CHECK: [[addr:%.*]] = llvm.shl [[addr32]], [[c2]] + // CHECK: [[bShflVec0:%.*]] = rocdl.ds_bpermute [[addr]], [[bvec0]] + // CHECK: [[shflVec0:%.*]] = llvm.bitcast [[bShflVec0]] - // GFX942: [[bvec1:%.*]] = llvm.bitcast [[vec1]] - // GFX942: [[c2:%.*]] = llvm.mlir.constant(2 : i32) - // GFX942: [[addr:%.*]] = llvm.shl [[addr32]], [[c2]] - // GFX942: [[bShflVec1:%.*]] = rocdl.ds_bpermute [[addr]], [[bvec1]] - // GFX942: [[shflVec1:%.*]] = llvm.bitcast [[bShflVec1]] + // CHECK: [[bvec1:%.*]] = llvm.bitcast [[vec1]] + // CHECK: [[c2:%.*]] = llvm.mlir.constant(2 : i32) + // CHECK: [[addr:%.*]] = llvm.shl [[addr32]], [[c2]] + // CHECK: [[bShflVec1:%.*]] = rocdl.ds_bpermute [[addr]], [[bvec1]] + // CHECK: [[shflVec1:%.*]] = llvm.bitcast [[bShflVec1]] // Input (8 values): (vec0, vec1) // Output (8 values shuffled, '>> n' - take the value from (lane + n) % 64): @@ -74,18 +73,18 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.thr // lanes 0-31: (vec0 , vec0 >> 32) (mask0=1) // lanes 32-63: (vec1 >> 32, vec1 ) (mask0=0) - // GFX942: [[resVec0:%.*]] = llvm.select [[mask0]], [[vec0]], [[shflVec1]] - // GFX942: [[resVec1:%.*]] = llvm.select [[mask0]], [[shflVec0]], [[vec1]] + // CHECK: [[resVec0:%.*]] = llvm.select [[mask0]], [[vec0]], [[shflVec1]] + // CHECK: [[resVec1:%.*]] = llvm.select [[mask0]], [[shflVec0]], [[vec1]] - // GFX942: [[c3:%.*]] = llvm.mlir.constant(3 : i32) - // GFX942: [[resVal3:%.*]] = llvm.extractelement [[resVec0]][[[c3]] : i32] : vector<4xi8> - // GFX942: [[c3:%.*]] = llvm.mlir.constant(3 : i32) : i32 - // GFX942: [[resVal7:%.*]] = llvm.extractelement [[resVec1]][[[c3]] : i32] : vector<4xi8> + // CHECK: [[c3:%.*]] = llvm.mlir.constant(3 : i32) + // CHECK: [[resVal3:%.*]] = llvm.extractelement [[resVec0]][[[c3]] : i32] : vector<4xi8> + // CHECK: [[c3:%.*]] = llvm.mlir.constant(3 : i32) : i32 + // CHECK: [[resVal7:%.*]] = llvm.extractelement [[resVec1]][[[c3]] : i32] : vector<4xi8> - // GFX942: llvm.insertvalue [[resVal3]], {{.*}}[3] - // GFX942: llvm.insertvalue [[resVal7]], {{.*}}[7] + // CHECK: llvm.insertvalue [[resVal3]], {{.*}}[3] + // CHECK: llvm.insertvalue [[resVal7]], {{.*}}[7] - // GFX942: llvm.return + // CHECK: llvm.return %0 = ttg.convert_layout %arg0 : tensor<128x32xf8E4M3FNUZ, #mfma> -> tensor<128x32xf8E4M3FNUZ, #dotop0> tt.return } @@ -97,12 +96,12 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.thr #dotop0 = #ttg.dot_op<{opIdx = 0, parent = #mfma, kWidth=8}> module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} { - // GFX942-LABEL: mfma_dot_cvt_bf8_mfma32 + // CHECK-LABEL: mfma_dot_cvt_bf8_mfma32 tt.func public @mfma_dot_cvt_bf8_mfma32(%arg0: tensor<128x32xf8E5M2, #mfma>) { - // GFX942-NOT: store - // GFX942-NOT: load - // GFX942: rocdl.ds_bpermute - // GFX942: llvm.return + // CHECK-NOT: store + // CHECK-NOT: load + // CHECK: rocdl.ds_bpermute + // CHECK: llvm.return %0 = ttg.convert_layout %arg0 : tensor<128x32xf8E5M2, #mfma> -> tensor<128x32xf8E5M2, #dotop0> tt.return } @@ -114,61 +113,61 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.thr #dotop0 = #ttg.dot_op<{opIdx = 0, parent = #mfma, kWidth=8}> module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} { - // GFX942-LABEL: mfma_dot_cvt_f8_mfma16 + // CHECK-LABEL: mfma_dot_cvt_f8_mfma16 tt.func public @mfma_dot_cvt_f8_mfma16(%arg0: tensor<128x32xf8E4M3FNUZ, #mfma>) { - // GFX942-NOT: store - // GFX942-NOT: load + // CHECK-NOT: store + // CHECK-NOT: load - // GFX942: [[val3:%.*]] = llvm.extractvalue %arg0[3] - // GFX942: [[val7:%.*]] = llvm.extractvalue %arg0[7] + // CHECK: [[val3:%.*]] = llvm.extractvalue %arg0[3] + // CHECK: [[val7:%.*]] = llvm.extractvalue %arg0[7] - // GFX942-DAG: [[c16:%.*]] = llvm.mlir.constant(16 : i32) - // GFX942-DAG: [[c32:%.*]] = llvm.mlir.constant(32 : i32) - // GFX942-DAG: [[c48:%.*]] = llvm.mlir.constant(48 : i32) - // GFX942-DAG: [[c64:%.*]] = llvm.mlir.constant(64 : i32) + // CHECK-DAG: [[c16:%.*]] = llvm.mlir.constant(16 : i32) + // CHECK-DAG: [[c32:%.*]] = llvm.mlir.constant(32 : i32) + // CHECK-DAG: [[c48:%.*]] = llvm.mlir.constant(48 : i32) + // CHECK-DAG: [[c64:%.*]] = llvm.mlir.constant(64 : i32) - // GFX942: [[threadId:%.*]] = rocdl.workitem.id.x - // GFX942: [[laneId:%.*]] = llvm.urem [[threadId]], [[c64]] - // GFX942: [[mask0:%.*]] = llvm.icmp "slt" [[laneId]], [[c32]] + // CHECK: [[threadId:%.*]] = rocdl.workitem.id.x + // CHECK: [[laneId:%.*]] = llvm.urem [[threadId]], [[c64]] + // CHECK: [[mask0:%.*]] = llvm.icmp "slt" [[laneId]], [[c32]] - // GFX942: [[laneIdRem:%.*]] = llvm.urem [[laneId]], [[c32]] - // GFX942: [[mask1:%.*]] = llvm.icmp "slt" [[laneIdRem]], [[c16]] + // CHECK: [[laneIdRem:%.*]] = llvm.urem [[laneId]], [[c32]] + // CHECK: [[mask1:%.*]] = llvm.icmp "slt" [[laneIdRem]], [[c16]] - // GFX942: [[shflLaneId:%.*]] = llvm.add [[laneId]], [[c16]] - // GFX942: [[addr16:%.*]] = llvm.urem [[shflLaneId]], [[c64]] + // CHECK: [[shflLaneId:%.*]] = llvm.add [[laneId]], [[c16]] + // CHECK: [[addr16:%.*]] = llvm.urem [[shflLaneId]], [[c64]] - // GFX942: [[shflLaneId:%.*]] = llvm.add [[laneId]], [[c32]] - // GFX942: [[addr32:%.*]] = llvm.urem [[shflLaneId]], [[c64]] + // CHECK: [[shflLaneId:%.*]] = llvm.add [[laneId]], [[c32]] + // CHECK: [[addr32:%.*]] = llvm.urem [[shflLaneId]], [[c64]] - // GFX942: [[shflLaneId:%.*]] = llvm.add [[laneId]], [[c48]] - // GFX942: [[addr48:%.*]] = llvm.urem [[shflLaneId]], [[c64]] + // CHECK: [[shflLaneId:%.*]] = llvm.add [[laneId]], [[c48]] + // CHECK: [[addr48:%.*]] = llvm.urem [[shflLaneId]], [[c64]] - // GFX942: [[vec0:%.*]] = llvm.insertelement [[val3]], {{.*}} : vector<4xi8> - // GFX942: [[vec1:%.*]] = llvm.insertelement [[val7]], {{.*}} : vector<4xi8> + // CHECK: [[vec0:%.*]] = llvm.insertelement [[val3]], {{.*}} : vector<4xi8> + // CHECK: [[vec1:%.*]] = llvm.insertelement [[val7]], {{.*}} : vector<4xi8> - // GFX942: [[bvec0:%.*]] = llvm.bitcast [[vec0]] - // GFX942: [[c2:%.*]] = llvm.mlir.constant(2 : i32) - // GFX942: [[addr:%.*]] = llvm.shl [[addr16]], [[c2]] - // GFX942: [[bShflVec0_16:%.*]] = rocdl.ds_bpermute [[addr]], [[bvec0]] - // GFX942: [[shflVec0_16:%.*]] = llvm.bitcast [[bShflVec0_16]] + // CHECK: [[bvec0:%.*]] = llvm.bitcast [[vec0]] + // CHECK: [[c2:%.*]] = llvm.mlir.constant(2 : i32) + // CHECK: [[addr:%.*]] = llvm.shl [[addr16]], [[c2]] + // CHECK: [[bShflVec0_16:%.*]] = rocdl.ds_bpermute [[addr]], [[bvec0]] + // CHECK: [[shflVec0_16:%.*]] = llvm.bitcast [[bShflVec0_16]] - // GFX942: [[bvec0:%.*]] = llvm.bitcast [[vec0]] - // GFX942: [[c2:%.*]] = llvm.mlir.constant(2 : i32) - // GFX942: [[addr:%.*]] = llvm.shl [[addr32]], [[c2]] - // GFX942: [[bShflVec0_32:%.*]] = rocdl.ds_bpermute [[addr]], [[bvec0]] - // GFX942: [[shflVec0_32:%.*]] = llvm.bitcast [[bShflVec0_32]] + // CHECK: [[bvec0:%.*]] = llvm.bitcast [[vec0]] + // CHECK: [[c2:%.*]] = llvm.mlir.constant(2 : i32) + // CHECK: [[addr:%.*]] = llvm.shl [[addr32]], [[c2]] + // CHECK: [[bShflVec0_32:%.*]] = rocdl.ds_bpermute [[addr]], [[bvec0]] + // CHECK: [[shflVec0_32:%.*]] = llvm.bitcast [[bShflVec0_32]] - // GFX942: [[bvec1:%.*]] = llvm.bitcast [[vec1]] - // GFX942: [[c2:%.*]] = llvm.mlir.constant(2 : i32) - // GFX942: [[addr:%.*]] = llvm.shl [[addr32]], [[c2]] - // GFX942: [[bShflVec1_32:%.*]] = rocdl.ds_bpermute [[addr]], [[bvec1]] - // GFX942: [[shflVec1_32:%.*]] = llvm.bitcast [[bShflVec1_32]] + // CHECK: [[bvec1:%.*]] = llvm.bitcast [[vec1]] + // CHECK: [[c2:%.*]] = llvm.mlir.constant(2 : i32) + // CHECK: [[addr:%.*]] = llvm.shl [[addr32]], [[c2]] + // CHECK: [[bShflVec1_32:%.*]] = rocdl.ds_bpermute [[addr]], [[bvec1]] + // CHECK: [[shflVec1_32:%.*]] = llvm.bitcast [[bShflVec1_32]] - // GFX942: [[bvec1:%.*]] = llvm.bitcast [[vec1]] - // GFX942: [[c2:%.*]] = llvm.mlir.constant(2 : i32) - // GFX942: [[addr:%.*]] = llvm.shl [[addr48]], [[c2]] - // GFX942: [[bShflVec1_48:%.*]] = rocdl.ds_bpermute [[addr]], [[bvec1]] - // GFX942: [[shflVec1_48:%.*]] = llvm.bitcast [[bShflVec1_48]] + // CHECK: [[bvec1:%.*]] = llvm.bitcast [[vec1]] + // CHECK: [[c2:%.*]] = llvm.mlir.constant(2 : i32) + // CHECK: [[addr:%.*]] = llvm.shl [[addr48]], [[c2]] + // CHECK: [[bShflVec1_48:%.*]] = rocdl.ds_bpermute [[addr]], [[bvec1]] + // CHECK: [[shflVec1_48:%.*]] = llvm.bitcast [[bShflVec1_48]] // Input (8 values): (vec0, vec1) // Output (8 values shuffled, '>> n' - take the value from (lane + n) % 64): @@ -178,23 +177,23 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.thr // lanes 32-47: (vec1 >> 32, vec1 >> 48) (mask0=0, mask1=1) // lanes 48-63: (vec1 >> 48, vec1 ) (mask0=0, mask1=0) - // GFX942-DAG: [[mask0_true:%.*]] = llvm.select [[mask1]], [[vec0]], [[shflVec0_16]] : i1, vector<4xi8> - // GFX942-DAG: [[mask0_false:%.*]] = llvm.select [[mask1]], [[shflVec1_32]], [[shflVec1_48]] : i1, vector<4xi8> - // GFX942: [[resVec0:%.*]] = llvm.select [[mask0]], [[mask0_true]], [[mask0_false]] : i1, vector<4xi8> + // CHECK-DAG: [[mask0_true:%.*]] = llvm.select [[mask1]], [[vec0]], [[shflVec0_16]] : i1, vector<4xi8> + // CHECK-DAG: [[mask0_false:%.*]] = llvm.select [[mask1]], [[shflVec1_32]], [[shflVec1_48]] : i1, vector<4xi8> + // CHECK: [[resVec0:%.*]] = llvm.select [[mask0]], [[mask0_true]], [[mask0_false]] : i1, vector<4xi8> - // GFX942-DAG: [[mask0_true:%.*]] = llvm.select [[mask1]], [[shflVec0_16]], [[shflVec0_32]] : i1, vector<4xi8> - // GFX942-DAG: [[mask0_false:%.*]] = llvm.select [[mask1]], [[shflVec1_48]], [[vec1]] : i1, vector<4xi8> - // GFX942: [[resVec1:%.*]] = llvm.select [[mask0]], [[mask0_true]], [[mask0_false]] : i1, vector<4xi8> + // CHECK-DAG: [[mask0_true:%.*]] = llvm.select [[mask1]], [[shflVec0_16]], [[shflVec0_32]] : i1, vector<4xi8> + // CHECK-DAG: [[mask0_false:%.*]] = llvm.select [[mask1]], [[shflVec1_48]], [[vec1]] : i1, vector<4xi8> + // CHECK: [[resVec1:%.*]] = llvm.select [[mask0]], [[mask0_true]], [[mask0_false]] : i1, vector<4xi8> - // GFX942: [[c3:%.*]] = llvm.mlir.constant(3 : i32) - // GFX942: [[resVal3:%.*]] = llvm.extractelement [[resVec0]][[[c3]] : i32] : vector<4xi8> - // GFX942: [[c3:%.*]] = llvm.mlir.constant(3 : i32) : i32 - // GFX942: [[resVal7:%.*]] = llvm.extractelement [[resVec1]][[[c3]] : i32] : vector<4xi8> + // CHECK: [[c3:%.*]] = llvm.mlir.constant(3 : i32) + // CHECK: [[resVal3:%.*]] = llvm.extractelement [[resVec0]][[[c3]] : i32] : vector<4xi8> + // CHECK: [[c3:%.*]] = llvm.mlir.constant(3 : i32) : i32 + // CHECK: [[resVal7:%.*]] = llvm.extractelement [[resVec1]][[[c3]] : i32] : vector<4xi8> - // GFX942: llvm.insertvalue [[resVal3]], {{.*}}[3] - // GFX942: llvm.insertvalue [[resVal7]], {{.*}}[7] + // CHECK: llvm.insertvalue [[resVal3]], {{.*}}[3] + // CHECK: llvm.insertvalue [[resVal7]], {{.*}}[7] - // GFX942: llvm.return + // CHECK: llvm.return %0 = ttg.convert_layout %arg0 : tensor<128x32xf8E4M3FNUZ, #mfma> -> tensor<128x32xf8E4M3FNUZ, #dotop0> tt.return } @@ -206,27 +205,13 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.thr #dotop0 = #ttg.dot_op<{opIdx = 0, parent = #mfma, kWidth=8}> module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} { - // GFX942-LABEL: mfma_dot_cvt_bf8_mfma16 + // CHECK-LABEL: mfma_dot_cvt_bf8_mfma16 tt.func public @mfma_dot_cvt_bf8_mfma16(%arg0: tensor<128x32xf8E5M2, #mfma>) { - // GFX942-NOT: store - // GFX942-NOT: load - // GFX942: rocdl.ds_bpermute - // GFX942: llvm.return + // CHECK-NOT: store + // CHECK-NOT: load + // CHECK: rocdl.ds_bpermute + // CHECK: llvm.return %0 = ttg.convert_layout %arg0 : tensor<128x32xf8E5M2, #mfma> -> tensor<128x32xf8E5M2, #dotop0> tt.return } } - -// ----- - -#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [16, 4], warpsPerCTA = [4, 1], order = [0, 1]}> -#linear = #ttg.linear<{register = [[0, 1], [0, 2], [0, 4], [0, 16], [0, 32], [0, 64]], lane = [[1, 0], [2, 0], [4, 0], [8, 0], [16, 0], [0, 8]], warp = [[32, 0], [64, 0]], block = []}> -#mma = #ttg.amd_mfma<{versionMajor = 4, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [32, 32], isTransposed = true}> -module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} { - // GFX950-LABEL: mfma_linear_permlane_swap - tt.func public @mfma_linear_permlane_swap(%arg0: tensor<128x128xf16, #mma>) attributes {noinline = false} { - // GFX950-COUNT-16: llvm.call_intrinsic "llvm.amdgcn.permlane32.swap" - %1 = ttg.convert_layout %arg0: tensor<128x128xf16, #mma> -> tensor<128x128xf16, #linear> - tt.return - } -} diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM.cpp index 7af92231b4a1..ef55a3448950 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM.cpp @@ -8,7 +8,6 @@ using ::mlir::triton::gpu::AMDMfmaEncodingAttr; using ::mlir::triton::gpu::AMDWmmaEncodingAttr; using ::mlir::triton::gpu::DotOperandEncodingAttr; using ::mlir::triton::gpu::MemDescType; -using ::triton::gpu::LinearEncodingAttr; namespace SharedToDotOperandMFMA { Value convertLayout(int opIdx, ConversionPatternRewriter &rewriter, @@ -181,118 +180,6 @@ struct ConvertLayoutOpMFMAToDotOpConversion const TargetInfoBase &targetInfo; }; -// Match MFMA->Linear Layout conversion -static bool matchMFMAAndLinearLayoutCase(RankedTensorType srcTy, - RankedTensorType dstTy) { - auto mfmaLayout = dyn_cast(srcTy.getEncoding()); - auto linearLayout = dyn_cast(dstTy.getEncoding()); - if (!mfmaLayout || !linearLayout) - return false; - - std::optional srcLL = - mlir::triton::gpu::chooseMfmaLikeStoreLayout(srcTy); - if (!srcLL) - return false; - - MLIRContext *ctx = linearLayout.getContext(); - StringAttr kLane = StringAttr::get(ctx, "lane"); - StringAttr kRegister = StringAttr::get(ctx, "register"); - auto srcBase = srcLL.value().getBases(); - auto srcReg = srcBase.lookup(kRegister); - auto srcLane = srcBase.lookup(kLane); - auto dstBases = linearLayout.getLinearLayout().getBases(); - auto dstReg = dstBases.lookup(kRegister); - auto dstLane = dstBases.lookup(kLane); - return dstReg == srcReg && dstLane == srcLane; -}; - -struct ConvertLayoutOpMFMAToLinearConversion - : public ConvertOpToLLVMPattern { -public: - explicit ConvertLayoutOpMFMAToLinearConversion( - LLVMTypeConverter &typeConverter, const TargetInfoBase &targetInfo, - PatternBenefit benefit) - : ConvertOpToLLVMPattern(typeConverter, - benefit), - targetInfo(targetInfo) {} - - LogicalResult - matchAndRewrite(triton::gpu::ConvertLayoutOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto srcType = cast(op.getSrc().getType()); - auto dstType = cast(op.getType()); - - if (!matchMFMAAndLinearLayoutCase(srcType, dstType)) - return failure(); - - auto loc = op.getLoc(); - auto b = TritonLLVMOpBuilder(loc, rewriter); - - SmallVector inVals = - unpackLLElements(loc, adaptor.getSrc(), rewriter); - if (inVals.empty() || inVals.size() % 8 != 0) - return failure(); - - auto mfmaLayout = dyn_cast(srcType.getEncoding()); - assert(mfmaLayout.getMDim() == 32 && "Expected MFMA size 32"); - assert(triton::gpu::lookupThreadsPerWarp(rewriter) == 64 && - "Expected warp size 64 for MFMA"); - - auto elemTy = srcType.getElementType(); - auto vecTy = vec_ty(elemTy, 2); - - SmallVector outVals; - auto idx0 = b.i32_val(0); - auto idx1 = b.i32_val(1); - // Convert MFMA layout to a MFMA-like linear layout where each thread - // holds 8 consecutive elements - for (size_t idx = 0; idx < inVals.size(); idx += 8) { - SmallVector inVecs; - for (size_t vIdx = 0; vIdx < 4; vIdx++) { - Value vec = b.undef(vecTy); - vec = b.insert_element(vecTy, vec, inVals[idx + vIdx * 2 + 0], idx0); - vec = b.insert_element(vecTy, vec, inVals[idx + vIdx * 2 + 1], idx1); - inVecs.push_back(vec); - } - - Value resVec0, resVec1, resVec2, resVec3; - - // Swap the row 2 and 3 of vec0 and the row 0 and 1 of vec2 - MLIRContext *ctx = rewriter.getContext(); - Type retType = struct_ty({i32_ty, i32_ty}); - Value falseVal = b.false_val(); - Value perm = - LLVM::createLLVMIntrinsicCallOp( - rewriter, loc, "llvm.amdgcn.permlane32.swap", retType, - ValueRange{b.bitcast(inVecs[0], i32_ty), - b.bitcast(inVecs[2], i32_ty), falseVal, falseVal}) - ->getResult(0); - resVec0 = b.bitcast(b.extract_val(i32_ty, perm, 0), vecTy); - resVec2 = b.bitcast(b.extract_val(i32_ty, perm, 1), vecTy); - - // Swap the row 2 and 3 of vec1 and the row 0 and 1 of vec3 - perm = LLVM::createLLVMIntrinsicCallOp( - rewriter, loc, "llvm.amdgcn.permlane32.swap", retType, - ValueRange{b.bitcast(inVecs[1], i32_ty), - b.bitcast(inVecs[3], i32_ty), falseVal, falseVal}) - ->getResult(0); - resVec1 = b.bitcast(b.extract_val(i32_ty, perm, 0), vecTy); - resVec3 = b.bitcast(b.extract_val(i32_ty, perm, 1), vecTy); - - for (Value res : {resVec0, resVec1, resVec2, resVec3}) - for (Value idx : {idx0, idx1}) - outVals.push_back(b.extract_element(elemTy, res, idx)); - } - - Value result = packLLElements(loc, getTypeConverter(), outVals, rewriter, - op.getType()); - rewriter.replaceOp(op, result); - return success(); - } - -protected: - const TargetInfoBase &targetInfo; -}; } // namespace void mlir::triton::AMD::populateConvertLayoutOpToLLVMPatterns( @@ -300,6 +187,4 @@ void mlir::triton::AMD::populateConvertLayoutOpToLLVMPatterns( RewritePatternSet &patterns, PatternBenefit benefit) { patterns.add(typeConverter, targetInfo, benefit); - patterns.add(typeConverter, targetInfo, - benefit); } diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/OptimizeEpilogue.cpp b/third_party/amd/lib/TritonAMDGPUTransforms/OptimizeEpilogue.cpp index a14e42d2d93f..a613a54b79a0 100644 --- a/third_party/amd/lib/TritonAMDGPUTransforms/OptimizeEpilogue.cpp +++ b/third_party/amd/lib/TritonAMDGPUTransforms/OptimizeEpilogue.cpp @@ -69,17 +69,25 @@ static triton::StoreOp convertMfmaLayoutForCDNA4(PatternRewriter &rewriter, auto mfmaLayout = cast(valType.getEncoding()); - // Create a new layout where each thread holds 8 consecutive elements, in - // order to enable wide 128-bit global stores. - std::optional mfma8Layout = - triton::gpu::chooseMfmaLikeStoreLayout(valType); + bool mfma32 = mfmaLayout.getMDim() == 32 && mfmaLayout.getNDim() == 32; - if (!mfma8Layout) + if (valType.getRank() != 2 || + (!valType.getElementType().isF16() && + !valType.getElementType().isBF16()) || + mfmaLayout.getVersionMajor() != 4 || !mfmaLayout.getIsTransposed() || + !mfma32) { return rewriter.create(oldStOp.getLoc(), ptr, val, mask, oldStOp.getCache(), oldStOp.getEvict()); + } + + // Create a new layout where each thread holds 8 consecutive elements, in + // order to enable wide 128-bit global stores. + triton::LinearLayout mfma8Layout = + chooseMfmaLikeStoreLayout(mfmaLayout, valType.getShape()); + Attribute newEncoding = triton::gpu::LinearEncodingAttr::get( - mfmaLayout.getContext(), mfma8Layout.value()); + mfmaLayout.getContext(), mfma8Layout); auto newPtrType = RankedTensorType::get( ptrType.getShape(), ptrType.getElementType(), newEncoding); Value newPtr = rewriter.create(ptr.getLoc(), From d7e2e2c13b5eaa3bcb2fedf2972d95fecc8e68c5 Mon Sep 17 00:00:00 2001 From: Alexander Weinrauch Date: Tue, 13 May 2025 17:21:30 +0000 Subject: [PATCH 29/44] Revert "[BACKEND] bump to llvm/llvm-project@3c709802d31b (#6754)" This reverts commit f3076b136c62ba95423a1d49f228376bc3da800f. --- cmake/llvm-hash.txt | 2 +- .../TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp | 4 ++-- lib/Conversion/TritonGPUToLLVM/Utility.cpp | 4 ++-- test/Conversion/cvt_to_llvm.mlir | 2 +- .../amd/lib/TritonAMDGPUToLLVM/BuiltinFuncToLLVM.cpp | 2 +- .../amd/lib/TritonAMDGPUToLLVM/TritonGPUToLLVM.cpp | 12 ++---------- .../TritonAMDGPUTransforms/CanonicalizePointers.cpp | 2 +- .../lib/TritonNVIDIAGPUToLLVM/MemoryOpToLLVM.cpp | 4 ++-- .../nvidia/lib/TritonNVIDIAGPUToLLVM/TargetInfo.cpp | 4 ++-- 9 files changed, 14 insertions(+), 22 deletions(-) diff --git a/cmake/llvm-hash.txt b/cmake/llvm-hash.txt index 4f839b752cab..594fab8f36a1 100644 --- a/cmake/llvm-hash.txt +++ b/cmake/llvm-hash.txt @@ -1 +1 @@ -3c709802d31b5bc5ed3af8284b40593ff39b9eec +092b6e73e651469527662443b592f98f442ece72 diff --git a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp index 473f79170240..5f0368401a16 100644 --- a/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp +++ b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp @@ -294,8 +294,8 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion b.shl(b.lshr(offset, b.i32_val(rshiftVal)), b.i32_val(lshiftVal)), offset); } - auto vecAddr = b.gep(sharedPtrTy, elemTy, smemBase, offset, - LLVM::GEPNoWrapFlags::inbounds); + auto vecAddr = b.gep(sharedPtrTy, elemTy, smemBase, offset); + vecAddr.setInbounds(true); return vecAddr; }; diff --git a/lib/Conversion/TritonGPUToLLVM/Utility.cpp b/lib/Conversion/TritonGPUToLLVM/Utility.cpp index efe00265eb11..e09a08105926 100644 --- a/lib/Conversion/TritonGPUToLLVM/Utility.cpp +++ b/lib/Conversion/TritonGPUToLLVM/Utility.cpp @@ -398,8 +398,8 @@ Value getSmemVecAddr(const LinearLayout ®Layout, smemOffset = b.sub(smemOffset, baseToAllocBaseDist); } auto ptrTy = smemBase.getType(); - auto vecAddr = b.gep(ptrTy, elemLlvmTy, smemBase, smemOffset, - LLVM::GEPNoWrapFlags::inbounds); + auto vecAddr = b.gep(ptrTy, elemLlvmTy, smemBase, smemOffset); + vecAddr.setInbounds(true); return vecAddr; } diff --git a/test/Conversion/cvt_to_llvm.mlir b/test/Conversion/cvt_to_llvm.mlir index 5ec73e4c8a32..f577bc5af53e 100644 --- a/test/Conversion/cvt_to_llvm.mlir +++ b/test/Conversion/cvt_to_llvm.mlir @@ -48,7 +48,7 @@ tt.func private @convert_layout_blocked_blocked_vec(%arg0: tensor<16x16xi32, #bl // CHECK-DAG: [[X_MOD_2:%.*]] = and i32 [[TID]], 1 // CHECK-DAG: [[X_2_4_LOWER:%.*]] = shl {{.*}} i32 [[IS_UPPER_HALF]], 1 - // CHECK-DAG: [[X_2_4_UPPER0:%.*]] = shl {{.*}} i32 [[TID]], 1 + // CHECK-DAG: [[X_2_4_UPPER0:%.*]] = shl i32 [[TID]], 1 // CHECK-DAG: [[X_2_4_UPPER1:%.*]] = and i32 [[X_2_4_UPPER0]], 24 // CHECK-DAG: [[X_GE_16:%.*]] = and i32 [[TID]], 16 // CHECK-DAG: [[X_GE_16_2:%.*]] = lshr exact i32 [[X_GE_16]], 3 diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/BuiltinFuncToLLVM.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/BuiltinFuncToLLVM.cpp index 7fe495ff3dd5..9cec2cd8b51d 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/BuiltinFuncToLLVM.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/BuiltinFuncToLLVM.cpp @@ -204,7 +204,7 @@ struct ConvertBuiltinFuncToLLVM ModuleOp mod = getOperation(); GreedyRewriteConfig config; - config.setRegionSimplificationLevel(GreedySimplifyRegionLevel::Aggressive); + config.enableRegionSimplification = GreedySimplifyRegionLevel::Aggressive; RewritePatternSet patterns(context); patterns.add(context, this->ftz); diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/TritonGPUToLLVM.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/TritonGPUToLLVM.cpp index 1647f4b0680a..02014f732838 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/TritonGPUToLLVM.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/TritonGPUToLLVM.cpp @@ -11,7 +11,6 @@ #include "mlir/Conversion/MathToLLVM/MathToLLVM.h" #include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h" #include "mlir/Conversion/UBToLLVM/UBToLLVM.h" -#include "mlir/Dialect/AMDGPU/Utils/Chipset.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/LLVMIR/NVVMDialect.h" #include "mlir/Dialect/LLVMIR/ROCDLDialect.h" @@ -210,16 +209,9 @@ struct ConvertTritonAMDGPUToLLVM mlir::arith::populateArithToLLVMConversionPatterns(typeConverter, patterns); mlir::populateMathToLLVMConversionPatterns(typeConverter, patterns); - FailureOr maybeChipset = - mlir::amdgpu::Chipset::parse(this->arch); - if (failed(maybeChipset)) { - emitError(UnknownLoc::get(&getContext()), - "Invalid AMDGPU chipset name: " + this->arch); - return signalPassFailure(); - } // Native lowering patterns - mlir::populateGpuToROCDLConversionPatterns( - typeConverter, patterns, mlir::gpu::amd::HIP, *maybeChipset); + mlir::populateGpuToROCDLConversionPatterns(typeConverter, patterns, + mlir::gpu::amd::HIP); mlir::cf::populateControlFlowToLLVMConversionPatterns(typeConverter, patterns); diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/CanonicalizePointers.cpp b/third_party/amd/lib/TritonAMDGPUTransforms/CanonicalizePointers.cpp index 51d77ca942e7..0844cc941efb 100644 --- a/third_party/amd/lib/TritonAMDGPUTransforms/CanonicalizePointers.cpp +++ b/third_party/amd/lib/TritonAMDGPUTransforms/CanonicalizePointers.cpp @@ -13,7 +13,7 @@ #include "mlir/IR/TypeUtilities.h" #include "mlir/IR/Value.h" #include "mlir/Pass/Pass.h" -#include "mlir/Transforms/DialectConversion.h" +#include "mlir/Transforms/OneToNTypeConversion.h" #include "triton/Analysis/Utility.h" #include "triton/Dialect/Triton/IR/Dialect.h" #include "triton/Dialect/Triton/IR/Types.h" diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/MemoryOpToLLVM.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/MemoryOpToLLVM.cpp index 6b9659d9d4c2..577db1c0b543 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/MemoryOpToLLVM.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/MemoryOpToLLVM.cpp @@ -241,8 +241,8 @@ LogicalResult lowerDistributedToSharedStmatrix( for (int i = 0; i < srcVals.size(); i += step) { auto regIdx = reps.apply({{kReg, i}, {kLane, 0}, {kWarp, 0}})[0].second; Value offset = b.xor_(regBase, b.i32_val(regIdx)); - auto vecAddr = b.gep(smemPtrTy, llvmElemTy, smemBase, offset, - LLVM::GEPNoWrapFlags::inbounds); + auto vecAddr = b.gep(smemPtrTy, llvmElemTy, smemBase, offset); + vecAddr.setInbounds(true); SmallVector inValsVec; for (int j = 0; j < step; j++) inValsVec.push_back(srcVals[i + j]); diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TargetInfo.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TargetInfo.cpp index e3b5ef77b7cf..a4673738dc67 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TargetInfo.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TargetInfo.cpp @@ -230,7 +230,7 @@ void TargetInfo::storeDShared(RewriterBase &rewriter, Location loc, Value ptr, SmallVector vals = unpackLLVector(loc, val, rewriter); for (int i = 0; i < vec / maxVec; i++) { auto newPtr = b.gep(ptr.getType(), elemTy, ptr, b.i32_val(i * maxVec), - LLVM::GEPNoWrapFlags::inbounds); + /*inbounds=*/true); storeDShared( rewriter, loc, newPtr, ctaId, packLLVector(loc, ArrayRef(vals).slice(i * maxVec, maxVec), rewriter), @@ -343,7 +343,7 @@ Value TargetInfo::loadDShared(RewriterBase &rewriter, Location loc, Value ptr, SmallVector vals; for (int i = 0; i < vec / maxVec; i++) { auto newPtr = b.gep(ptr.getType(), elemTy, ptr, b.i32_val(i * maxVec), - LLVM::GEPNoWrapFlags::inbounds); + /*inbounds=*/true); auto newVal = loadDShared(rewriter, loc, newPtr, ctaId, vec_ty(elemTy, maxVec), pred); for (Value v : unpackLLVector(loc, newVal, rewriter)) { From 247f4f440cef46cf95130c23700f610080e6b441 Mon Sep 17 00:00:00 2001 From: Alexander Weinrauch Date: Wed, 14 May 2025 17:26:53 +0000 Subject: [PATCH 30/44] Revert because no longer needed: "[ASYNC_COPY] Remove MemoryEffect of BufferLoadToLocal to avoid implicit barrier from Membar" This reverts commit 012793ae2809884143a4e8dfdbdfd7f72c3382d0. --- .../amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUOps.td | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/third_party/amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUOps.td b/third_party/amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUOps.td index 093fedc56fb4..17d9409468d8 100644 --- a/third_party/amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUOps.td +++ b/third_party/amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUOps.td @@ -248,7 +248,7 @@ def BufferLoadToLocalOp : TT_AMDGPU_Op<"buffer_load_to_local", [ let description = [{ AMD Buffer load operation. Similar to amdgpu.buffer_load op but directly wirtes to shared memory instead of into registers. }]; let arguments = (ins - Arg:$dest, + Arg]>:$dest, Arg]>:$ptr, I32Tensor:$offsets, Optional:$mask, From 1028c8f60f8167eddf3e360a2f5c5fc46f48d7f9 Mon Sep 17 00:00:00 2001 From: Stanley Winata <68087699+raikonenfnu@users.noreply.github.com> Date: Thu, 15 May 2025 08:24:01 -0700 Subject: [PATCH 31/44] [AMD] Enable async pingpong for F16 GEMMs (#796) * [AMD] Generalize PingPong to have different type of Load/Store Ops This main motivation behind this commit is to add support for PingPong with AsyncOps. In order to accomplish that we made these changes: - Fork "determineDotMemoryOps" to "determineDotAsyncMemoryOps" that handles async memory ops. - Refactor validation and pruning of memory ops to "pruneDotMemoryOps" S.T we can have clean interface for it's async memory ops counterpart "pruneAsyncDotMemoryOps". - Plumb "useBlockPingpong" into StreamPipeliner S.T it can adjust AsyncWait stage/cluster to hoist first AsyncWait and allow set AsyncWait towards the end of the loop to make it easier for 4 PP cluster to move it before the 3rd dot-slice / 2 s_barrier before localLoads this is to ensure no race conditions. - Add check to enable handling of dotSOps (dot scaled) VS dotOps (dot) Signed-off-by: Stanley Winata Co-authored-by: Alexander Weinrauch --- include/triton/Tools/Sys/GetEnv.hpp | 1 + third_party/amd/backend/compiler.py | 2 +- .../include/TritonAMDGPUToLLVM/AsyncUtility.h | 12 + .../include/TritonAMDGPUTransforms/Passes.h | 3 +- .../include/TritonAMDGPUTransforms/Passes.td | 11 +- .../lib/TritonAMDGPUToLLVM/AsyncUtility.cpp | 62 +++ .../amd/lib/TritonAMDGPUToLLVM/CMakeLists.txt | 1 + .../lib/TritonAMDGPUToLLVM/MembarUtility.cpp | 45 +- .../amd/lib/TritonAMDGPUToLLVM/Utility.cpp | 5 +- .../TritonAMDGPUTransforms/BlockPingpong.cpp | 454 +++++++++++++----- .../TritonAMDGPUTransforms/StreamPipeline.cpp | 37 +- third_party/amd/python/triton_amd.cc | 4 +- 12 files changed, 467 insertions(+), 170 deletions(-) create mode 100644 third_party/amd/include/TritonAMDGPUToLLVM/AsyncUtility.h create mode 100644 third_party/amd/lib/TritonAMDGPUToLLVM/AsyncUtility.cpp diff --git a/include/triton/Tools/Sys/GetEnv.hpp b/include/triton/Tools/Sys/GetEnv.hpp index b07fb91f263b..8dcd2917b6b5 100644 --- a/include/triton/Tools/Sys/GetEnv.hpp +++ b/include/triton/Tools/Sys/GetEnv.hpp @@ -35,6 +35,7 @@ inline const std::set CACHE_INVALIDATING_ENV_VARS = { "TRITON_HIP_LOCAL_PREFETCH", "TRITON_HIP_USE_ASYNC_COPY", "TRITON_HIP_ASYNC_COPY_BYPASS_PERMUTE", + "TRITON_HIP_ENABLE_F16_ASYNC_PINGPONG", "TRITON_HIP_USE_BLOCK_PINGPONG", "TRITON_HIP_USE_IN_THREAD_TRANSPOSE", "TRITON_LLVM_DEBUG_ONLY", diff --git a/third_party/amd/backend/compiler.py b/third_party/amd/backend/compiler.py index 3fb38ac11be2..d2d6b2073c66 100644 --- a/third_party/amd/backend/compiler.py +++ b/third_party/amd/backend/compiler.py @@ -262,7 +262,7 @@ def make_ttgir(mod, metadata, options): amd.passes.ttgpuir.add_reorder_instructions(pm) use_block_pingpong = is_pingpong_schedule_enabled(options.arch) if use_block_pingpong and options.num_stages in [2, 4]: - amd.passes.ttgpuir.add_block_pingpong(pm, options.num_stages) + amd.passes.ttgpuir.add_block_pingpong(pm, options.num_stages, use_async_copy) if knobs.amd.use_buffer_ops: amd.passes.ttgpuir.add_canonicalize_pointers(pm) diff --git a/third_party/amd/include/TritonAMDGPUToLLVM/AsyncUtility.h b/third_party/amd/include/TritonAMDGPUToLLVM/AsyncUtility.h new file mode 100644 index 000000000000..82063e528ce6 --- /dev/null +++ b/third_party/amd/include/TritonAMDGPUToLLVM/AsyncUtility.h @@ -0,0 +1,12 @@ +#ifndef TRITON_THIRD_PARTY_AMD_INCLUDE_TRITONAMDGPUTOLLVM_ASYNCUTILITY_H_ +#define TRITON_THIRD_PARTY_AMD_INCLUDE_TRITONAMDGPUTOLLVM_ASYNCUTILITY_H_ + +#include "mlir/IR/Value.h" + +namespace mlir::triton::AMD { +// Traverses the def-chain including control flow of the token and returns true +// if all defining operations are an AsyncWait +bool comesFromAsyncWait(mlir::Value value); +} // namespace mlir::triton::AMD + +#endif diff --git a/third_party/amd/include/TritonAMDGPUTransforms/Passes.h b/third_party/amd/include/TritonAMDGPUTransforms/Passes.h index fccb65d061ab..9d48e1ffe208 100644 --- a/third_party/amd/include/TritonAMDGPUTransforms/Passes.h +++ b/third_party/amd/include/TritonAMDGPUTransforms/Passes.h @@ -34,7 +34,8 @@ std::unique_ptr createTritonAMDGPUConvertToBufferOpsPass( std::string archGenName = std::string()); std::unique_ptr -createTritonAMDGPUBlockPingpongPass(int32_t numStages = 2); +createTritonAMDGPUBlockPingpongPass(int32_t numStages = 2, + bool useAsyncCopy = false); std::unique_ptr createTritonAMDGPUInThreadTransposePass(); diff --git a/third_party/amd/include/TritonAMDGPUTransforms/Passes.td b/third_party/amd/include/TritonAMDGPUTransforms/Passes.td index 91bd40000222..9f9bf9cf7b0e 100644 --- a/third_party/amd/include/TritonAMDGPUTransforms/Passes.td +++ b/third_party/amd/include/TritonAMDGPUTransforms/Passes.td @@ -168,11 +168,12 @@ def TritonAMDGPUBlockPingpong: Pass<"tritonamdgpu-block-pingpong", "mlir::Module let dependentDialects = ["mlir::ROCDL::ROCDLDialect, mlir::triton::amdgpu::TritonAMDGPUDialect"]; - let options = [ - Option<"numStages", "num-stages", - "int32_t", /*default*/"2", - "Number of Pipeline stages">, - ]; + let options = + [Option<"numStages", "num-stages", "int32_t", /*default*/ "2", + "Number of Pipeline stages">, + Option<"useAsyncCopy", "use_async_copy", "bool", /*default*/ "false", + "Use AsyncCopyGlobalToLocal to directly load to shared memory">, + ]; } def TritonAMDGPUInThreadTranspose: Pass<"tritonamdgpu-in-thread-transpose", "mlir::triton::FuncOp"> { diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/AsyncUtility.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/AsyncUtility.cpp new file mode 100644 index 000000000000..484422021894 --- /dev/null +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/AsyncUtility.cpp @@ -0,0 +1,62 @@ +#include "third_party/amd/include/TritonAMDGPUToLLVM/AsyncUtility.h" +#include "Dialect/TritonAMDGPU/IR/Dialect.h" +#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" +#include "mlir/IR/Operation.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" + +namespace mlir::triton::AMD { + +// Traverses the def-chain including control flow of the token and returns true +// if all defining operations are an AsyncWait +bool comesFromAsyncWait(mlir::Value token) { + if (auto defOp = token.getDefiningOp()) { + if (isa(defOp)) + return true; + else if (auto castOp = dyn_cast(defOp)) + return comesFromAsyncWait(castOp.getInputs()[0]); + else + return false; + } + + auto blockArg = llvm::dyn_cast(token); + // If the token has no defining op and is not an BlockArgument bail out + if (!blockArg) { + return false; + } + + auto block = blockArg.getOwner(); + auto argId = blockArg.getArgNumber(); + + auto destOperandFromAsyncWait = [argId](auto &&operands) { + assert(argId < operands.size()); + return comesFromAsyncWait(operands[argId]); + }; + + // Check all predecessor block's terminator and follow the passed value at + // argId to see if they are immediately an AsyncWait. + for (auto *pred : block->getPredecessors()) { + auto terminator = pred->getTerminator(); + if (auto br = llvm::dyn_cast(terminator)) { + if (!destOperandFromAsyncWait(br.getDestOperands())) + return false; + } else if (auto condBr = llvm::dyn_cast(terminator)) { + if (condBr.getTrueDest() == block) { + if (!destOperandFromAsyncWait(condBr.getTrueDestOperands())) + return false; + } + if (condBr.getFalseDest() == block) { + if (!destOperandFromAsyncWait(condBr.getFalseDestOperands())) + return false; + } + } else if (auto br = llvm::dyn_cast(terminator)) { + if (!destOperandFromAsyncWait(br.getDestOperands())) + return false; + } else { + llvm::dbgs() << "no terminator!" << *terminator << "\n"; + return false; + } + } + return true; +} + +} // namespace mlir::triton::AMD diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/CMakeLists.txt b/third_party/amd/lib/TritonAMDGPUToLLVM/CMakeLists.txt index 2842cc76bfc4..4a3db488a21f 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/CMakeLists.txt +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/CMakeLists.txt @@ -1,4 +1,5 @@ add_triton_library(TritonAMDGPUToLLVM + AsyncUtility.cpp AtomicRMWOpsEmitter.cpp BufferOpsEmitter.cpp ConvertLayoutOpToLLVM/SharedToDotOperandHelper.cpp diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/MembarUtility.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/MembarUtility.cpp index 32c9f4c4c730..26673d320a33 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/MembarUtility.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/MembarUtility.cpp @@ -1,55 +1,12 @@ #include "third_party/amd/include/TritonAMDGPUToLLVM/MembarUtility.h" #include "Dialect/TritonAMDGPU/IR/Dialect.h" #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" +#include "third_party/amd/include/TritonAMDGPUToLLVM/AsyncUtility.h" #include "triton/Dialect/TritonGPU/IR/Dialect.h" namespace mlir::triton::AMD { namespace { -// Traverses the def-chain including control flow of the token and returns true -// if all defining operations are an AsyncWait -bool comesFromAsyncWait(Value token) { - if (auto defOp = token.getDefiningOp()) { - return isa(defOp); - } - - auto blockArg = dyn_cast(token); - // If the token has no defining op and is not an BlockArgument bail out - if (!blockArg) { - return false; - } - - auto block = blockArg.getOwner(); - auto argId = blockArg.getArgNumber(); - - auto destOperandFromAsyncWait = [argId](auto &&operands) { - assert(argId < operands.size()); - return comesFromAsyncWait(operands[argId]); - }; - - // Check all predecessor block's terminator and follow the passed value at - // argId to see if they are immediately an AsyncWait. - for (auto *pred : block->getPredecessors()) { - auto terminator = pred->getTerminator(); - if (auto br = dyn_cast(terminator)) { - if (!destOperandFromAsyncWait(br.getDestOperands())) - return false; - } else if (auto condBr = dyn_cast(terminator)) { - if (condBr.getTrueDest() == block) { - if (!destOperandFromAsyncWait(condBr.getTrueDestOperands())) - return false; - } - if (condBr.getFalseDest() == block) { - if (!destOperandFromAsyncWait(condBr.getFalseDestOperands())) - return false; - } - } else { - return false; - } - } - return true; -} - // Returns true if one of the operands is a LocalLoad synced via AsyncWait. bool filterAsyncLocalLoadsDeppendencies(Operation *op1, Operation *op2) { auto isAsyncLoad = [](Operation *op) { diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/Utility.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/Utility.cpp index 8ec4ff2de468..11ff747b6675 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/Utility.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/Utility.cpp @@ -4,12 +4,14 @@ #include "mlir/Dialect/LLVMIR/LLVMTypes.h" #include "mlir/Dialect/LLVMIR/ROCDLDialect.h" #include "mlir/IR/PatternMatch.h" +#include "third_party/amd/include/TritonAMDGPUToLLVM/AsyncUtility.h" #include "triton/Conversion/TritonGPUToLLVM/Utility.h" #include "triton/Dialect/Triton/IR/Dialect.h" #include "triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h" namespace tt = mlir::triton; using mlir::triton::ModuleAxisInfoAnalysis; +using mlir::triton::AMD::comesFromAsyncWait; using mlir::triton::AMD::DppCtrl; using mlir::triton::AMD::ISAFamily; using mlir::triton::gpu::appendOrGetExternFuncOp; @@ -734,8 +736,9 @@ void addAsyncCopyAliasScope(AliasAnalysisOpInterface directToLdsOp) { void addLocalLoadNoAliasScope(triton::gpu::LocalLoadOp localLoadOp, AliasAnalysisOpInterface llLoadOp) { auto token = localLoadOp.getToken(); - if (!token || !token.getDefiningOp()) + if (!token || !comesFromAsyncWait(token)) { return; + } return addLocalLoadNoAliasScope(llLoadOp); } diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/BlockPingpong.cpp b/third_party/amd/lib/TritonAMDGPUTransforms/BlockPingpong.cpp index 82c5cddf32aa..90ec3d8d83a3 100644 --- a/third_party/amd/lib/TritonAMDGPUTransforms/BlockPingpong.cpp +++ b/third_party/amd/lib/TritonAMDGPUTransforms/BlockPingpong.cpp @@ -39,6 +39,7 @@ class Pingponger { SmallVector lStoreOps; SmallVector asyncCopyOps; SmallVector asyncWaitOps; + SmallVector asyncCommitOps; SmallVector dotOps; SmallVector dotSOps; SmallVector> subViewOps; @@ -59,10 +60,13 @@ class Pingponger { int32_t kWidth; int32_t numWarps; int32_t numStages; + bool useAsyncCopy; public: - Pingponger(scf::ForOp forOp, int32_t numWarps, int32_t numStages) - : forOp(forOp), numWarps(numWarps), numStages(numStages) {} + Pingponger(scf::ForOp forOp, int32_t numWarps, int32_t numStages, + bool useAsyncCopy) + : forOp(forOp), numWarps(numWarps), numStages(numStages), + useAsyncCopy(useAsyncCopy) {} void getDotPingponged(); private: @@ -72,13 +76,12 @@ class Pingponger { Attribute dotEncoding, unsigned opIdx, unsigned numSlices, int64_t sliceWidth); LogicalResult genLocalSliceScales(OpBuilder &builder, Value v, - Attribute dotEncoding, unsigned opIdx, - unsigned numSlices, int64_t sliceWidth); + Attribute dotEncoding, unsigned opIdx, + unsigned numSlices, int64_t sliceWidth); LogicalResult sliceDot(OpBuilder &builder, Location loc, tt::DotOp op, unsigned numSlices); LogicalResult sliceDotScaled(OpBuilder &builder, Location loc, tt::DotScaledOp op, unsigned numSlices); - void transformOnePPClusters(OpBuilder &builder, Location loc); LogicalResult transformFourPPClusters(OpBuilder &builder, Location loc); LogicalResult transformTwoPPClusters(OpBuilder &builder, Location loc); @@ -103,6 +106,21 @@ class Pingponger { DenseSet &dotGlobalLoads, DenseSet &dotLocalLoads, DenseSet &dotLocalStores); + LogicalResult pruneDotMemoryOps(DenseSet &dotGlobalLoads, + DenseSet &dotLocalLoads, + DenseSet &dotLocalStores, + bool assumeNotTaken); + void determineDotAsyncMemoryOps( + tt::DotOp dotOp, + DenseSet &dotAsyncGlobalLoads, + DenseSet &dotLocalLoads, + DenseSet &dotAsyncCommitGroups, + DenseSet &dotAsyncWaits); + LogicalResult pruneDotAsyncMemoryOps( + DenseSet &dotGlobalLoads, + DenseSet &dotLocalLoads, + DenseSet &dotAsyncCommitGroups, + DenseSet &dotAsyncWaits, bool assumeNotTaken); template void findClosestPredOps(Value v, DenseSet &matchingOps); @@ -373,6 +391,208 @@ void Pingponger::determineDotMemoryOps( findClosestPredOps(localStore.getSrc(), dotGlobalLoads); } +// Populate the dotAsyncGlobalLoads, dotLocalLoads, dotAsyncCommitGroups, and +// dotAsyncWaits set with any loads that are generated by the current dot +// product. This occurs in steps to: +// 1. Determine which loads are generated by the dot product via getA() +// and getB(). +// 2. Determine which asyncCopyGlobalToLcals are used to populate the +// inputs to the local loads. +// 3. Determine which async commit are using asyncCopyGlobalToLcals. +// 4. Determine which async waits are consuming async commits +// Note: This function currently depends on num_stages=2, which is a +// precondition for the pingpong scheduling. +void Pingponger::determineDotAsyncMemoryOps( + tt::DotOp dotOp, + DenseSet &dotAsyncGlobalLoads, + DenseSet &dotLocalLoads, + DenseSet &dotAsyncCommitGroups, + DenseSet &dotAsyncWaits) { + // Find the locals loads used to compute the dot inputs. These + // must come before the dot op. + findClosestPredOps(dotOp.getA(), dotLocalLoads); + findClosestPredOps(dotOp.getB(), dotLocalLoads); + + // Determine the local stores from the local loads. + // With pipelining we expect this to be a single local + // store within the loop based on a block argument after routing through + // a ttg.MemDescSubviewOp. + DenseSet subviews; + for (auto &&localLoad : dotLocalLoads) + findClosestPredOps(localLoad.getSrc(), subviews); + + for (auto &&subview : subviews) { + for (auto &&user : subview->getUsers()) { + if (auto globalLoad = dyn_cast(user)) { + if (!globalLoad->hasOneUse()) + continue; + auto asyncCommitGroup = + dyn_cast(*globalLoad->getUsers().begin()); + if (!asyncCommitGroup) + continue; + + dotAsyncGlobalLoads.insert(globalLoad); + dotAsyncCommitGroups.insert(asyncCommitGroup); + } + } + } + + // Looks for AsyncWaitOp, which after StreamPipeliner should be + // located/consumed by the iter arg which represent the AsyncCommits. + for (auto &&asyncCommitGroup : dotAsyncCommitGroups) { + if (!asyncCommitGroup->hasOneUse()) + return; + auto asyncWaitOp = + dyn_cast(*asyncCommitGroup->getUsers().begin()); + if (!asyncWaitOp) + return; + dotAsyncWaits.insert(asyncWaitOp); + } +} + +LogicalResult +Pingponger::pruneDotMemoryOps(DenseSet &dotGlobalLoads, + DenseSet &dotLocalLoads, + DenseSet &dotLocalStores, + bool assumeNotTaken) { + // Prune Memory operations that may be moved to only those involved in dot + // computation. To understand the "cluster assumptions" we also estimate + // the impact of any additional loads/stores. + auto gLoadIt = std::stable_partition( + gLoadOps.begin(), gLoadOps.end(), + [&dotGlobalLoads](tt::LoadOp op) { return dotGlobalLoads.contains(op); }); + auto lLoadIt = std::stable_partition(lLoadOps.begin(), lLoadOps.end(), + [&dotLocalLoads](ttg::LocalLoadOp op) { + return dotLocalLoads.contains(op); + }); + auto lStoreIt = + std::stable_partition(lStoreOps.begin(), lStoreOps.end(), + [&dotLocalStores](ttg::LocalStoreOp op) { + return dotLocalStores.contains(op); + }); + + if (estimateNonDotMemoryImpact(gLoadIt, gLoadOps.end(), + assumeNotTaken) != 0) { + std::stringstream message; + message << "Unable to match ping pong scheduling pattern. Details: " + << "Non-dot global loads found in non-persistent GEMM"; + LDBG(message.str()); + return failure(); + } + if (estimateNonDotMemoryImpact(lLoadIt, lLoadOps.end(), + assumeNotTaken) != 0) { + std::stringstream message; + message << "Unable to match ping pong scheduling pattern. Details: " + << "Non-dot local loads found in non-persistent GEMM"; + LDBG(message.str()); + return failure(); + } + if (estimateNonDotMemoryImpact(lStoreIt, lStoreOps.end(), + assumeNotTaken) != 0) { + std::stringstream message; + message << "Unable to match ping pong scheduling pattern. Details: " + << "Non-dot local stores found in non-persistent GEMM"; + LDBG(message.str()); + return failure(); + } + + // Remove non-dot memory operations. + gLoadOps.erase(gLoadIt, gLoadOps.end()); + lLoadOps.erase(lLoadIt, lLoadOps.end()); + lStoreOps.erase(lStoreIt, lStoreOps.end()); + // All PingPong Scheduler assumes there are 2 movable global loads and 2 + // movable local loads. + if (gLoadOps.size() != 2 || lLoadOps.size() != 2) { + std::stringstream message; + message << "Unable to match ping pong slicing pattern. Details: " + << gLoadOps.size() << " global loads in dot computation, " + << lLoadOps.size() << " local loads in dot computation"; + LDBG(message.str()); + return failure(); + } + return success(); +} + +LogicalResult Pingponger::pruneDotAsyncMemoryOps( + DenseSet &dotGlobalLoads, + DenseSet &dotLocalLoads, + DenseSet &dotAsyncCommitGroups, + DenseSet &dotAsyncWaits, bool assumeNotTaken) { + // Prune Memory operations that may be moved to only those involved in dot + // computation. To understand the "cluster assumptions" we also estimate + // the impact of any additional loads/stores. + auto asyncCopyIt = std::stable_partition( + asyncCopyOps.begin(), asyncCopyOps.end(), + [&dotGlobalLoads](ttg::AsyncCopyGlobalToLocalOp op) { + return dotGlobalLoads.contains(op); + }); + auto lLoadIt = std::stable_partition(lLoadOps.begin(), lLoadOps.end(), + [&dotLocalLoads](ttg::LocalLoadOp op) { + return dotLocalLoads.contains(op); + }); + auto asyncCommitIt = std::stable_partition( + asyncCommitOps.begin(), asyncCommitOps.end(), + [&dotAsyncCommitGroups](ttg::AsyncCommitGroupOp op) { + return dotAsyncCommitGroups.contains(op); + }); + auto asyncWaitIt = + std::stable_partition(asyncWaitOps.begin(), asyncWaitOps.end(), + [&dotAsyncWaits](ttg::AsyncWaitOp op) { + return dotAsyncWaits.contains(op); + }); + + if (estimateNonDotMemoryImpact( + asyncCopyIt, asyncCopyOps.end(), assumeNotTaken) != 0) { + std::stringstream message; + message << "Unable to match ping pong scheduling pattern. Details: " + << "Non-dot global loads found in non-persistent GEMM"; + LDBG(message.str()); + return failure(); + } + if (estimateNonDotMemoryImpact(lLoadIt, lLoadOps.end(), + assumeNotTaken) != 0) { + std::stringstream message; + message << "Unable to match ping pong scheduling pattern. Details: " + << "Non-dot local loads found in non-persistent GEMM"; + LDBG(message.str()); + return failure(); + } + if (estimateNonDotMemoryImpact( + asyncCommitIt, asyncCommitOps.end(), assumeNotTaken) != 0) { + std::stringstream message; + message << "Unable to match ping pong scheduling pattern. Details: " + << "Non-dot local stores found in non-persistent GEMM"; + LDBG(message.str()); + return failure(); + } + if (estimateNonDotMemoryImpact( + asyncWaitIt, asyncWaitOps.end(), assumeNotTaken) != 0) { + std::stringstream message; + message << "Unable to match ping pong scheduling pattern. Details: " + << "Non-dot local stores found in non-persistent GEMM"; + LDBG(message.str()); + return failure(); + } + + // Remove non-dot memory operations. + asyncCopyOps.erase(asyncCopyIt, asyncCopyOps.end()); + lLoadOps.erase(lLoadIt, lLoadOps.end()); + asyncCommitOps.erase(asyncCommitIt, asyncCommitOps.end()); + asyncWaitOps.erase(asyncWaitIt, asyncWaitOps.end()); + // All PingPong Scheduler assumes there are 2 movable global loads and 2 + // movable local loads. + if (asyncCopyOps.size() != 2 || lLoadOps.size() != 2 || + asyncWaitOps.size() != 2) { + std::stringstream message; + message << "Unable to match ping pong slicing pattern. Details: " + << asyncCopyOps.size() << " global loads in dot computation, " + << lLoadOps.size() << " local loads in dot computation"; + LDBG(message.str()); + return failure(); + } + return success(); +} + // Transform a loop into one Dot - Memory (ping - pong) clusters // Each cluster, especially the Dot cluster is guarded with setprio(1->0) so // each warp can complete the execution of the cluster without being @@ -503,8 +723,8 @@ LogicalResult Pingponger::genLocalSliceHelper(OpBuilder &builder, Value v, } Value newSmem = builder.create( v.getLoc(), subviewDescType, memDesc, offsetsVal); - Value prefetchSlice = - builder.create(v.getLoc(), tensorType, newSmem, waitToken); + Value prefetchSlice = builder.create( + v.getLoc(), tensorType, newSmem, waitToken); subviews.push_back(newSmem.getDefiningOp()); slices.push_back(prefetchSlice.getDefiningOp()); } @@ -525,7 +745,7 @@ LogicalResult Pingponger::sliceDot(OpBuilder &builder, Location loc, if (shapeB[0] % numSlices != 0) return failure(); genOffsetConstants(loc, builder, numSlices, sliceWidth); - builder.setInsertionPointAfter(gLoadOps[0]); + builder.setInsertionPointAfter(useAsyncCopy ? asyncCopyOps[0] : gLoadOps[0]); auto dotEncoding = op.getType().getEncoding(); if (genLocalSlice(builder, op.getA(), dotEncoding, 0, numSlices, sliceWidth) .failed() || @@ -650,13 +870,14 @@ LogicalResult Pingponger::transformFourPPClusters(OpBuilder &builder, // First, slice local_loads and dot into 4 parts if (sliceDot(builder, loc, dotOps[0], 4).failed()) return failure(); - builder.setInsertionPointAfter(gLoadOps[1]); + Operation *gLoadRhs = useAsyncCopy ? asyncCopyOps[1] : gLoadOps[1]; + builder.setInsertionPointAfter(gLoadRhs); // Reorder operations into four mem/dot clusters // mem0: global load A, local load A(1/4), local load B(1/4) // set insertion point at the last global_load where all the addresses are // ready to be used. - updateOpInsertion(gLoadOps[1]); + updateOpInsertion(gLoadRhs); appendSlicedLoadAB(/*slice=*/0); appendClusterBarrier(builder, loc); @@ -665,7 +886,10 @@ LogicalResult Pingponger::transformFourPPClusters(OpBuilder &builder, appendClusterBarrier(builder, loc); // mem1: global load B, local load A(2/4), local load B(2/4) - appendOp(gLoadOps[1]); + appendOp(gLoadRhs); + if (useAsyncCopy) { + appendOp(asyncCommitOps[1]); + } appendSlicedLoadAB(/*slice=*/1); appendClusterBarrier(builder, loc); @@ -686,10 +910,14 @@ LogicalResult Pingponger::transformFourPPClusters(OpBuilder &builder, // Matmul kernels may use the output of the dot product in another operation // before the local store (e.g. persistent matmul epilogue). To accommodate // such cases, we need to move the local store up in the loop. - moveOpAndPredecessorsUpSameBlock(lStoreOps[0]); - moveOpAndPredecessorsUpSameBlock(lStoreOps[1]); - appendClusterBarrier(builder, loc); - + if (!useAsyncCopy) { + moveOpAndPredecessorsUpSameBlock(lStoreOps[0]); + moveOpAndPredecessorsUpSameBlock(lStoreOps[1]); + appendClusterBarrier(builder, loc); + } else { + appendOp(asyncWaitOps[0]); + appendOp(asyncWaitOps[1]); + } // dot3 (4/4) appendOpWithPrio(builder, dotSliceOps[3], loc); @@ -804,10 +1032,9 @@ LogicalResult Pingponger::transformFAv3(OpBuilder &builder, Location loc) { LogicalResult Pingponger::transformFP4(OpBuilder &builder, Location loc) { - builder.setInsertionPointAfter(forOp); - //FIXME: This is duplicated code, need to refactorize. + // FIXME: This is duplicated code, need to refactorize. auto i32ty = builder.getIntegerType(32); auto workIDX = builder.create(loc, i32ty); workIDX->moveBefore(forOp); @@ -819,8 +1046,6 @@ LogicalResult Pingponger::transformFP4(OpBuilder &builder, Location loc) { warpIDX, constZero); auto warpHigh = builder.create(loc, arith::CmpIPredicate::ne, warpIDX, constZero); - - builder.setInsertionPointAfter(dotSOps[0]); @@ -831,10 +1056,10 @@ LogicalResult Pingponger::transformFP4(OpBuilder &builder, Location loc) { appendOp(builder.create(loc, 0)); appendOp(builder.create(loc, warpLow)); appendOp(builder.create(loc, 0)); - for (int j=0; j<4; j++){ - for (int i=0; i<4; i++) + for (int j = 0; j < 4; j++) { + for (int i = 0; i < 4; i++) appendOp(subViewOps[i][j]); - for (int i=0; i<4; i++) + for (int i = 0; i < 4; i++) appendOp(loadSliceOps[i][j]); appendOp(builder.create(loc, 0)); appendOp(dotSliceOps[j]); @@ -843,7 +1068,6 @@ LogicalResult Pingponger::transformFP4(OpBuilder &builder, Location loc) { appendOp(builder.create(loc, 0)); appendOp(builder.create(loc, warpHigh)); - return success(); } @@ -894,6 +1118,8 @@ void Pingponger::getDotPingponged() { forOp->walk([&](Operation *op) { if (auto gLoad = dyn_cast(op)) gLoadOps.push_back(gLoad); + if (auto asyncCopy = dyn_cast(op)) + asyncCopyOps.push_back(asyncCopy); else if (auto lLoad = dyn_cast(op)) { // This scheduling doesn't help hiding intra-warp latency. So, we only // collect local_load ops that are software pipelined, which means their @@ -912,6 +1138,9 @@ void Pingponger::getDotPingponged() { dotSOps.push_back(pingpongDot); } else if (auto asyncOp = dyn_cast(op)) { asyncCopyOps.push_back(asyncOp); + } else if (auto asyncCommitGroupOp = + dyn_cast(op)) { + asyncCommitOps.push_back(asyncCommitGroupOp); } else if (auto asyncOp = dyn_cast(op)) asyncWaitOps.push_back(asyncOp); }); @@ -933,42 +1162,112 @@ void Pingponger::getDotPingponged() { // supported combination of operations because this transformation is very // tightly scheduling the latencies. - //FIXME: get better condition to enable pingpong either for dot or for dot_scaled - if ((dotSOps.size() != 1) && (gLoadOps.size() < 2 || lLoadOps.size() < 2 || dotSOps.size() != 1)){ + // FIXME: get better condition to enable pingpong either for dot or for + // dot_scaled + int64_t numOfDotLikeOps = dotSOps.size() + dotOps.size(); + if (numOfDotLikeOps != 1) { + LDBG("Only handle a single of either dot or dot_scaled op"); + return; + } + int64_t gloadSize = useAsyncCopy ? asyncCopyOps.size() : gLoadOps.size(); + int64_t dotSize = dotSOps.size() > 0 ? dotSOps.size() : dotOps.size(); + if ((gloadSize < 2 || lLoadOps.size() < 2 || dotSize != 1)) { std::stringstream message; message << "Unable to match ping pong scheduling pattern. Details: " - << gLoadOps.size() << " global loads, " << lLoadOps.size() - << " local loads, " << dotOps.size() << " dot products"; + << gloadSize << " global loads, " << lLoadOps.size() + << " local loads, " << dotSize << " dot products"; LDBG(message.str()); return; } - //FIXME: place tile size restriction here and obtain kWidth - kWidth = 16; - if (dotSOps.size() == 1){ + // FIXME: place tile size restriction here and obtain kWidth + if (dotSOps.size() == 1) { + kWidth = 16; auto dotSType = dotSOps[0].getType(); auto dotSShape = dotSType.getShape(); auto aType = dotSOps[0].getA().getType(); auto aShape = aType.getShape(); auto elemWidth = aType.getElementTypeBitWidth(); int64_t tileSize = dotSShape[0] * dotSShape[1] * aShape[1]; - if(tileSize != 8388608 || aShape[1] != 128 || elemWidth != 8) + if (tileSize != 8388608 || aShape[1] != 128 || elemWidth != 8) return; if (transformFP4(builder, dotSOps[0]->getLoc()).failed()) { LDBG("Encountered failure when trying to execute the two ping pong " - "cluster transformation"); + "cluster transformation"); return; } addAsymmetricSyncToLoop(builder, loc); + return; } - return; - // Determine if we have a persistent GEMM. This will decide how we interpret // any memory operations that we find in conditionals. auto assumeNotTaken = isPersistentGemm(dotOps.size()); + // Compute tile size, kWidth, and mfma type. + auto dotType = dotOps[0].getType(); + auto dotShape = dotType.getShape(); + auto aType = dotOps[0].getA().getType(); + auto aShape = aType.getShape(); + auto elemWidth = aType.getElementTypeBitWidth(); + int64_t tileSize = dotShape[0] * dotShape[1] * aShape[1] * elemWidth; + + const int64_t minTile = 262144; // e.g. 32x128x64x16bit + const int64_t smallTile = 16777216; // e.g. 128x128x64x16bit + const int64_t mediumTile = 33554432; // smallTile x 2 + const int64_t largeTile = 67108864; // e.g. 256x256x64x16bit + + auto encoding = cast(aType).getEncoding(); + auto srcEncoding = cast(encoding); + kWidth = srcEncoding.getKWidth(); + auto mfmaEncoding = cast(srcEncoding.getParent()); + SmallVector intShape; + intShape.push_back(mfmaEncoding.getMDim()); + intShape.push_back(mfmaEncoding.getNDim()); + + if (dotOps.size() == 1 && useAsyncCopy) { + if (numWarps != 8) { + LDBG("Currently only support num_warp=8 for async PP"); + return; + } + if (tileSize != largeTile || aShape[1] != 64 || elemWidth != 16) { + LDBG("Only support tile size of 256x256x64 tile size for async PP"); + return; + } + + auto encoding = cast(aType).getEncoding(); + auto srcEncoding = cast(encoding); + kWidth = srcEncoding.getKWidth(); + auto mfmaEncoding = cast(srcEncoding.getParent()); + if (mfmaEncoding.getMDim() != 16 && mfmaEncoding.getNDim() != 16 && + kWidth != 8) { + LDBG("Only support 16x16 intrinsic and kWidth=8 for async PP"); + } + + DenseSet dotGlobalLoads; + DenseSet dotLocalLoads; + DenseSet dotAsyncCommitGroups; + DenseSet dotAsyncWaits; + determineDotAsyncMemoryOps(dotOps[0], dotGlobalLoads, dotLocalLoads, + dotAsyncCommitGroups, dotAsyncWaits); + if (failed(pruneDotAsyncMemoryOps(dotGlobalLoads, dotLocalLoads, + dotAsyncCommitGroups, dotAsyncWaits, + assumeNotTaken))) { + std::stringstream message; + message << "Failed to match ping pong scheduling pattern and prune async " + "memory ops."; + LDBG(message.str()); + return; + } + if (transformFourPPClusters(builder, dotOps[0]->getLoc()).failed()) { + LDBG("Encountered failure when trying to execute the four ping pong " + "cluster transformation"); + return; + } + addAsymmetricSyncToLoop(builder, loc); + return; + } // The existing code depends on the loads being targeted being safe to move, // which will not hold if we do not properly have a GEMM. As a result, we // filter the associated load operations to only those that are associated @@ -978,58 +1277,11 @@ void Pingponger::getDotPingponged() { DenseSet dotLocalStores; determineDotMemoryOps(dotOps[0], dotGlobalLoads, dotLocalLoads, dotLocalStores); - - // Prune Memory operations that may be moved to only those involved in dot - // computation. To understand the "cluster assumptions" we also estimate - // the impact of any additional loads/stores. - auto gLoadIt = std::stable_partition( - gLoadOps.begin(), gLoadOps.end(), - [&dotGlobalLoads](tt::LoadOp op) { return dotGlobalLoads.contains(op); }); - auto lLoadIt = std::stable_partition(lLoadOps.begin(), lLoadOps.end(), - [&dotLocalLoads](ttg::LocalLoadOp op) { - return dotLocalLoads.contains(op); - }); - auto lStoreIt = - std::stable_partition(lStoreOps.begin(), lStoreOps.end(), - [&dotLocalStores](ttg::LocalStoreOp op) { - return dotLocalStores.contains(op); - }); - if (estimateNonDotMemoryImpact(gLoadIt, gLoadOps.end(), - assumeNotTaken) != 0) { - std::stringstream message; - message << "Unable to match ping pong scheduling pattern. Details: " - << "Non-dot global loads found in non-persistent GEMM"; - LDBG(message.str()); - return; - } - if (estimateNonDotMemoryImpact(lLoadIt, lLoadOps.end(), - assumeNotTaken) != 0) { - std::stringstream message; - message << "Unable to match ping pong scheduling pattern. Details: " - << "Non-dot local loads found in non-persistent GEMM"; - LDBG(message.str()); - return; - } - if (estimateNonDotMemoryImpact(lStoreIt, lStoreOps.end(), - assumeNotTaken) != 0) { - std::stringstream message; - message << "Unable to match ping pong scheduling pattern. Details: " - << "Non-dot local stores found in non-persistent GEMM"; - LDBG(message.str()); - return; - } - - // Remove non-dot memory operations. - gLoadOps.erase(gLoadIt, gLoadOps.end()); - lLoadOps.erase(lLoadIt, lLoadOps.end()); - lStoreOps.erase(lStoreIt, lStoreOps.end()); - // All PingPong Scheduler assumes there are 2 movable global loads and 2 - // movable local loads. - if (gLoadOps.size() != 2 || lLoadOps.size() != 2) { + if (failed(pruneDotMemoryOps(dotGlobalLoads, dotLocalLoads, dotLocalStores, + assumeNotTaken))) { std::stringstream message; - message << "Unable to match ping pong slicing pattern. Details: " - << gLoadOps.size() << " global loads in dot computation, " - << lLoadOps.size() << " local loads in dot computation"; + message << "Failed to match ping pong scheduling pattern and prune " + "memory ops."; LDBG(message.str()); return; } @@ -1062,26 +1314,6 @@ void Pingponger::getDotPingponged() { // N.B., Tile size smaller than 128x128x64_FP16 is likely not compute-bound // that pingpong scheduling doesn't help much. - auto dotType = dotOps[0].getType(); - auto dotShape = dotType.getShape(); - auto aType = dotOps[0].getA().getType(); - auto aShape = aType.getShape(); - auto elemWidth = aType.getElementTypeBitWidth(); - int64_t tileSize = dotShape[0] * dotShape[1] * aShape[1] * elemWidth; - - const int64_t minTile = 262144; // e.g. 32x128x64x16bit - const int64_t smallTile = 16777216; // e.g. 128x128x64x16bit - const int64_t mediumTile = 33554432; // smallTile x 2 - const int64_t largeTile = 67108864; // e.g. 256x256x64x16bit - - auto encoding = cast(aType).getEncoding(); - auto srcEncoding = cast(encoding); - kWidth = srcEncoding.getKWidth(); - auto mfmaEncoding = cast(srcEncoding.getParent()); - SmallVector intShape; - intShape.push_back(mfmaEncoding.getMDim()); - intShape.push_back(mfmaEncoding.getNDim()); - if (numWarps == 4) { // Pingpong between warps from different blocks // Transform a loop with small tile size. // We've observed that this small tile size spent almost equivalent cycle @@ -1135,14 +1367,16 @@ class TritonAMDGPUBlockPingpongPass : public TritonAMDGPUBlockPingpongBase { public: TritonAMDGPUBlockPingpongPass() = default; - TritonAMDGPUBlockPingpongPass(int32_t numStages) { + TritonAMDGPUBlockPingpongPass(int32_t numStages, bool useAsyncCopy) { this->numStages = numStages; + this->useAsyncCopy = useAsyncCopy; } void runOnOperation() override { ModuleOp m = getOperation(); for (auto funcOp : m.getOps()) { funcOp.walk([&](scf::ForOp forOp) { - Pingponger pingponger(forOp, ttg::lookupNumWarps(forOp), numStages); + Pingponger pingponger(forOp, ttg::lookupNumWarps(forOp), numStages, + useAsyncCopy); pingponger.getDotPingponged(); }); } @@ -1151,6 +1385,8 @@ class TritonAMDGPUBlockPingpongPass } // namespace std::unique_ptr -mlir::createTritonAMDGPUBlockPingpongPass(int32_t numStages) { - return std::make_unique(numStages); +mlir::createTritonAMDGPUBlockPingpongPass(int32_t numStages, + bool useAsyncCopy) { + return std::make_unique(numStages, + useAsyncCopy); } diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/StreamPipeline.cpp b/third_party/amd/lib/TritonAMDGPUTransforms/StreamPipeline.cpp index a37e170223dd..96dd6ef51130 100644 --- a/third_party/amd/lib/TritonAMDGPUTransforms/StreamPipeline.cpp +++ b/third_party/amd/lib/TritonAMDGPUTransforms/StreamPipeline.cpp @@ -11,6 +11,7 @@ #include "triton/Dialect/TritonGPU/Transforms/PipeliningUtility.h" #include "triton/Dialect/TritonGPU/Transforms/Schedule.h" #include "triton/Dialect/TritonGPU/Transforms/Utility.h" +#include "triton/Tools/Sys/GetEnv.hpp" #include "llvm/ADT/STLExtras.h" #include "llvm/Support/Debug.h" @@ -120,9 +121,11 @@ class StreamPipeliner { public: StreamPipeliner(scf::ForOp _forOp, int _numStages, int _globalPrefetch, - int _localPrefetch, bool _useAsyncCopy) + int _localPrefetch, bool _useAsyncCopy, + bool _useF16BlockPingpong) : forOp(_forOp), numStages(_numStages), numBuffers(1), - useAsyncCopy(_useAsyncCopy), schedule(numStages), + useAsyncCopy(_useAsyncCopy), useF16BlockPingpong(_useF16BlockPingpong), + schedule(numStages), axisInfoAnalysis(forOp->getParentOfType()) { int lastStage = numStages - 1; stages[SCHED_GLOBAL_LOAD] = 0; @@ -175,6 +178,9 @@ class StreamPipeliner { // Directly store to shared memory with AsyncCopy when pipelining tt.loads bool useAsyncCopy; + // Whether or not we are intend to ping-pong. + bool useF16BlockPingpong; + // Stage for each SchedType Op int stages[SCHED_SIZE]; // Cluster for each SchedType Op @@ -220,6 +226,15 @@ LogicalResult StreamPipeliner::initSchedule(int maxIndirectionLevel) { bool pairedGlobalLoadLocalStore = stages[SCHED_LOCAL_STORE] == 0; stages[SCHED_LOCAL_STORE] += maxIndirectionLevel; + // In useAsyncCopy + PingPong case, we'd want to hoist out first async_wait + // out of the loop, and async_wait within the loop be towards the end. + // This is beneficial for maximizing hiding of latency, while ensuring + // 2 barriers between asyncWait and localLoad at start of loop S.T + // we do not hit race conditions between warp-lo and warp-hi. + if (useAsyncCopy && useF16BlockPingpong) { + stages[SCHED_ASYNC_WAIT] = std::max(0, stages[SCHED_LOCAL_LOAD] - 1); + } + LDBG( "Stage schedule:" << " GLOBAL_LOAD stage = " << stages[SCHED_GLOBAL_LOAD] << ", LOCAL_STORE stage = " << stages[SCHED_LOCAL_STORE] @@ -247,9 +262,9 @@ LogicalResult StreamPipeliner::initSchedule(int maxIndirectionLevel) { LDBG("deduced max shared memory buffer number = " << numBuffers); // We place async wait as the first cluster because we want to have it being - // the first in the main loop after pipelining. - int asyncWaitCluster = 0; - + // the first in the main loop after pipelining. However if we intend on doing + // PP then we set it near the end of the loop for reasons state above. + int asyncWaitCluster = useF16BlockPingpong ? 4 : 0; // If tt.load and ttg.local_store are in the same stage // spread them apart to allow overlap with compute // else @@ -1053,6 +1068,10 @@ struct PipelinePass : public TritonAMDGPUStreamPipelineBase { return signalPassFailure(); } + // TODO: Replace this with more stable argument/env, once we unify strategy + // between MXFP4 and FP16. + bool useF16BlockPingpong = + triton::tools::getBoolEnv("TRITON_HIP_ENABLE_F16_ASYNC_PINGPONG"); SmallVector loops; getOperation()->walk([&](scf::ForOp forOp) { labelLoadOpsForTritonDot(forOp); @@ -1072,12 +1091,16 @@ struct PipelinePass : public TritonAMDGPUStreamPipelineBase { (void)fsp.pipelineLoop(); } else { StreamPipeliner sp(forOp, tt::getNumStagesOrDefault(forOp, numStages), - globalPrefetch, localPrefetch, useAsyncCopy); + globalPrefetch, localPrefetch, useAsyncCopy, + useF16BlockPingpong); (void)sp.pipelineLoop(); } } - if (useAsyncCopy) { + // This removes additional barrier but pingpong will add the barrier again. + // So we should just not do it to get a better vmcnt in front of each + // AsyncCopy. + if (useAsyncCopy && !useF16BlockPingpong) { llvm::SmallSetVector waitOps; moduleOp.walk([&](ttg::AsyncWaitOp waitOp) { waitOps.insert(waitOp); }); tt::combineRedundantWaitOps(waitOps); diff --git a/third_party/amd/python/triton_amd.cc b/third_party/amd/python/triton_amd.cc index a4095a27ae75..24cabe6bb70a 100644 --- a/third_party/amd/python/triton_amd.cc +++ b/third_party/amd/python/triton_amd.cc @@ -76,8 +76,8 @@ void init_triton_amd_passes_ttgpuir(py::module &&m) { mlir::createTritonAMDGPUReorderInstructionsPass); ADD_PASS_WRAPPER_0("add_fold_true_cmpi", mlir::createTritonAMDGPUFoldTrueCmpIPass); - ADD_PASS_WRAPPER_1("add_block_pingpong", - mlir::createTritonAMDGPUBlockPingpongPass, int32_t); + ADD_PASS_WRAPPER_2("add_block_pingpong", + mlir::createTritonAMDGPUBlockPingpongPass, int32_t, bool); ADD_PASS_WRAPPER_4("add_stream_pipeline", mlir::createTritonAMDGPUStreamPipelinePass, int, int, int, bool); From c5c0e67732f080c88d978241218861c99b7e933f Mon Sep 17 00:00:00 2001 From: Jungwook Park Date: Sat, 17 May 2025 10:44:09 -0500 Subject: [PATCH 32/44] Add initial support for skinny mxfp gemm Overlapping buffer_load and local_load+dot --- .../TritonAMDGPUTransforms/BlockPingpong.cpp | 60 +++++++++++++++++-- 1 file changed, 54 insertions(+), 6 deletions(-) diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/BlockPingpong.cpp b/third_party/amd/lib/TritonAMDGPUTransforms/BlockPingpong.cpp index 90ec3d8d83a3..c95112ab0875 100644 --- a/third_party/amd/lib/TritonAMDGPUTransforms/BlockPingpong.cpp +++ b/third_party/amd/lib/TritonAMDGPUTransforms/BlockPingpong.cpp @@ -87,6 +87,7 @@ class Pingponger { LogicalResult transformTwoPPClusters(OpBuilder &builder, Location loc); LogicalResult transformFAv3(OpBuilder &builder, Location loc); LogicalResult transformFP4(OpBuilder &builder, Location loc); + LogicalResult transformFP4s(OpBuilder &builder, Location loc); void addAsymmetricSyncToLoop(OpBuilder &builder, Location loc); void updateOpInsertion(Operation *Op); void appendOp(Operation *Op); @@ -1030,6 +1031,43 @@ LogicalResult Pingponger::transformFAv3(OpBuilder &builder, Location loc) { return success(); } +LogicalResult Pingponger::transformFP4s(OpBuilder &builder, Location loc) { + //FIXME: support nonscale. + if (lLoadOps.size() != 4) + return failure(); + + builder.setInsertionPointAfter(forOp); + + // FIXME: This is duplicated code, need to refactorize. + auto i32ty = builder.getIntegerType(32); + auto workIDX = builder.create(loc, i32ty); + workIDX->moveBefore(forOp); + builder.setInsertionPointAfter(workIDX); + auto constZero = builder.create(loc, 0, 32); + auto constWarpSize = builder.create(loc, 256, 32); + auto warpIDX = builder.create(loc, workIDX, constWarpSize); + auto warpLow = builder.create(loc, arith::CmpIPredicate::eq, + warpIDX, constZero); + auto warpHigh = builder.create(loc, arith::CmpIPredicate::ne, + warpIDX, constZero); + + builder.setInsertionPointAfter(dotSOps[0]); + updateOpInsertion(dotSOps[0]); + + appendOp(builder.create(loc, 0)); + appendOp(builder.create(loc, warpLow)); + appendOp(builder.create(loc, 0)); + + for (int i = 0; i < 4; i++) + appendOp(lLoadOps[i]); + appendOp(dotSOps[0]); + + appendOp(builder.create(loc, 0)); + appendOp(builder.create(loc, warpHigh)); + return success(); +} + + LogicalResult Pingponger::transformFP4(OpBuilder &builder, Location loc) { builder.setInsertionPointAfter(forOp); @@ -1189,14 +1227,24 @@ void Pingponger::getDotPingponged() { auto aShape = aType.getShape(); auto elemWidth = aType.getElementTypeBitWidth(); int64_t tileSize = dotSShape[0] * dotSShape[1] * aShape[1]; - if (tileSize != 8388608 || aShape[1] != 128 || elemWidth != 8) - return; - if (transformFP4(builder, dotSOps[0]->getLoc()).failed()) { - LDBG("Encountered failure when trying to execute the two ping pong " - "cluster transformation"); - return; + // 256x256x256 (128xi8) + if (tileSize == 8388608 && aShape[1] == 128 && elemWidth == 8){ + if (transformFP4(builder, dotSOps[0]->getLoc()).failed()) { + LDBG("Encountered failure when trying to execute the two ping pong " + "cluster transformation"); + return; + } + } + // 128x128x512 (256xi8) + else if (tileSize == 4194304 && aShape[1] == 256 && elemWidth == 8){ + if (transformFP4s(builder, dotSOps[0]->getLoc()).failed()) { + LDBG("Encountered failure when trying to execute the two ping pong " + "cluster transformation"); + return; + } } + addAsymmetricSyncToLoop(builder, loc); return; } From bcc871d0f7d9631ff612f9015875ab127e925905 Mon Sep 17 00:00:00 2001 From: Jungwook Park Date: Sun, 18 May 2025 14:03:32 -0500 Subject: [PATCH 33/44] add AB load separated pingpong for skinny gemm. --- .../TritonAMDGPUTransforms/BlockPingpong.cpp | 48 +++++++++++++++++++ 1 file changed, 48 insertions(+) diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/BlockPingpong.cpp b/third_party/amd/lib/TritonAMDGPUTransforms/BlockPingpong.cpp index c95112ab0875..7a21bbb4e737 100644 --- a/third_party/amd/lib/TritonAMDGPUTransforms/BlockPingpong.cpp +++ b/third_party/amd/lib/TritonAMDGPUTransforms/BlockPingpong.cpp @@ -1036,6 +1036,9 @@ LogicalResult Pingponger::transformFP4s(OpBuilder &builder, Location loc) { if (lLoadOps.size() != 4) return failure(); +//#define OBO + +#if defined (OBO) builder.setInsertionPointAfter(forOp); // FIXME: This is duplicated code, need to refactorize. @@ -1064,6 +1067,51 @@ LogicalResult Pingponger::transformFP4s(OpBuilder &builder, Location loc) { appendOp(builder.create(loc, 0)); appendOp(builder.create(loc, warpHigh)); + + +#else + auto tokens = asyncWaitOps[0].getAsyncToken(); + Operation *aWait = asyncWaitOps[0]; + builder.setInsertionPointToStart(forOp.getBody()); + asyncWaitOps.clear(); + for (int i = 0; i < 2; i++) { + auto newOp = builder.clone(*aWait); + newOp->eraseOperand(3 - i); + newOp->eraseOperand(1 - i); + asyncWaitOps.push_back(cast(newOp)); + } + lLoadOps[0]->replaceUsesOfWith(aWait->getResult(0), asyncWaitOps[0]); + lLoadOps[2]->replaceUsesOfWith(aWait->getResult(0), asyncWaitOps[0]); + lLoadOps[1]->replaceUsesOfWith(aWait->getResult(0), asyncWaitOps[1]); + lLoadOps[3]->replaceUsesOfWith(aWait->getResult(0), asyncWaitOps[1]); + aWait->erase(); + + builder.setInsertionPointAfter(dotSOps[0]); + updateOpInsertion(dotSOps[0]); + + appendOp(builder.create(loc, 0)); + appendOp(builder.create(loc)); + appendOp(builder.create(loc, 0)); + appendOp(lLoadOps[0]); + appendOp(lLoadOps[2]); + + appendOp(asyncWaitOps[1]); + + appendOp(asyncCopyOps[1]); + appendOp(asyncCopyOps[3]); + appendOp(asyncCommitOps[1]); + appendOp(asyncCommitOps[3]); + + appendOp(builder.create(loc, 0)); + appendOp(builder.create(loc)); + appendOp(builder.create(loc, 0)); + + appendOp(lLoadOps[1]); + appendOp(lLoadOps[3]); + appendOp(dotSOps[0]); + +#endif + return success(); } From 1b2a86b88f5ee8e835b44fb5551595f3b6b456d7 Mon Sep 17 00:00:00 2001 From: ravil-mobile Date: Thu, 15 May 2025 17:12:20 +0000 Subject: [PATCH 34/44] [AMD] add slicing `async-copy-local-to-global` --- lib/Analysis/AxisInfo.cpp | 2 + lib/Analysis/CMakeLists.txt | 1 + .../TritonAMDGPUTransforms/BlockPingpong.cpp | 566 ++++++++++++++++-- .../CanonicalizePointers.cpp | 58 +- 4 files changed, 589 insertions(+), 38 deletions(-) diff --git a/lib/Analysis/AxisInfo.cpp b/lib/Analysis/AxisInfo.cpp index 7edc5c45aa30..bfade5b0d823 100644 --- a/lib/Analysis/AxisInfo.cpp +++ b/lib/Analysis/AxisInfo.cpp @@ -3,6 +3,7 @@ #include "llvm/Support/Debug.h" #include "llvm/Support/raw_ostream.h" +#include "third_party/amd/include/Dialect/TritonAMDGPU/IR/Dialect.h" #include "triton/Analysis/AxisInfo.h" #include "triton/Dialect/Triton/IR/Dialect.h" #include "triton/Dialect/Triton/IR/Utility.h" @@ -1043,6 +1044,7 @@ AxisInfoAnalysis::AxisInfoAnalysis(DataFlowSolver &solver) CastOpAxisInfoVisitor, CastOpAxisInfoVisitor, CastOpAxisInfoVisitor, + CastOpAxisInfoVisitor, CastOpAxisInfoVisitor>(); visitors.append(); visitors.append(); diff --git a/lib/Analysis/CMakeLists.txt b/lib/Analysis/CMakeLists.txt index 693d222f2f39..dac700dc1335 100644 --- a/lib/Analysis/CMakeLists.txt +++ b/lib/Analysis/CMakeLists.txt @@ -17,4 +17,5 @@ add_triton_library(TritonAnalysis TritonIR TritonGPUIR TritonNvidiaGPUIR + TritonAMDGPUIR ) diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/BlockPingpong.cpp b/third_party/amd/lib/TritonAMDGPUTransforms/BlockPingpong.cpp index 90ec3d8d83a3..7a7d907234bb 100644 --- a/third_party/amd/lib/TritonAMDGPUTransforms/BlockPingpong.cpp +++ b/third_party/amd/lib/TritonAMDGPUTransforms/BlockPingpong.cpp @@ -7,6 +7,7 @@ #include "mlir/Pass/Pass.h" #include "mlir/Pass/PassManager.h" #include "third_party/amd/include/Dialect/TritonAMDGPU/IR/Dialect.h" +#include "triton/Dialect/Triton/IR/Utility.h" #include "triton/Dialect/TritonGPU/IR/Dialect.h" #include "triton/Dialect/TritonGPU/Transforms/Utility.h" @@ -23,6 +24,25 @@ namespace tt = mlir::triton; namespace { +template std::optional getSingleUserOf(Value val) { + auto users = llvm::to_vector(val.getUsers()); + if (users.size() == 1) { + auto targetOp = dyn_cast(users[0]); + if (targetOp != nullptr) + return targetOp; + } + return std::nullopt; +} + +template +std::optional getIndex(T consumer, Value target) { + auto it = llvm::find_if(consumer->getOperands(), + [&](Value v) { return v == target; }); + if (it == consumer->getOperands().end()) + return std::nullopt; + return std::distance(consumer->getOperands().begin(), it); +} + // This pass transforms a for-loop calculating a GEMM. Main purpose of the // transform is improve the efficiency of the GPU dot instruction (mfma) // by interleaving the execution of two warps on each SIMD. Especially it groups @@ -40,6 +60,9 @@ class Pingponger { SmallVector asyncCopyOps; SmallVector asyncWaitOps; SmallVector asyncCommitOps; + DenseSet preservedAsyncCommits; + DenseMap> newAsyncGroups; + DenseMap> asyncTokenReassociation; SmallVector dotOps; SmallVector dotSOps; SmallVector> subViewOps; @@ -47,6 +70,7 @@ class Pingponger { SmallVector dotSliceOps; SmallVector constOffsets; Operation *lastInsertedOp; + const static inline std::string sliceAttrName = "sliceIdx"; // rocdl.s.setprio will be mapped to `s_setprio` instruction which set the // priority of the warp within a SIMD, determines which warp to occupy the @@ -74,14 +98,20 @@ class Pingponger { int64_t sliceWidth); LogicalResult genLocalSlice(OpBuilder &builder, Value v, Attribute dotEncoding, unsigned opIdx, - unsigned numSlices, int64_t sliceWidth); + unsigned numSlices, int64_t sliceWidth, + bool needCopySliced); LogicalResult genLocalSliceScales(OpBuilder &builder, Value v, Attribute dotEncoding, unsigned opIdx, - unsigned numSlices, int64_t sliceWidth); + unsigned numSlices, int64_t sliceWidth, + bool needCopySliced); LogicalResult sliceDot(OpBuilder &builder, Location loc, tt::DotOp op, unsigned numSlices); LogicalResult sliceDotScaled(OpBuilder &builder, Location loc, tt::DotScaledOp op, unsigned numSlices); + LogicalResult genAsyncCopySlices(OpBuilder &builder); + LogicalResult updateForOpSignature(OpBuilder &builder); + LogicalResult adjustRefinedAsyncTokens(OpBuilder &builder); + void transformOnePPClusters(OpBuilder &builder, Location loc); LogicalResult transformFourPPClusters(OpBuilder &builder, Location loc); LogicalResult transformTwoPPClusters(OpBuilder &builder, Location loc); @@ -126,7 +156,8 @@ class Pingponger { LogicalResult genLocalSliceHelper(OpBuilder &builder, Value v, unsigned opIdx, unsigned numSlices, int64_t sliceWidth, - RankedTensorType tensorType); + RankedTensorType tensorType, + bool needCopySliced); }; void Pingponger::updateOpInsertion(Operation *op) { lastInsertedOp = op; } @@ -642,10 +673,11 @@ void Pingponger::genOffsetConstants(Location loc, OpBuilder &builder, // generates ops when succeed, return fail() otherwise. LogicalResult Pingponger::genLocalSlice(OpBuilder &builder, Value v, Attribute dotEncoding, unsigned opIdx, - unsigned numSlices, - int64_t sliceWidth) { + unsigned numSlices, int64_t sliceWidth, + bool needCopySliced) { // TODO: support transformed input to dot auto localLoad = v.getDefiningOp(); + if (!localLoad) return failure(); auto memDesc = localLoad.getSrc(); @@ -663,14 +695,12 @@ LogicalResult Pingponger::genLocalSlice(OpBuilder &builder, Value v, auto tensorType = RankedTensorType::get(shape, elementType, dotOperandEnc); return genLocalSliceHelper(builder, v, opIdx, numSlices, sliceWidth, - tensorType); + tensorType, needCopySliced); } -LogicalResult Pingponger::genLocalSliceScales(OpBuilder &builder, Value v, - Attribute dotEncoding, - unsigned opIdx, - unsigned numSlices, - int64_t sliceWidth) { +LogicalResult Pingponger::genLocalSliceScales( + OpBuilder &builder, Value v, Attribute dotEncoding, unsigned opIdx, + unsigned numSlices, int64_t sliceWidth, bool needCopySliced) { auto localLoad = v.getDefiningOp(); if (!localLoad) return failure(); @@ -685,18 +715,13 @@ LogicalResult Pingponger::genLocalSliceScales(OpBuilder &builder, Value v, auto dotOperandEnc = ttg::LinearEncodingAttr::get(type.getContext(), ll); auto tensorType = RankedTensorType::get(shape, elementType, dotOperandEnc); - return genLocalSliceHelper(builder, v, 0, numSlices, sliceWidth, tensorType); + return genLocalSliceHelper(builder, v, 0, numSlices, sliceWidth, tensorType, + needCopySliced); } -LogicalResult Pingponger::genLocalSliceHelper(OpBuilder &builder, Value v, - unsigned opIdx, - unsigned numSlices, - int64_t sliceWidth, - RankedTensorType tensorType) { - - SmallVector slices; - SmallVector subviews; - +LogicalResult Pingponger::genLocalSliceHelper( + OpBuilder &builder, Value v, unsigned opIdx, unsigned numSlices, + int64_t sliceWidth, RankedTensorType tensorType, bool needCopySliced) { auto localLoad = v.getDefiningOp(); if (!localLoad) return failure(); @@ -709,11 +734,36 @@ LogicalResult Pingponger::genLocalSliceHelper(OpBuilder &builder, Value v, int64_t kIdx = opIdx == 0 ? 1 : 0; shape[kIdx] = sliceWidth; + auto resEncoding = localLoad.getResult().getType().getEncoding(); + auto dotOperandResEncoding = + dyn_cast(resEncoding); + const bool refineOrigSubview = dotOperandResEncoding != nullptr; + + auto arg = mlir::dyn_cast(memDesc); + if (!arg) { + LDBG("failed to cast input to `ttg.LocalLoadOp` to `BlockArgument`"); + return failure(); + } + + auto forOp = localLoad->getParentOfType(); + auto argIdx = arg.getArgNumber(); + auto yieldOperand = forOp.getTiedLoopYieldedValue(arg); + auto yieldOp = cast(yieldOperand->getOwner()); + auto origMemDesc = + cast(yieldOperand->get().getDefiningOp()); + auto subviewDescType = ttg::MemDescType::get( shape, elementType, type.getEncoding(), type.getMemorySpace(), type.getMutableMemory(), type.getAllocShape()); + SmallVector slices; + SmallVector subviews; + MLIRContext *ctx = localLoad->getContext(); + auto intType = mlir::IntegerType::get(ctx, 32); for (int i = 0; i < numSlices; i++) { + auto ip = builder.saveInsertionPoint(); + builder.setInsertionPoint(&forOp.front()); + SmallVector offsetsVal; SmallVector offsets = {0, 0}; offsets[opIdx == 0 ? 1 : 0] = i; @@ -721,13 +771,34 @@ LogicalResult Pingponger::genLocalSliceHelper(OpBuilder &builder, Value v, offsetsVal.push_back(builder.create( v.getLoc(), off * sliceWidth, 32)); } + + builder.setInsertionPointAfter(origMemDesc); + auto sliceIdAttr = mlir::IntegerAttr::get(intType, i); + if (needCopySliced && refineOrigSubview) { + Value newOrigSmem = builder.create( + origMemDesc.getLoc(), subviewDescType, origMemDesc, offsetsVal); + + // set attributes - i.e., which dot-operand, which slice + newOrigSmem.getDefiningOp()->setAttr(Pingponger::sliceAttrName, + sliceIdAttr); + newOrigSmem.getDefiningOp()->setAttr( + triton::amdgpu::OpIdxAttr::getMnemonic(), + triton::amdgpu::OpIdxAttr::get(ctx, + dotOperandResEncoding.getOpIdx())); + } + builder.restoreInsertionPoint(ip); + Value newSmem = builder.create( v.getLoc(), subviewDescType, memDesc, offsetsVal); Value prefetchSlice = builder.create( v.getLoc(), tensorType, newSmem, waitToken); + + prefetchSlice.getDefiningOp()->setAttr(Pingponger::sliceAttrName, + sliceIdAttr); subviews.push_back(newSmem.getDefiningOp()); slices.push_back(prefetchSlice.getDefiningOp()); } + subViewOps.push_back(subviews); loadSliceOps.push_back(slices); return success(); @@ -747,9 +818,12 @@ LogicalResult Pingponger::sliceDot(OpBuilder &builder, Location loc, genOffsetConstants(loc, builder, numSlices, sliceWidth); builder.setInsertionPointAfter(useAsyncCopy ? asyncCopyOps[0] : gLoadOps[0]); auto dotEncoding = op.getType().getEncoding(); - if (genLocalSlice(builder, op.getA(), dotEncoding, 0, numSlices, sliceWidth) + const bool needCopySliced = false; + if (genLocalSlice(builder, op.getA(), dotEncoding, 0, numSlices, sliceWidth, + needCopySliced) .failed() || - genLocalSlice(builder, op.getB(), dotEncoding, 1, numSlices, sliceWidth) + genLocalSlice(builder, op.getB(), dotEncoding, 1, numSlices, sliceWidth, + needCopySliced) .failed()) return failure(); @@ -772,6 +846,390 @@ LogicalResult Pingponger::sliceDot(OpBuilder &builder, Location loc, return success(); } +LogicalResult Pingponger::genAsyncCopySlices(OpBuilder &builder) { + if (asyncCopyOps.empty()) + return success(); + + auto &loopBody = forOp.getRegion().front(); + auto yieldOp = cast(loopBody.getTerminator()); + + for (auto asyncCopy : asyncCopyOps) { + MLIRContext *ctx = asyncCopy->getContext(); + auto srcPointers = asyncCopy.getSrc(); + auto subView = cast( + asyncCopy.getResult().getDefiningOp()); + auto subViewEncoding = subView.getType().getEncoding(); + + DenseMap subViews; + for (auto user : subView->getUsers()) { + if (auto subView = dyn_cast(user)) { + if (auto attr = subView->getAttrOfType( + Pingponger::sliceAttrName)) { + auto sliceIdx = attr.getValue().getSExtValue(); + subViews.insert({sliceIdx, subView}); + + if (!newAsyncGroups.contains(sliceIdx)) { + newAsyncGroups.insert({sliceIdx, {}}); + } + } + } + } + + if (subViews.empty()) { + auto commit = + getSingleUserOf(asyncCopy.getToken()); + preservedAsyncCommits.insert(commit->getResult()); + continue; + } + + // infer the sliced shape + triton::gpu::MemDescSubviewOp subViewSlice = subViews[0]; + auto origShape = subView.getType().getShape(); + auto slicedShape = subViewSlice.getType().getShape(); + assert(origShape.size() == slicedShape.size()); + const auto numDims = origShape.size(); + int64_t slicedDim = -1; + for (size_t dim = 0; dim < numDims; ++dim) { + if (origShape[dim] != slicedShape[dim]) { + slicedDim = dim; + break; + } + } + + builder.setInsertionPointAfter(asyncCopy); + + auto elementType = srcPointers.getType().getElementType(); + auto encoding = + cast(srcPointers.getType().getEncoding()); + + auto warpsPerCTA = encoding.getWarpsPerCTA(); + auto sizePerThread = encoding.getSizePerThread(); + SmallVector threadPerWarp(warpsPerCTA.size(), 0); + for (size_t dim = 0; dim < numDims; ++dim) { + threadPerWarp[dim] = + slicedShape[dim] / (warpsPerCTA[dim] * sizePerThread[dim]); + } + assert(mlir::product(threadPerWarp) == 64); + + auto newEncoding = ttg::BlockedEncodingAttr::get( + ctx, sizePerThread, threadPerWarp, warpsPerCTA, encoding.getOrder(), + encoding.getCTALayout()); + + auto convertTensor = [&](mlir::TypedValue tensor) { + RankedTensorType newType = nullptr; + Value newTensor = nullptr; + RankedTensorType slicedTensorType = nullptr; + if (tensor) { + assert(encoding == tensor.getType().getEncoding()); + auto elemType = tensor.getType().getElementType(); + newType = RankedTensorType::get(origShape, elemType, newEncoding); + newTensor = + builder + .create(tensor.getLoc(), newType, tensor) + .getResult(); + slicedTensorType = + RankedTensorType::get(slicedShape, elemType, newEncoding); + } + + return std::make_tuple(newType, newTensor, slicedTensorType); + }; + + mlir::TypedValue origMask = nullptr; + mlir::TypedValue origOtherTensor = nullptr; + + if (auto value = asyncCopy.getMask()) { + origMask = dyn_cast(value); + } + if (auto value = asyncCopy.getOther()) { + origOtherTensor = cast(value); + } + + auto [newSrcType, newSrcPointers, slicedSrcType] = + convertTensor(srcPointers); + auto [newMaskType, newMask, slicedMaskType] = convertTensor(origMask); + auto [newOtherType, newOther, slicedOtherType] = + convertTensor(origOtherTensor); + + auto extract = [&builder](Type resType, Value src, + DenseI64ArrayAttr &offset) { + Value resValue = nullptr; + if (src) { + resValue = builder.create( + src.getLoc(), resType, src, offset); + } + return resValue; + }; + + assert(slicedDim != -1); + SmallVector newCommits; + auto numReps = origShape[slicedDim] / slicedShape[slicedDim]; + for (size_t rep = 0; rep < numReps; ++rep) { + SmallVector offset(slicedShape.size(), 0); + offset[slicedDim] = slicedShape[slicedDim] * rep; + auto offsetAttr = DenseI64ArrayAttr::get(ctx, offset); + + auto extractedSrc = extract(slicedSrcType, newSrcPointers, offsetAttr); + auto extractedMask = extract(slicedMaskType, newMask, offsetAttr); + auto extractedOther = extract(slicedOtherType, newOther, offsetAttr); + + auto newAsyncCopy = builder.create( + asyncCopy->getLoc(), extractedSrc, Value{subViews[rep].getResult()}, + extractedMask, extractedOther, asyncCopy.getCache(), + asyncCopy.getEvict(), asyncCopy.getIsVolatile()); + + auto newCommit = builder.create( + asyncCopy->getLoc(), newAsyncCopy.getToken()); + + // propagate all attributes from `mem-view` to the commit token + newAsyncCopy->setAttrs(subViews[rep]->getAttrs()); + newCommit->setAttrs(subViews[rep]->getAttrs()); + + newAsyncGroups[rep].push_back(newCommit); + newCommits.push_back(newCommit); + } + + auto origCommitGroup = getSingleUserOf(asyncCopy); + auto maybeResultIdx = getIndex(yieldOp, origCommitGroup->getResult()); + assert(maybeResultIdx.has_value()); + auto origYieldOperand = yieldOp->getOperand(maybeResultIdx.value()); + asyncTokenReassociation.insert({origYieldOperand, newCommits}); + } + + return success(); +} + +LogicalResult Pingponger::updateForOpSignature(OpBuilder &builder) { + // Note: call this method at the very end when reference to the + // original ops are not needed anymore + + if (asyncCopyOps.empty()) + return llvm::success(); + + Block &oldBlock = forOp.getRegion().front(); + auto origYieldOp = cast(oldBlock.getTerminator()); + auto orgiInitArgs = forOp.getInitArgs(); + + SmallVector newInputArgTokens; + for (auto &[origCommit, newCommits] : asyncTokenReassociation) { + auto maybeIdx = getIndex(origYieldOp, origCommit); + assert(maybeIdx.has_value()); + auto initCommitArgValue = orgiInitArgs[maybeIdx.value()]; + auto initCommitOp = + cast(initCommitArgValue.getDefiningOp()); + builder.setInsertionPointAfter(initCommitOp); + for (size_t i = 0; i < newCommits.size(); ++i) { + auto newInputArgToken = builder.create( + initCommitOp->getLoc(), initCommitOp.getAsyncToken().getType(), + initCommitOp.getInputTokens()); + newInputArgTokens.push_back(newInputArgToken); + } + } + + builder.setInsertionPointAfter(forOp); + DenseSet preservedArgsIndices; + DenseSet removedArgsIndices; + auto origYeildValues = forOp.getYieldedValues(); + for (auto [idx, value] : llvm::enumerate(origYeildValues)) { + bool copyable = dyn_cast(value.getType()) == nullptr; + copyable |= preservedAsyncCommits.contains(value); + if (copyable) { + preservedArgsIndices.insert(idx); + } else { + removedArgsIndices.insert(idx); + } + } + + DenseMap argIndicesMap; + SmallVector newInitArgs; + for (auto [idx, value] : llvm::enumerate(orgiInitArgs)) { + if (preservedArgsIndices.contains(idx)) { + argIndicesMap.insert({idx, newInitArgs.size()}); + newInitArgs.push_back(value); + } + } + + for (auto newInputToken : newInputArgTokens) { + newInitArgs.push_back(newInputToken); + } + + // Create a new ForOp + scf::ForOp newForOp = builder.create( + forOp->getLoc(), forOp.getLowerBound(), forOp.getUpperBound(), + forOp.getStep(), newInitArgs); + + // Map original block arguments to new ones + Block &newBlock = newForOp.getRegion().front(); + + IRMapping mapping; + auto oldIterArgs = forOp.getRegionIterArgs(); + auto newIterArgs = newForOp.getRegionIterArgs(); + mapping.map(oldBlock.getArgument(0), newBlock.getArgument(0)); // loop index + for (auto [origIdx, newIdx] : argIndicesMap) { + mapping.map(oldIterArgs[origIdx], newIterArgs[newIdx]); + } + + // Clone the body of the loop + builder.setInsertionPointToStart(&newBlock); + for (auto &op : oldBlock.without_terminator()) { + builder.clone(op, mapping); + } + + // Clone the yield terminator + builder.setInsertionPointToEnd(&newBlock); + SmallVector newYieldResults; + for (auto [idx, value] : llvm::enumerate(forOp.getYieldedValues())) { + if (preservedArgsIndices.contains(idx)) { + newYieldResults.push_back(mapping.lookup(value)); + } + } + + for (auto &[origCommit, newCommits] : asyncTokenReassociation) { + for (auto commit : newCommits) + newYieldResults.push_back(mapping.lookup(commit)); + } + + builder.create(origYieldOp.getLoc(), newYieldResults); + + auto newForOpResults = newForOp.getResults(); + DenseSet adjustedUsers; + for (auto [idx, oldResult] : llvm::enumerate(forOp->getResults())) { + if (preservedArgsIndices.contains(idx)) { + auto newArgIdx = argIndicesMap[idx]; + oldResult.replaceAllUsesWith(newForOpResults[newArgIdx]); + } else { + for (auto user : oldResult.getUsers()) { + adjustedUsers.insert(user); + } + } + } + + // Adjust async-wait outside the newForOp + assert(adjustedUsers.size() == 1); + auto asyncWaitEpilogue = dyn_cast(*adjustedUsers.begin()); + assert(asyncWaitEpilogue != nullptr); + + builder.setInsertionPointAfter(asyncWaitEpilogue); + SmallVector newOperands; + for (auto newResult : newForOp->getResults()) { + if (dyn_cast(newResult.getType())) { + newOperands.push_back(newResult); + } + } + auto newAsyncWaitEpilogue = builder.create( + asyncWaitEpilogue->getLoc(), newOperands, 0); + asyncWaitEpilogue->replaceAllUsesWith(newAsyncWaitEpilogue); + asyncWaitEpilogue->erase(); + + SmallVector newAsyncTokens; + for (auto &arg : newForOp.getRegionIterArgs()) { + if (dyn_cast(arg.getType())) + newAsyncTokens.push_back(arg); + } + + // adjust async-wait inside the newForOp block + ttg::AsyncWaitOp asyncWait = nullptr; + newForOp.walk([&asyncWait](ttg::AsyncWaitOp op) { + asyncWait = op; + return WalkResult::interrupt(); + }); + assert(asyncWait != nullptr); + builder.setInsertionPointAfter(asyncWait); + auto newAsyncToken = + builder.create(asyncWait->getLoc(), newAsyncTokens, 0); + asyncWait->replaceAllUsesWith(newAsyncToken); + asyncWait.erase(); + + this->forOp->erase(); + this->forOp = newForOp; + return success(); +} + +LogicalResult Pingponger::adjustRefinedAsyncTokens(OpBuilder &builder) { + auto yeildOp = cast(forOp.getBody()->getTerminator()); + auto forOpArgs = forOp.getRegionIterArgs(); + + DenseMap, Value> refinedTokens; + SmallVector nonRefinedTokens; + forOp->walk([&](ttg::AsyncCommitGroupOp commit) { + auto tokenIdx = getIndex(yeildOp, commit.getResult()); + if (!tokenIdx.has_value()) + return WalkResult::advance(); + + int32_t opIdx = -1; + if (auto attr = commit->getAttrOfType( + triton::amdgpu::OpIdxAttr::getMnemonic())) { + opIdx = attr.getValue(); + } + + int32_t sliceId = -1; + if (auto attr = commit->getAttrOfType( + Pingponger::sliceAttrName)) { + sliceId = attr.getValue().getSExtValue(); + } + + assert(tokenIdx.has_value()); + auto asyncToken = forOpArgs[tokenIdx.value()]; + if ((opIdx > -1) && (sliceId > -1)) { + refinedTokens.insert({{opIdx, sliceId}, asyncToken}); + } else { + nonRefinedTokens.push_back(asyncToken); + } + return WalkResult::advance(); + }); + + // leave only `scaleA` and `scaleB` wait-tokens + ttg::AsyncWaitOp origAsyncWait; + forOp->walk([&origAsyncWait](ttg::AsyncWaitOp wait) { + origAsyncWait = wait; + return WalkResult::interrupt(); + }); + builder.setInsertionPointAfter(origAsyncWait); + auto newAsyncWait = builder.create(origAsyncWait->getLoc(), + nonRefinedTokens, 0); + origAsyncWait->replaceAllUsesWith(newAsyncWait); + origAsyncWait->erase(); + + // collect all refined localLoads + DenseMap, ttg::LocalLoadOp> refinedLocalLoads; + forOp->walk([&](ttg::LocalLoadOp localLoad) { + int32_t opIdx = -1; + auto resultType = cast(localLoad.getResult().getType()); + if (auto encding = + dyn_cast(resultType.getEncoding())) { + opIdx = encding.getOpIdx(); + } + + int32_t sliceId = -1; + if (auto attr = localLoad->getAttrOfType( + Pingponger::sliceAttrName)) { + sliceId = attr.getValue().getSExtValue(); + } + + if ((opIdx > -1) && (sliceId > -1)) { + refinedLocalLoads.insert({{opIdx, sliceId}, localLoad}); + } + }); + + // create new local load preceeded by new wait-tokens + for (auto &item : refinedTokens) { + auto [opIdx, sliceIdx] = item.first; + auto commit = item.second; + if (!refinedLocalLoads.contains({opIdx, sliceIdx})) + continue; + auto localLoad = refinedLocalLoads[{opIdx, sliceIdx}]; + builder.setInsertionPointAfter(localLoad); + auto token = builder.create(localLoad->getLoc(), + ValueRange{commit}, 0); + auto newLocalLoad = builder.create( + localLoad->getLoc(), localLoad.getResult().getType(), + localLoad.getSrc(), token); + localLoad->replaceAllUsesWith(newLocalLoad); + localLoad->erase(); + } + + return success(); +} + LogicalResult Pingponger::sliceDotScaled(OpBuilder &builder, Location loc, tt::DotScaledOp op, unsigned numSlices) { @@ -786,13 +1244,22 @@ LogicalResult Pingponger::sliceDotScaled(OpBuilder &builder, Location loc, if (shapeB[1] % numSlices != 0) return failure(); - builder.setInsertionPointAfter(op); + if (!gLoadOps.empty()) + builder.setInsertionPointAfter(gLoadOps[0]); + else if (!asyncCopyOps.empty()) { + builder.setInsertionPointAfter(asyncCopyOps[0]); + } else { + return failure(); + } auto dotEncoding = op.getType().getEncoding(); + bool needCopySliced = true; // Generate slices for operands A and B - if (genLocalSlice(builder, op.getA(), dotEncoding, 0, numSlices, sliceWidth) + if (genLocalSlice(builder, op.getA(), dotEncoding, 0, numSlices, sliceWidth, + needCopySliced) .failed() || - genLocalSlice(builder, op.getB(), dotEncoding, 1, numSlices, sliceWidth) + genLocalSlice(builder, op.getB(), dotEncoding, 1, numSlices, sliceWidth, + needCopySliced) .failed()) return failure(); @@ -800,10 +1267,11 @@ LogicalResult Pingponger::sliceDotScaled(OpBuilder &builder, Location loc, Value aScale = op.getAScale(); Value bScale = op.getBScale(); + needCopySliced = false; if (aScale) { if (genLocalSliceScales(builder, aScale, op.getAScale().getType().getEncoding(), 0, - numSlices, sliceScaleWidth) + numSlices, sliceScaleWidth, needCopySliced) .failed()) return failure(); } @@ -811,7 +1279,7 @@ LogicalResult Pingponger::sliceDotScaled(OpBuilder &builder, Location loc, if (bScale) { if (genLocalSliceScales(builder, bScale, op.getBScale().getType().getEncoding(), 0, - numSlices, sliceScaleWidth) + numSlices, sliceScaleWidth, needCopySliced) .failed()) return failure(); } @@ -1048,7 +1516,6 @@ LogicalResult Pingponger::transformFP4(OpBuilder &builder, Location loc) { warpIDX, constZero); builder.setInsertionPointAfter(dotSOps[0]); - if (sliceDotScaled(builder, loc, dotSOps[0], 4).failed()) return failure(); updateOpInsertion(dotSliceOps[0]); @@ -1115,6 +1582,7 @@ void Pingponger::getDotPingponged() { MLIRContext *ctx = forOp.getContext(); Location loc = forOp.getLoc(); + SmallVector asyncWaitsOps; forOp->walk([&](Operation *op) { if (auto gLoad = dyn_cast(op)) gLoadOps.push_back(gLoad); @@ -1136,15 +1604,26 @@ void Pingponger::getDotPingponged() { dotOps.push_back(pingpongDot); } else if (auto pingpongDot = dyn_cast(op)) { dotSOps.push_back(pingpongDot); - } else if (auto asyncOp = dyn_cast(op)) { - asyncCopyOps.push_back(asyncOp); + } else if (auto asyncCopy = dyn_cast(op)) { + asyncCopyOps.push_back(asyncCopy); } else if (auto asyncCommitGroupOp = dyn_cast(op)) { asyncCommitOps.push_back(asyncCommitGroupOp); - } else if (auto asyncOp = dyn_cast(op)) - asyncWaitOps.push_back(asyncOp); + } else if (auto wait = dyn_cast(op)) { + asyncWaitsOps.push_back(wait); + } }); + const bool isAsyncOpsInUse = !(asyncWaitsOps.empty()); + if (isAsyncOpsInUse && (asyncWaitsOps.size() != 1)) { + std::stringstream message; + message << "Unable to match ping pong scheduling pattern. " + << "Found " << asyncWaitsOps.size() + << " `AsyncWaitOp` in the scheduled region. Only one is allowed."; + LDBG(message.str()); + return; + } + // Fixme : use proper condition to identify FAv3 if (numStages == 4 && dotOps.size() == 2) { if (transformFAv3(builder, loc).failed()) { @@ -1189,14 +1668,33 @@ void Pingponger::getDotPingponged() { auto aShape = aType.getShape(); auto elemWidth = aType.getElementTypeBitWidth(); int64_t tileSize = dotSShape[0] * dotSShape[1] * aShape[1]; - if (tileSize != 8388608 || aShape[1] != 128 || elemWidth != 8) + if (tileSize != 8388608 || aShape[1] != 128 || elemWidth != 8) { + LDBG("encountered large matrix for scale dot: " + << "TileSize==" << tileSize << "; aShape[1]==" << aShape[1] + << "; elemWidth: " << elemWidth); return; + } if (transformFP4(builder, dotSOps[0]->getLoc()).failed()) { LDBG("Encountered failure when trying to execute the two ping pong " "cluster transformation"); return; } + + if (llvm::failed(genAsyncCopySlices(builder))) { + LDBG("failed to slice global-to-local async copies"); + } + + auto updateSignature = updateForOpSignature(builder); + if (llvm::failed(updateSignature)) { + LDBG("failed to update forOp signature"); + } + + if (llvm::succeeded(updateSignature)) { + if (llvm::failed(adjustRefinedAsyncTokens(builder))) { + LDBG("failed to update forOp signature"); + } + } addAsymmetricSyncToLoop(builder, loc); return; } diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/CanonicalizePointers.cpp b/third_party/amd/lib/TritonAMDGPUTransforms/CanonicalizePointers.cpp index 0844cc941efb..f782c237b59e 100644 --- a/third_party/amd/lib/TritonAMDGPUTransforms/CanonicalizePointers.cpp +++ b/third_party/amd/lib/TritonAMDGPUTransforms/CanonicalizePointers.cpp @@ -1146,6 +1146,54 @@ class ConvertConvertLayoutOp } }; +/// slice integer offset, keep base +class ConvertExtractSliceOp + : public PointerCanonicalizationPattern { +public: + using PointerCanonicalizationPattern::PointerCanonicalizationPattern; + + LogicalResult + matchAndRewrite_(tt::amdgpu::ExtractSliceOp extractSliceOp, + OneToNOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + ValueRange remappedOperands = adaptor.getSource(); + if (remappedOperands.size() != 2) { + // some prior op materialized the fat ptr, e.g.: + // %3 = tt.bitcast %2 + // %4 = tt.splat %3 + return success(); + } + Value fatPtrBase = remappedOperands[0]; + Value fatPtrOffset = remappedOperands[1]; + if (!llvm::isa(fatPtrBase.getType())) { + return rewriter.notifyMatchFailure(extractSliceOp, + "non tt.ptr base unimplemented"); + } + auto offsetTensorTy = dyn_cast(fatPtrOffset.getType()); + if (!offsetTensorTy) { + return rewriter.notifyMatchFailure( + extractSliceOp, "non RankedTensorType offset unimplemented"); + } + + Location loc = extractSliceOp->getLoc(); + + const FatPointers::FatPtrAttrs &fatPtrAttrs = + fatPtrs.at({fatPtrBase, fatPtrOffset}); + auto newSrc = createTensorPointer(rewriter, fatPtrBase, fatPtrOffset, loc, + fatPtrAttrs); + + RankedTensorType resType = extractSliceOp.getResult().getType(); + tt::amdgpu::ExtractSliceOp newExtractSliceOp = + rewriter.create( + loc, Type{resType}, Value{newSrc}, + extractSliceOp.getStaticOffsetsAttr()); + rewriter.replaceOp(extractSliceOp, newExtractSliceOp); + fatPtrs[{fatPtrBase, newExtractSliceOp}] = + fatPtrs.at({fatPtrBase, fatPtrOffset}); + return success(); + } +}; + template class MaterializeFatPointer : public PointerCanonicalizationPattern { public: @@ -1508,6 +1556,8 @@ void TritonAMDGPUCanonicalizePointersPass::runOnOperation() { target.addDynamicallyLegalDialect(isLegal); target.addDynamicallyLegalDialect(isLegal); target.addDynamicallyLegalDialect(isLegal); + target.addDynamicallyLegalDialect( + isLegal); // Rewrite the rest of the ops. // Note we *do not* declare unrealized_cast an illegal op here in order that @@ -1529,10 +1579,10 @@ void TritonAMDGPUCanonicalizePointersPass::runOnOperation() { MaterializeFatPointerVariadic, MaterializeFatPointerVariadic, MaterializeFatPointerVariadic, ConvertSCFForOp, - ConvertExpandDims, ConvertSCFYieldOp, ConvertSCFIfOp, - ConvertSCFConditionOp, ConvertSCFWhileOp, ConvertCFCondBranch, - ConvertCFBranch, ConvertArithSelectOp, ConvertReturnOp>( - patterns.getContext(), opsToRewrite, fatPrs); + ConvertExpandDims, ConvertExtractSliceOp, ConvertSCFYieldOp, + ConvertSCFIfOp, ConvertSCFConditionOp, ConvertSCFWhileOp, + ConvertCFCondBranch, ConvertCFBranch, ConvertArithSelectOp, + ConvertReturnOp>(patterns.getContext(), opsToRewrite, fatPrs); if (failed(applyPartialConversion(func, target, std::move(patterns), config))) return signalPassFailure(); From 33f6ce92ad09398a8ad057716ed0067020420296 Mon Sep 17 00:00:00 2001 From: Lei Zhang Date: Mon, 19 May 2025 09:39:45 -0700 Subject: [PATCH 35/44] Revert "Revert "[AMD] Use v_permlane to optimize MFAM to linear layout on GFX950 (#6744)"" This reverts commit f6065b93e96e1cdb5137e27f7b8f21199399cf54. --- .../TritonGPU/IR/LinearLayoutConversions.h | 3 +- .../TritonGPU/IR/LinearLayoutConversions.cpp | 19 +- test/Conversion/amd/mfma-shortcut.mlir | 229 ++++++++++-------- .../ConvertLayoutOpToLLVM.cpp | 115 +++++++++ .../OptimizeEpilogue.cpp | 20 +- 5 files changed, 257 insertions(+), 129 deletions(-) diff --git a/include/triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h b/include/triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h index 6a6047216b71..b00eb5084112 100644 --- a/include/triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h +++ b/include/triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h @@ -287,8 +287,7 @@ LinearLayout chooseScaledMfmaScaleLayout( // 8 elements. This layout is useful for emitting the widest 128-bit global // store instructions. Since it closely resembles mfmaLayout, conversion between // the two can be done using transferWithinWarp, without involving LDS -LinearLayout chooseMfmaLikeStoreLayout(AMDMfmaEncodingAttr mfmaLayout, - ArrayRef shape); +std::optional chooseMfmaLikeStoreLayout(RankedTensorType valType); } // namespace mlir::triton::gpu #endif // TRITON_DIALECT_TRITONGPU_IR_LINEARLAYOUTCONVERSIONS_H diff --git a/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp b/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp index 22949a1489a3..d1397549de27 100644 --- a/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp +++ b/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp @@ -1533,10 +1533,17 @@ LinearLayout chooseScaledMfmaScaleLayout( return newLL; } -LinearLayout chooseMfmaLikeStoreLayout(AMDMfmaEncodingAttr mfmaLayout, - ArrayRef shape) { - assert(shape.size() == 2 && mfmaLayout.getMDim() == 32 && - mfmaLayout.getNDim() == 32 && mfmaLayout.getIsTransposed()); +std::optional +chooseMfmaLikeStoreLayout(RankedTensorType valType) { + auto mfmaLayout = cast(valType.getEncoding()); + + // Currently support transposed [B]F16 MFMA32x32 on CDNA4 + bool isMfma32 = mfmaLayout.getMDim() == 32 && mfmaLayout.getNDim() == 32; + Type elemType = valType.getElementType(); + if (!(valType.getRank() == 2 && (elemType.isF16() || elemType.isBF16()) && + mfmaLayout.getVersionMajor() == 4 && mfmaLayout.getIsTransposed() && + isMfma32)) + return {}; MLIRContext *ctx = mfmaLayout.getContext(); StringAttr kRegister = S("register"); @@ -1561,8 +1568,8 @@ LinearLayout chooseMfmaLikeStoreLayout(AMDMfmaEncodingAttr mfmaLayout, identityStandardND(kWarp, mfmaLayout.getWarpsPerCTA(), order); LinearLayout ctaLayout = mfma8Layout.transposeOuts(standardOutDims) * warpLayout.transposeOuts(standardOutDims); - mfma8Layout = - combineCtaCgaWithShape(ctaLayout, mfmaLayout.getCTALayout(), shape); + mfma8Layout = combineCtaCgaWithShape(ctaLayout, mfmaLayout.getCTALayout(), + valType.getShape()); return mfma8Layout; } diff --git a/test/Conversion/amd/mfma-shortcut.mlir b/test/Conversion/amd/mfma-shortcut.mlir index 0e64eed47040..94f1650c39de 100644 --- a/test/Conversion/amd/mfma-shortcut.mlir +++ b/test/Conversion/amd/mfma-shortcut.mlir @@ -1,13 +1,14 @@ -// RUN: triton-opt %s --tritongpu-reduce-data-duplication --allocate-shared-memory --convert-triton-amdgpu-to-llvm=arch="gfx942" -split-input-file | FileCheck %s +// RUN: triton-opt %s --tritongpu-reduce-data-duplication --allocate-shared-memory --convert-triton-amdgpu-to-llvm=arch="gfx942" -split-input-file | FileCheck %s --check-prefix=GFX942 +// RUN: triton-opt %s --tritongpu-reduce-data-duplication --allocate-shared-memory --convert-triton-amdgpu-to-llvm=arch="gfx950" -split-input-file | FileCheck %s --check-prefix=GFX950 #mfma = #ttg.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 16], isTransposed = true}> #dotop = #ttg.dot_op<{opIdx = 0, parent = #mfma, kWidth=4}> module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} { - // CHECK-LABEL: shortcut_mfma16 + // GFX942-LABEL: shortcut_mfma16 tt.func public @shortcut_mfma16(%arg0: tensor<16x16xf16, #mfma>) { - // CHECK-NOT: store - // CHECK-NOT: load - // CHECK: llvm.return + // GFX942-NOT: store + // GFX942-NOT: load + // GFX942: llvm.return %0 = ttg.convert_layout %arg0 : tensor<16x16xf16, #mfma> -> tensor<16x16xf16, #dotop> tt.return } @@ -18,11 +19,11 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.thr #mfma = #ttg.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 16], isTransposed = true}> #dotop = #ttg.dot_op<{opIdx = 0, parent = #mfma, kWidth=8}> module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} { - // CHECK-LABEL: no_shortcut_mfma16 + // GFX942-LABEL: no_shortcut_mfma16 tt.func public @no_shortcut_mfma16(%arg0: tensor<16x16xf16, #mfma>) { - // CHECK: store - // CHECK: load - // CHECK: llvm.return + // GFX942: store + // GFX942: load + // GFX942: llvm.return %0 = ttg.convert_layout %arg0 : tensor<16x16xf16, #mfma> -> tensor<16x16xf16, #dotop> tt.return } @@ -34,38 +35,38 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.thr #dotop0 = #ttg.dot_op<{opIdx = 0, parent = #mfma, kWidth=8}> module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} { - // CHECK-LABEL: mfma_dot_cvt_f8_mfma32 + // GFX942-LABEL: mfma_dot_cvt_f8_mfma32 tt.func public @mfma_dot_cvt_f8_mfma32(%arg0: tensor<128x32xf8E4M3FNUZ, #mfma>) { - // CHECK-NOT: store - // CHECK-NOT: load + // GFX942-NOT: store + // GFX942-NOT: load - // CHECK: [[val3:%.*]] = llvm.extractvalue %arg0[3] - // CHECK: [[val7:%.*]] = llvm.extractvalue %arg0[7] + // GFX942: [[val3:%.*]] = llvm.extractvalue %arg0[3] + // GFX942: [[val7:%.*]] = llvm.extractvalue %arg0[7] - // CHECK-DAG: [[c32:%.*]] = llvm.mlir.constant(32 : i32) - // CHECK-DAG: [[c64:%.*]] = llvm.mlir.constant(64 : i32) + // GFX942-DAG: [[c32:%.*]] = llvm.mlir.constant(32 : i32) + // GFX942-DAG: [[c64:%.*]] = llvm.mlir.constant(64 : i32) - // CHECK: [[threadId:%.*]] = rocdl.workitem.id.x - // CHECK: [[laneId:%.*]] = llvm.urem [[threadId]], [[c64]] - // CHECK: [[mask0:%.*]] = llvm.icmp "slt" [[laneId]], [[c32]] + // GFX942: [[threadId:%.*]] = rocdl.workitem.id.x + // GFX942: [[laneId:%.*]] = llvm.urem [[threadId]], [[c64]] + // GFX942: [[mask0:%.*]] = llvm.icmp "slt" [[laneId]], [[c32]] - // CHECK: [[shflLaneId:%.*]] = llvm.add [[laneId]], [[c32]] - // CHECK: [[addr32:%.*]] = llvm.urem [[shflLaneId]], [[c64]] + // GFX942: [[shflLaneId:%.*]] = llvm.add [[laneId]], [[c32]] + // GFX942: [[addr32:%.*]] = llvm.urem [[shflLaneId]], [[c64]] - // CHECK: [[vec0:%.*]] = llvm.insertelement [[val3]], {{.*}} : vector<4xi8> - // CHECK: [[vec1:%.*]] = llvm.insertelement [[val7]], {{.*}} : vector<4xi8> + // GFX942: [[vec0:%.*]] = llvm.insertelement [[val3]], {{.*}} : vector<4xi8> + // GFX942: [[vec1:%.*]] = llvm.insertelement [[val7]], {{.*}} : vector<4xi8> - // CHECK: [[bvec0:%.*]] = llvm.bitcast [[vec0]] - // CHECK: [[c2:%.*]] = llvm.mlir.constant(2 : i32) - // CHECK: [[addr:%.*]] = llvm.shl [[addr32]], [[c2]] - // CHECK: [[bShflVec0:%.*]] = rocdl.ds_bpermute [[addr]], [[bvec0]] - // CHECK: [[shflVec0:%.*]] = llvm.bitcast [[bShflVec0]] + // GFX942: [[bvec0:%.*]] = llvm.bitcast [[vec0]] + // GFX942: [[c2:%.*]] = llvm.mlir.constant(2 : i32) + // GFX942: [[addr:%.*]] = llvm.shl [[addr32]], [[c2]] + // GFX942: [[bShflVec0:%.*]] = rocdl.ds_bpermute [[addr]], [[bvec0]] + // GFX942: [[shflVec0:%.*]] = llvm.bitcast [[bShflVec0]] - // CHECK: [[bvec1:%.*]] = llvm.bitcast [[vec1]] - // CHECK: [[c2:%.*]] = llvm.mlir.constant(2 : i32) - // CHECK: [[addr:%.*]] = llvm.shl [[addr32]], [[c2]] - // CHECK: [[bShflVec1:%.*]] = rocdl.ds_bpermute [[addr]], [[bvec1]] - // CHECK: [[shflVec1:%.*]] = llvm.bitcast [[bShflVec1]] + // GFX942: [[bvec1:%.*]] = llvm.bitcast [[vec1]] + // GFX942: [[c2:%.*]] = llvm.mlir.constant(2 : i32) + // GFX942: [[addr:%.*]] = llvm.shl [[addr32]], [[c2]] + // GFX942: [[bShflVec1:%.*]] = rocdl.ds_bpermute [[addr]], [[bvec1]] + // GFX942: [[shflVec1:%.*]] = llvm.bitcast [[bShflVec1]] // Input (8 values): (vec0, vec1) // Output (8 values shuffled, '>> n' - take the value from (lane + n) % 64): @@ -73,18 +74,18 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.thr // lanes 0-31: (vec0 , vec0 >> 32) (mask0=1) // lanes 32-63: (vec1 >> 32, vec1 ) (mask0=0) - // CHECK: [[resVec0:%.*]] = llvm.select [[mask0]], [[vec0]], [[shflVec1]] - // CHECK: [[resVec1:%.*]] = llvm.select [[mask0]], [[shflVec0]], [[vec1]] + // GFX942: [[resVec0:%.*]] = llvm.select [[mask0]], [[vec0]], [[shflVec1]] + // GFX942: [[resVec1:%.*]] = llvm.select [[mask0]], [[shflVec0]], [[vec1]] - // CHECK: [[c3:%.*]] = llvm.mlir.constant(3 : i32) - // CHECK: [[resVal3:%.*]] = llvm.extractelement [[resVec0]][[[c3]] : i32] : vector<4xi8> - // CHECK: [[c3:%.*]] = llvm.mlir.constant(3 : i32) : i32 - // CHECK: [[resVal7:%.*]] = llvm.extractelement [[resVec1]][[[c3]] : i32] : vector<4xi8> + // GFX942: [[c3:%.*]] = llvm.mlir.constant(3 : i32) + // GFX942: [[resVal3:%.*]] = llvm.extractelement [[resVec0]][[[c3]] : i32] : vector<4xi8> + // GFX942: [[c3:%.*]] = llvm.mlir.constant(3 : i32) : i32 + // GFX942: [[resVal7:%.*]] = llvm.extractelement [[resVec1]][[[c3]] : i32] : vector<4xi8> - // CHECK: llvm.insertvalue [[resVal3]], {{.*}}[3] - // CHECK: llvm.insertvalue [[resVal7]], {{.*}}[7] + // GFX942: llvm.insertvalue [[resVal3]], {{.*}}[3] + // GFX942: llvm.insertvalue [[resVal7]], {{.*}}[7] - // CHECK: llvm.return + // GFX942: llvm.return %0 = ttg.convert_layout %arg0 : tensor<128x32xf8E4M3FNUZ, #mfma> -> tensor<128x32xf8E4M3FNUZ, #dotop0> tt.return } @@ -96,12 +97,12 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.thr #dotop0 = #ttg.dot_op<{opIdx = 0, parent = #mfma, kWidth=8}> module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} { - // CHECK-LABEL: mfma_dot_cvt_bf8_mfma32 + // GFX942-LABEL: mfma_dot_cvt_bf8_mfma32 tt.func public @mfma_dot_cvt_bf8_mfma32(%arg0: tensor<128x32xf8E5M2, #mfma>) { - // CHECK-NOT: store - // CHECK-NOT: load - // CHECK: rocdl.ds_bpermute - // CHECK: llvm.return + // GFX942-NOT: store + // GFX942-NOT: load + // GFX942: rocdl.ds_bpermute + // GFX942: llvm.return %0 = ttg.convert_layout %arg0 : tensor<128x32xf8E5M2, #mfma> -> tensor<128x32xf8E5M2, #dotop0> tt.return } @@ -113,61 +114,61 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.thr #dotop0 = #ttg.dot_op<{opIdx = 0, parent = #mfma, kWidth=8}> module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} { - // CHECK-LABEL: mfma_dot_cvt_f8_mfma16 + // GFX942-LABEL: mfma_dot_cvt_f8_mfma16 tt.func public @mfma_dot_cvt_f8_mfma16(%arg0: tensor<128x32xf8E4M3FNUZ, #mfma>) { - // CHECK-NOT: store - // CHECK-NOT: load + // GFX942-NOT: store + // GFX942-NOT: load - // CHECK: [[val3:%.*]] = llvm.extractvalue %arg0[3] - // CHECK: [[val7:%.*]] = llvm.extractvalue %arg0[7] + // GFX942: [[val3:%.*]] = llvm.extractvalue %arg0[3] + // GFX942: [[val7:%.*]] = llvm.extractvalue %arg0[7] - // CHECK-DAG: [[c16:%.*]] = llvm.mlir.constant(16 : i32) - // CHECK-DAG: [[c32:%.*]] = llvm.mlir.constant(32 : i32) - // CHECK-DAG: [[c48:%.*]] = llvm.mlir.constant(48 : i32) - // CHECK-DAG: [[c64:%.*]] = llvm.mlir.constant(64 : i32) + // GFX942-DAG: [[c16:%.*]] = llvm.mlir.constant(16 : i32) + // GFX942-DAG: [[c32:%.*]] = llvm.mlir.constant(32 : i32) + // GFX942-DAG: [[c48:%.*]] = llvm.mlir.constant(48 : i32) + // GFX942-DAG: [[c64:%.*]] = llvm.mlir.constant(64 : i32) - // CHECK: [[threadId:%.*]] = rocdl.workitem.id.x - // CHECK: [[laneId:%.*]] = llvm.urem [[threadId]], [[c64]] - // CHECK: [[mask0:%.*]] = llvm.icmp "slt" [[laneId]], [[c32]] + // GFX942: [[threadId:%.*]] = rocdl.workitem.id.x + // GFX942: [[laneId:%.*]] = llvm.urem [[threadId]], [[c64]] + // GFX942: [[mask0:%.*]] = llvm.icmp "slt" [[laneId]], [[c32]] - // CHECK: [[laneIdRem:%.*]] = llvm.urem [[laneId]], [[c32]] - // CHECK: [[mask1:%.*]] = llvm.icmp "slt" [[laneIdRem]], [[c16]] + // GFX942: [[laneIdRem:%.*]] = llvm.urem [[laneId]], [[c32]] + // GFX942: [[mask1:%.*]] = llvm.icmp "slt" [[laneIdRem]], [[c16]] - // CHECK: [[shflLaneId:%.*]] = llvm.add [[laneId]], [[c16]] - // CHECK: [[addr16:%.*]] = llvm.urem [[shflLaneId]], [[c64]] + // GFX942: [[shflLaneId:%.*]] = llvm.add [[laneId]], [[c16]] + // GFX942: [[addr16:%.*]] = llvm.urem [[shflLaneId]], [[c64]] - // CHECK: [[shflLaneId:%.*]] = llvm.add [[laneId]], [[c32]] - // CHECK: [[addr32:%.*]] = llvm.urem [[shflLaneId]], [[c64]] + // GFX942: [[shflLaneId:%.*]] = llvm.add [[laneId]], [[c32]] + // GFX942: [[addr32:%.*]] = llvm.urem [[shflLaneId]], [[c64]] - // CHECK: [[shflLaneId:%.*]] = llvm.add [[laneId]], [[c48]] - // CHECK: [[addr48:%.*]] = llvm.urem [[shflLaneId]], [[c64]] + // GFX942: [[shflLaneId:%.*]] = llvm.add [[laneId]], [[c48]] + // GFX942: [[addr48:%.*]] = llvm.urem [[shflLaneId]], [[c64]] - // CHECK: [[vec0:%.*]] = llvm.insertelement [[val3]], {{.*}} : vector<4xi8> - // CHECK: [[vec1:%.*]] = llvm.insertelement [[val7]], {{.*}} : vector<4xi8> + // GFX942: [[vec0:%.*]] = llvm.insertelement [[val3]], {{.*}} : vector<4xi8> + // GFX942: [[vec1:%.*]] = llvm.insertelement [[val7]], {{.*}} : vector<4xi8> - // CHECK: [[bvec0:%.*]] = llvm.bitcast [[vec0]] - // CHECK: [[c2:%.*]] = llvm.mlir.constant(2 : i32) - // CHECK: [[addr:%.*]] = llvm.shl [[addr16]], [[c2]] - // CHECK: [[bShflVec0_16:%.*]] = rocdl.ds_bpermute [[addr]], [[bvec0]] - // CHECK: [[shflVec0_16:%.*]] = llvm.bitcast [[bShflVec0_16]] + // GFX942: [[bvec0:%.*]] = llvm.bitcast [[vec0]] + // GFX942: [[c2:%.*]] = llvm.mlir.constant(2 : i32) + // GFX942: [[addr:%.*]] = llvm.shl [[addr16]], [[c2]] + // GFX942: [[bShflVec0_16:%.*]] = rocdl.ds_bpermute [[addr]], [[bvec0]] + // GFX942: [[shflVec0_16:%.*]] = llvm.bitcast [[bShflVec0_16]] - // CHECK: [[bvec0:%.*]] = llvm.bitcast [[vec0]] - // CHECK: [[c2:%.*]] = llvm.mlir.constant(2 : i32) - // CHECK: [[addr:%.*]] = llvm.shl [[addr32]], [[c2]] - // CHECK: [[bShflVec0_32:%.*]] = rocdl.ds_bpermute [[addr]], [[bvec0]] - // CHECK: [[shflVec0_32:%.*]] = llvm.bitcast [[bShflVec0_32]] + // GFX942: [[bvec0:%.*]] = llvm.bitcast [[vec0]] + // GFX942: [[c2:%.*]] = llvm.mlir.constant(2 : i32) + // GFX942: [[addr:%.*]] = llvm.shl [[addr32]], [[c2]] + // GFX942: [[bShflVec0_32:%.*]] = rocdl.ds_bpermute [[addr]], [[bvec0]] + // GFX942: [[shflVec0_32:%.*]] = llvm.bitcast [[bShflVec0_32]] - // CHECK: [[bvec1:%.*]] = llvm.bitcast [[vec1]] - // CHECK: [[c2:%.*]] = llvm.mlir.constant(2 : i32) - // CHECK: [[addr:%.*]] = llvm.shl [[addr32]], [[c2]] - // CHECK: [[bShflVec1_32:%.*]] = rocdl.ds_bpermute [[addr]], [[bvec1]] - // CHECK: [[shflVec1_32:%.*]] = llvm.bitcast [[bShflVec1_32]] + // GFX942: [[bvec1:%.*]] = llvm.bitcast [[vec1]] + // GFX942: [[c2:%.*]] = llvm.mlir.constant(2 : i32) + // GFX942: [[addr:%.*]] = llvm.shl [[addr32]], [[c2]] + // GFX942: [[bShflVec1_32:%.*]] = rocdl.ds_bpermute [[addr]], [[bvec1]] + // GFX942: [[shflVec1_32:%.*]] = llvm.bitcast [[bShflVec1_32]] - // CHECK: [[bvec1:%.*]] = llvm.bitcast [[vec1]] - // CHECK: [[c2:%.*]] = llvm.mlir.constant(2 : i32) - // CHECK: [[addr:%.*]] = llvm.shl [[addr48]], [[c2]] - // CHECK: [[bShflVec1_48:%.*]] = rocdl.ds_bpermute [[addr]], [[bvec1]] - // CHECK: [[shflVec1_48:%.*]] = llvm.bitcast [[bShflVec1_48]] + // GFX942: [[bvec1:%.*]] = llvm.bitcast [[vec1]] + // GFX942: [[c2:%.*]] = llvm.mlir.constant(2 : i32) + // GFX942: [[addr:%.*]] = llvm.shl [[addr48]], [[c2]] + // GFX942: [[bShflVec1_48:%.*]] = rocdl.ds_bpermute [[addr]], [[bvec1]] + // GFX942: [[shflVec1_48:%.*]] = llvm.bitcast [[bShflVec1_48]] // Input (8 values): (vec0, vec1) // Output (8 values shuffled, '>> n' - take the value from (lane + n) % 64): @@ -177,23 +178,23 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.thr // lanes 32-47: (vec1 >> 32, vec1 >> 48) (mask0=0, mask1=1) // lanes 48-63: (vec1 >> 48, vec1 ) (mask0=0, mask1=0) - // CHECK-DAG: [[mask0_true:%.*]] = llvm.select [[mask1]], [[vec0]], [[shflVec0_16]] : i1, vector<4xi8> - // CHECK-DAG: [[mask0_false:%.*]] = llvm.select [[mask1]], [[shflVec1_32]], [[shflVec1_48]] : i1, vector<4xi8> - // CHECK: [[resVec0:%.*]] = llvm.select [[mask0]], [[mask0_true]], [[mask0_false]] : i1, vector<4xi8> + // GFX942-DAG: [[mask0_true:%.*]] = llvm.select [[mask1]], [[vec0]], [[shflVec0_16]] : i1, vector<4xi8> + // GFX942-DAG: [[mask0_false:%.*]] = llvm.select [[mask1]], [[shflVec1_32]], [[shflVec1_48]] : i1, vector<4xi8> + // GFX942: [[resVec0:%.*]] = llvm.select [[mask0]], [[mask0_true]], [[mask0_false]] : i1, vector<4xi8> - // CHECK-DAG: [[mask0_true:%.*]] = llvm.select [[mask1]], [[shflVec0_16]], [[shflVec0_32]] : i1, vector<4xi8> - // CHECK-DAG: [[mask0_false:%.*]] = llvm.select [[mask1]], [[shflVec1_48]], [[vec1]] : i1, vector<4xi8> - // CHECK: [[resVec1:%.*]] = llvm.select [[mask0]], [[mask0_true]], [[mask0_false]] : i1, vector<4xi8> + // GFX942-DAG: [[mask0_true:%.*]] = llvm.select [[mask1]], [[shflVec0_16]], [[shflVec0_32]] : i1, vector<4xi8> + // GFX942-DAG: [[mask0_false:%.*]] = llvm.select [[mask1]], [[shflVec1_48]], [[vec1]] : i1, vector<4xi8> + // GFX942: [[resVec1:%.*]] = llvm.select [[mask0]], [[mask0_true]], [[mask0_false]] : i1, vector<4xi8> - // CHECK: [[c3:%.*]] = llvm.mlir.constant(3 : i32) - // CHECK: [[resVal3:%.*]] = llvm.extractelement [[resVec0]][[[c3]] : i32] : vector<4xi8> - // CHECK: [[c3:%.*]] = llvm.mlir.constant(3 : i32) : i32 - // CHECK: [[resVal7:%.*]] = llvm.extractelement [[resVec1]][[[c3]] : i32] : vector<4xi8> + // GFX942: [[c3:%.*]] = llvm.mlir.constant(3 : i32) + // GFX942: [[resVal3:%.*]] = llvm.extractelement [[resVec0]][[[c3]] : i32] : vector<4xi8> + // GFX942: [[c3:%.*]] = llvm.mlir.constant(3 : i32) : i32 + // GFX942: [[resVal7:%.*]] = llvm.extractelement [[resVec1]][[[c3]] : i32] : vector<4xi8> - // CHECK: llvm.insertvalue [[resVal3]], {{.*}}[3] - // CHECK: llvm.insertvalue [[resVal7]], {{.*}}[7] + // GFX942: llvm.insertvalue [[resVal3]], {{.*}}[3] + // GFX942: llvm.insertvalue [[resVal7]], {{.*}}[7] - // CHECK: llvm.return + // GFX942: llvm.return %0 = ttg.convert_layout %arg0 : tensor<128x32xf8E4M3FNUZ, #mfma> -> tensor<128x32xf8E4M3FNUZ, #dotop0> tt.return } @@ -205,13 +206,27 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.thr #dotop0 = #ttg.dot_op<{opIdx = 0, parent = #mfma, kWidth=8}> module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} { - // CHECK-LABEL: mfma_dot_cvt_bf8_mfma16 + // GFX942-LABEL: mfma_dot_cvt_bf8_mfma16 tt.func public @mfma_dot_cvt_bf8_mfma16(%arg0: tensor<128x32xf8E5M2, #mfma>) { - // CHECK-NOT: store - // CHECK-NOT: load - // CHECK: rocdl.ds_bpermute - // CHECK: llvm.return + // GFX942-NOT: store + // GFX942-NOT: load + // GFX942: rocdl.ds_bpermute + // GFX942: llvm.return %0 = ttg.convert_layout %arg0 : tensor<128x32xf8E5M2, #mfma> -> tensor<128x32xf8E5M2, #dotop0> tt.return } } + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [16, 4], warpsPerCTA = [4, 1], order = [0, 1]}> +#linear = #ttg.linear<{register = [[0, 1], [0, 2], [0, 4], [0, 16], [0, 32], [0, 64]], lane = [[1, 0], [2, 0], [4, 0], [8, 0], [16, 0], [0, 8]], warp = [[32, 0], [64, 0]], block = []}> +#mma = #ttg.amd_mfma<{versionMajor = 4, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [32, 32], isTransposed = true}> +module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} { + // GFX950-LABEL: mfma_linear_permlane_swap + tt.func public @mfma_linear_permlane_swap(%arg0: tensor<128x128xf16, #mma>) attributes {noinline = false} { + // GFX950-COUNT-16: llvm.call_intrinsic "llvm.amdgcn.permlane32.swap" + %1 = ttg.convert_layout %arg0: tensor<128x128xf16, #mma> -> tensor<128x128xf16, #linear> + tt.return + } +} diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM.cpp index ef55a3448950..7af92231b4a1 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM.cpp @@ -8,6 +8,7 @@ using ::mlir::triton::gpu::AMDMfmaEncodingAttr; using ::mlir::triton::gpu::AMDWmmaEncodingAttr; using ::mlir::triton::gpu::DotOperandEncodingAttr; using ::mlir::triton::gpu::MemDescType; +using ::triton::gpu::LinearEncodingAttr; namespace SharedToDotOperandMFMA { Value convertLayout(int opIdx, ConversionPatternRewriter &rewriter, @@ -180,6 +181,118 @@ struct ConvertLayoutOpMFMAToDotOpConversion const TargetInfoBase &targetInfo; }; +// Match MFMA->Linear Layout conversion +static bool matchMFMAAndLinearLayoutCase(RankedTensorType srcTy, + RankedTensorType dstTy) { + auto mfmaLayout = dyn_cast(srcTy.getEncoding()); + auto linearLayout = dyn_cast(dstTy.getEncoding()); + if (!mfmaLayout || !linearLayout) + return false; + + std::optional srcLL = + mlir::triton::gpu::chooseMfmaLikeStoreLayout(srcTy); + if (!srcLL) + return false; + + MLIRContext *ctx = linearLayout.getContext(); + StringAttr kLane = StringAttr::get(ctx, "lane"); + StringAttr kRegister = StringAttr::get(ctx, "register"); + auto srcBase = srcLL.value().getBases(); + auto srcReg = srcBase.lookup(kRegister); + auto srcLane = srcBase.lookup(kLane); + auto dstBases = linearLayout.getLinearLayout().getBases(); + auto dstReg = dstBases.lookup(kRegister); + auto dstLane = dstBases.lookup(kLane); + return dstReg == srcReg && dstLane == srcLane; +}; + +struct ConvertLayoutOpMFMAToLinearConversion + : public ConvertOpToLLVMPattern { +public: + explicit ConvertLayoutOpMFMAToLinearConversion( + LLVMTypeConverter &typeConverter, const TargetInfoBase &targetInfo, + PatternBenefit benefit) + : ConvertOpToLLVMPattern(typeConverter, + benefit), + targetInfo(targetInfo) {} + + LogicalResult + matchAndRewrite(triton::gpu::ConvertLayoutOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto srcType = cast(op.getSrc().getType()); + auto dstType = cast(op.getType()); + + if (!matchMFMAAndLinearLayoutCase(srcType, dstType)) + return failure(); + + auto loc = op.getLoc(); + auto b = TritonLLVMOpBuilder(loc, rewriter); + + SmallVector inVals = + unpackLLElements(loc, adaptor.getSrc(), rewriter); + if (inVals.empty() || inVals.size() % 8 != 0) + return failure(); + + auto mfmaLayout = dyn_cast(srcType.getEncoding()); + assert(mfmaLayout.getMDim() == 32 && "Expected MFMA size 32"); + assert(triton::gpu::lookupThreadsPerWarp(rewriter) == 64 && + "Expected warp size 64 for MFMA"); + + auto elemTy = srcType.getElementType(); + auto vecTy = vec_ty(elemTy, 2); + + SmallVector outVals; + auto idx0 = b.i32_val(0); + auto idx1 = b.i32_val(1); + // Convert MFMA layout to a MFMA-like linear layout where each thread + // holds 8 consecutive elements + for (size_t idx = 0; idx < inVals.size(); idx += 8) { + SmallVector inVecs; + for (size_t vIdx = 0; vIdx < 4; vIdx++) { + Value vec = b.undef(vecTy); + vec = b.insert_element(vecTy, vec, inVals[idx + vIdx * 2 + 0], idx0); + vec = b.insert_element(vecTy, vec, inVals[idx + vIdx * 2 + 1], idx1); + inVecs.push_back(vec); + } + + Value resVec0, resVec1, resVec2, resVec3; + + // Swap the row 2 and 3 of vec0 and the row 0 and 1 of vec2 + MLIRContext *ctx = rewriter.getContext(); + Type retType = struct_ty({i32_ty, i32_ty}); + Value falseVal = b.false_val(); + Value perm = + LLVM::createLLVMIntrinsicCallOp( + rewriter, loc, "llvm.amdgcn.permlane32.swap", retType, + ValueRange{b.bitcast(inVecs[0], i32_ty), + b.bitcast(inVecs[2], i32_ty), falseVal, falseVal}) + ->getResult(0); + resVec0 = b.bitcast(b.extract_val(i32_ty, perm, 0), vecTy); + resVec2 = b.bitcast(b.extract_val(i32_ty, perm, 1), vecTy); + + // Swap the row 2 and 3 of vec1 and the row 0 and 1 of vec3 + perm = LLVM::createLLVMIntrinsicCallOp( + rewriter, loc, "llvm.amdgcn.permlane32.swap", retType, + ValueRange{b.bitcast(inVecs[1], i32_ty), + b.bitcast(inVecs[3], i32_ty), falseVal, falseVal}) + ->getResult(0); + resVec1 = b.bitcast(b.extract_val(i32_ty, perm, 0), vecTy); + resVec3 = b.bitcast(b.extract_val(i32_ty, perm, 1), vecTy); + + for (Value res : {resVec0, resVec1, resVec2, resVec3}) + for (Value idx : {idx0, idx1}) + outVals.push_back(b.extract_element(elemTy, res, idx)); + } + + Value result = packLLElements(loc, getTypeConverter(), outVals, rewriter, + op.getType()); + rewriter.replaceOp(op, result); + return success(); + } + +protected: + const TargetInfoBase &targetInfo; +}; } // namespace void mlir::triton::AMD::populateConvertLayoutOpToLLVMPatterns( @@ -187,4 +300,6 @@ void mlir::triton::AMD::populateConvertLayoutOpToLLVMPatterns( RewritePatternSet &patterns, PatternBenefit benefit) { patterns.add(typeConverter, targetInfo, benefit); + patterns.add(typeConverter, targetInfo, + benefit); } diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/OptimizeEpilogue.cpp b/third_party/amd/lib/TritonAMDGPUTransforms/OptimizeEpilogue.cpp index a613a54b79a0..a14e42d2d93f 100644 --- a/third_party/amd/lib/TritonAMDGPUTransforms/OptimizeEpilogue.cpp +++ b/third_party/amd/lib/TritonAMDGPUTransforms/OptimizeEpilogue.cpp @@ -69,25 +69,17 @@ static triton::StoreOp convertMfmaLayoutForCDNA4(PatternRewriter &rewriter, auto mfmaLayout = cast(valType.getEncoding()); - bool mfma32 = mfmaLayout.getMDim() == 32 && mfmaLayout.getNDim() == 32; + // Create a new layout where each thread holds 8 consecutive elements, in + // order to enable wide 128-bit global stores. + std::optional mfma8Layout = + triton::gpu::chooseMfmaLikeStoreLayout(valType); - if (valType.getRank() != 2 || - (!valType.getElementType().isF16() && - !valType.getElementType().isBF16()) || - mfmaLayout.getVersionMajor() != 4 || !mfmaLayout.getIsTransposed() || - !mfma32) { + if (!mfma8Layout) return rewriter.create(oldStOp.getLoc(), ptr, val, mask, oldStOp.getCache(), oldStOp.getEvict()); - } - - // Create a new layout where each thread holds 8 consecutive elements, in - // order to enable wide 128-bit global stores. - triton::LinearLayout mfma8Layout = - chooseMfmaLikeStoreLayout(mfmaLayout, valType.getShape()); - Attribute newEncoding = triton::gpu::LinearEncodingAttr::get( - mfmaLayout.getContext(), mfma8Layout); + mfmaLayout.getContext(), mfma8Layout.value()); auto newPtrType = RankedTensorType::get( ptrType.getShape(), ptrType.getElementType(), newEncoding); Value newPtr = rewriter.create(ptr.getLoc(), From 0f7bbc28b1946c078730b6bea23a9ee48e74b08b Mon Sep 17 00:00:00 2001 From: Lei Zhang Date: Fri, 16 May 2025 16:01:40 -0700 Subject: [PATCH 36/44] [AMD] Use composition to swap columns for mfma like store layout (#6844) This commit improves how we create the mfma-like layout for optimizing global store by using linear layout composition. Along the way fixes a few implemenation issues. --------- Co-authored-by: Yi Qian --- .../TritonGPU/IR/LinearLayoutConversions.cpp | 49 +++++++++---------- test/TritonGPU/amd/amd-optimize-epilogue.mlir | 27 +++++++++- .../ConvertLayoutOpToLLVM.cpp | 17 ++----- .../OptimizeEpilogue.cpp | 37 ++++++-------- 4 files changed, 65 insertions(+), 65 deletions(-) diff --git a/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp b/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp index d1397549de27..43796875be51 100644 --- a/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp +++ b/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp @@ -1537,7 +1537,7 @@ std::optional chooseMfmaLikeStoreLayout(RankedTensorType valType) { auto mfmaLayout = cast(valType.getEncoding()); - // Currently support transposed [B]F16 MFMA32x32 on CDNA4 + // We currently only support transposed [B]F16 MFMA32x32 on CDNA4. bool isMfma32 = mfmaLayout.getMDim() == 32 && mfmaLayout.getNDim() == 32; Type elemType = valType.getElementType(); if (!(valType.getRank() == 2 && (elemType.isF16() || elemType.isBF16()) && @@ -1545,32 +1545,27 @@ chooseMfmaLikeStoreLayout(RankedTensorType valType) { isMfma32)) return {}; - MLIRContext *ctx = mfmaLayout.getContext(); - StringAttr kRegister = S("register"); - StringAttr kLane = S("lane"); - StringAttr kWarp = S("warp"); - StringAttr kBlock = S("block"); - - SmallVector order = getDefaultMmaOrder(mfmaLayout); - auto standardOutDims = standardOutDimNames(ctx, 2); - // We make each thread handle 8 consecutive elements to enable 128-bit - // global stores for [b]f16 types and keep the thread pattern in each lane - // similar to the canonical mfmaLayout. - LinearLayout mfma8Layout = LinearLayout::empty(); - mfma8Layout = - LinearLayout({{kRegister, {{1, 0}, {2, 0}, {4, 0}}}, - {kLane, {{0, 1}, {0, 2}, {0, 4}, {0, 8}, {0, 16}, {8, 0}}}, - {kWarp, {}}, - {kBlock, {}}}, - {standardOutDims[order[0]], standardOutDims[order[1]]}); - - LinearLayout warpLayout = - identityStandardND(kWarp, mfmaLayout.getWarpsPerCTA(), order); - LinearLayout ctaLayout = mfma8Layout.transposeOuts(standardOutDims) * - warpLayout.transposeOuts(standardOutDims); - mfma8Layout = combineCtaCgaWithShape(ctaLayout, mfmaLayout.getCTALayout(), - valType.getShape()); - return mfma8Layout; + auto valShape = valType.getShape(); + LinearLayout mfmaLL = mfmaLayout.toLinearLayout(valShape); + auto mfmaOutDims = llvm::to_vector(mfmaLL.getOutDimNames()); + StringAttr dimM = mfmaOutDims[0]; + StringAttr dimN = mfmaOutDims[1]; + + auto swapLL = LinearLayout::empty(); + // The rows are kept as is with an identity linear layout. + swapLL *= LinearLayout::identity1D(valShape[0], dimM, dimM); + // In transposed mfma32 layout, each thread holds 4 consecutive values along N + // dim. We want to exchange column 4-7 (owned by thread 32-63) and column 8-11 + // (owned by thread 0-31) every 16 columns to make each thread holds 8 + // elements. This would mean exchange the 2nd and 3rd basis vector from an + // identity linear layout. + std::vector> dimNBases(mfmaLL.getOutDimSizeLog2(dimN)); + std::generate(dimNBases.begin(), dimNBases.end(), + [i = 0]() mutable { return std::vector{1 << i++}; }); + std::swap(dimNBases[2], dimNBases[3]); + swapLL *= LinearLayout({{dimN, dimNBases}}, {dimN}); + + return mfmaLL.compose(swapLL); } LinearLayout getScaleTMEMStoreLinearLayout(RankedTensorType scaleType, diff --git a/test/TritonGPU/amd/amd-optimize-epilogue.mlir b/test/TritonGPU/amd/amd-optimize-epilogue.mlir index 9c0d91881f2f..e84485fcd5bc 100644 --- a/test/TritonGPU/amd/amd-optimize-epilogue.mlir +++ b/test/TritonGPU/amd/amd-optimize-epilogue.mlir @@ -43,7 +43,7 @@ module attributes {"ttg.num-warps" = 1 : i32, "ttg.threads-per-warp" = 64 : i32} // ----- // CHECK{LITERAL}: #linear = #ttg.linear<{register = [[0, 1], [0, 2], [0, 4], [0, 16], [0, 32], [0, 64]], lane = [[1, 0], [2, 0], [4, 0], [8, 0], [16, 0], [0, 8]], warp = [[32, 0], [64, 0]], block = []}> -// CHECK-LABEL: store_dword +// CHECK-LABEL: store_dword_128x128 // CHECK-NOT: ttg.convert_layout %{{.*}} : tensor<128x128xf32, #mma> -> tensor<128x128xf32, #blocked> // CHECK-DAG: %[[PTR:.+]] = ttg.convert_layout %{{.*}} : tensor<128x128x!tt.ptr, #mma> -> tensor<128x128x!tt.ptr, #linear> // CHECK-DAG: %[[VAL:.+]] = ttg.convert_layout %{{.*}} : tensor<128x128xf16, #mma> -> tensor<128x128xf16, #linear> @@ -51,7 +51,7 @@ module attributes {"ttg.num-warps" = 1 : i32, "ttg.threads-per-warp" = 64 : i32} #blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [16, 4], warpsPerCTA = [4, 1], order = [0, 1]}> #mma = #ttg.amd_mfma<{versionMajor = 4, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [32, 32], isTransposed = true}> module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} { - tt.func public @store_dword(%arg0: !tt.ptr) attributes {noinline = false} { + tt.func public @store_dword_128x128(%arg0: !tt.ptr) attributes {noinline = false} { %cst = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #mma> %cst_0 = arith.constant dense<1.230000e+02> : tensor<128x128xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> %cst_1 = arith.constant dense<1.230000e+02> : tensor<128x128xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> @@ -63,3 +63,26 @@ module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} tt.return } } + +// ----- +// CHECK{LITERAL}: #linear = #ttg.linear<{register = [[0, 1], [0, 2], [0, 4], [0, 16], [0, 128], [64, 0], [128, 0]], lane = [[1, 0], [2, 0], [4, 0], [8, 0], [16, 0], [0, 8]], warp = [[0, 32], [0, 64], [32, 0]], block = []}> +// CHECK-LABEL: store_dword_256x256 +// CHECK-NOT: ttg.convert_layout %{{.*}} : tensor<256x256xf32, #mma> -> tensor<256x256xf32, #blocked> +// CHECK-DAG: %[[PTR:.+]] = ttg.convert_layout %{{.*}} : tensor<256x256x!tt.ptr, #mma> -> tensor<256x256x!tt.ptr, #linear> +// CHECK-DAG: %[[VAL:.+]] = ttg.convert_layout %{{.*}} : tensor<256x256xf16, #mma> -> tensor<256x256xf16, #linear> +// CHECK: tt.store %[[PTR]], %[[VAL]] : tensor<256x256x!tt.ptr, #linear> +#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [2, 32], warpsPerCTA = [8, 1], order = [1, 0]}> +#mma = #ttg.amd_mfma<{versionMajor = 4, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [32, 32], isTransposed = true}> +module attributes {"ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 64 : i32} { + tt.func public @store_dword_256x256(%arg0: !tt.ptr) attributes {noinline = false} { + %cst = arith.constant dense<0.000000e+00> : tensor<256x256xf32, #mma> + %cst_0 = arith.constant dense<1.230000e+02> : tensor<256x256xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> + %cst_1 = arith.constant dense<1.230000e+02> : tensor<256x256xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> + %0 = tt.dot %cst_0, %cst_1, %cst : tensor<256x256xf32, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<256x256xf32, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<256x256xf32, #mma> + %1 = ttg.convert_layout %0 : tensor<256x256xf32, #mma> -> tensor<256x256xf32, #blocked> + %2 = arith.truncf %1 : tensor<256x256xf32, #blocked> to tensor<256x256xf16, #blocked> + %3 = tt.splat %arg0 : !tt.ptr -> tensor<256x256x!tt.ptr, #blocked> + tt.store %3, %2 : tensor<256x256x!tt.ptr, #blocked> + tt.return + } +} diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM.cpp index 7af92231b4a1..1d2e2e039491 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM.cpp @@ -189,21 +189,10 @@ static bool matchMFMAAndLinearLayoutCase(RankedTensorType srcTy, if (!mfmaLayout || !linearLayout) return false; - std::optional srcLL = + std::optional storeLL = mlir::triton::gpu::chooseMfmaLikeStoreLayout(srcTy); - if (!srcLL) - return false; - - MLIRContext *ctx = linearLayout.getContext(); - StringAttr kLane = StringAttr::get(ctx, "lane"); - StringAttr kRegister = StringAttr::get(ctx, "register"); - auto srcBase = srcLL.value().getBases(); - auto srcReg = srcBase.lookup(kRegister); - auto srcLane = srcBase.lookup(kLane); - auto dstBases = linearLayout.getLinearLayout().getBases(); - auto dstReg = dstBases.lookup(kRegister); - auto dstLane = dstBases.lookup(kLane); - return dstReg == srcReg && dstLane == srcLane; + return linearLayout.getLinearLayout() == + storeLL.value_or(LinearLayout::empty()); }; struct ConvertLayoutOpMFMAToLinearConversion diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/OptimizeEpilogue.cpp b/third_party/amd/lib/TritonAMDGPUTransforms/OptimizeEpilogue.cpp index a14e42d2d93f..0c3bb3e44966 100644 --- a/third_party/amd/lib/TritonAMDGPUTransforms/OptimizeEpilogue.cpp +++ b/third_party/amd/lib/TritonAMDGPUTransforms/OptimizeEpilogue.cpp @@ -59,27 +59,23 @@ bool isOneOperandElementwiseOp(Operation *op) { return false; } -static triton::StoreOp convertMfmaLayoutForCDNA4(PatternRewriter &rewriter, - Value ptr, Value val, - Value mask, - triton::StoreOp oldStOp) { +// Tries to optimize oldStoreOp with v_permlane*_swap instruction when possible. +// Returns null store op if not suitable. +static triton::StoreOp +usePermlaneSwapToOptimizeStore(PatternRewriter &rewriter, Value ptr, Value val, + Value mask, triton::StoreOp oldStoreOp) { auto ptrType = cast(ptr.getType()); auto valType = cast(val.getType()); - auto mfmaLayout = - cast(valType.getEncoding()); - // Create a new layout where each thread holds 8 consecutive elements, in // order to enable wide 128-bit global stores. - std::optional mfma8Layout = + std::optional storeLL = triton::gpu::chooseMfmaLikeStoreLayout(valType); + if (!storeLL) + return nullptr; - if (!mfma8Layout) - return rewriter.create(oldStOp.getLoc(), ptr, val, mask, - oldStOp.getCache(), - oldStOp.getEvict()); Attribute newEncoding = triton::gpu::LinearEncodingAttr::get( - mfmaLayout.getContext(), mfma8Layout.value()); + oldStoreOp.getContext(), storeLL.value()); auto newPtrType = RankedTensorType::get( ptrType.getShape(), ptrType.getElementType(), newEncoding); Value newPtr = rewriter.create(ptr.getLoc(), @@ -99,9 +95,9 @@ static triton::StoreOp convertMfmaLayoutForCDNA4(PatternRewriter &rewriter, newMaskType, mask); } - return rewriter.create(oldStOp.getLoc(), newPtr, newVal, - newMask, oldStOp.getCache(), - oldStOp.getEvict()); + return rewriter.create(oldStoreOp.getLoc(), newPtr, newVal, + newMask, oldStoreOp.getCache(), + oldStoreOp.getEvict()); } // convert(val) : xmma -> blocked @@ -195,12 +191,9 @@ class BypassEpilogueSMEM : public mlir::OpRewritePattern { newMask = rewriter.create( mask.getLoc(), newMaskType, mask); } - triton::StoreOp newStoreOp; - if (auto mfmaLayout = - dyn_cast(newEncoding)) { - newStoreOp = - convertMfmaLayoutForCDNA4(rewriter, newPtr, newVal, newMask, stOp); - } else { + triton::StoreOp newStoreOp = + usePermlaneSwapToOptimizeStore(rewriter, newPtr, newVal, newMask, stOp); + if (!newStoreOp) { newStoreOp = rewriter.create( stOp.getLoc(), newPtr, newVal, newMask, stOp.getCache(), stOp.getEvict()); From aebdfd79318ea8a50d05e8ade2214dc2878d136b Mon Sep 17 00:00:00 2001 From: Alexander Weinrauch Date: Fri, 16 May 2025 14:46:06 +0000 Subject: [PATCH 37/44] [ASYNCCOPY] Simplify swizzling calculations to get better codegen from the backend --- .../Conversion/TritonGPUToLLVM/Utility.h | 6 +- include/triton/Tools/Sys/GetEnv.hpp | 1 + lib/Conversion/TritonGPUToLLVM/Utility.cpp | 19 ++- .../amd/buffer_load_to_local_to_llvm.mlir | 6 +- .../TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp | 129 +++++++++++++++++- 5 files changed, 151 insertions(+), 10 deletions(-) diff --git a/include/triton/Conversion/TritonGPUToLLVM/Utility.h b/include/triton/Conversion/TritonGPUToLLVM/Utility.h index 1159a48ca471..09ef9904b17e 100644 --- a/include/triton/Conversion/TritonGPUToLLVM/Utility.h +++ b/include/triton/Conversion/TritonGPUToLLVM/Utility.h @@ -698,13 +698,15 @@ emitIndices(Location loc, RewriterBase &rewriter, const TargetInfoBase &target, Type elemLlvmTy, std::optional maxVecElems, const SharedMemoryObject &smemObj, Location loc, RewriterBase &rewriter, const TargetInfoBase &target, - std::function perVectorCallback); + std::function perVectorCallback, + bool forceLane0 = false); [[nodiscard]] bool emitTransferBetweenRegistersAndShared( LinearLayout ®Layout, triton::gpu::MemDescType sharedTy, Type elemLlvmTy, std::optional maxVecElems, const SharedMemoryObject &smemObj, Location loc, RewriterBase &rewriter, const TargetInfoBase &target, - std::function perVectorCallback); + std::function perVectorCallback, + bool forceLane0 = false); SmallVector loadSharedToDistributed(triton::gpu::LocalLoadOp localLoadOp, Type elemLlvmTy, diff --git a/include/triton/Tools/Sys/GetEnv.hpp b/include/triton/Tools/Sys/GetEnv.hpp index 8dcd2917b6b5..112161cd5b1a 100644 --- a/include/triton/Tools/Sys/GetEnv.hpp +++ b/include/triton/Tools/Sys/GetEnv.hpp @@ -38,6 +38,7 @@ inline const std::set CACHE_INVALIDATING_ENV_VARS = { "TRITON_HIP_ENABLE_F16_ASYNC_PINGPONG", "TRITON_HIP_USE_BLOCK_PINGPONG", "TRITON_HIP_USE_IN_THREAD_TRANSPOSE", + "TRITON_HIP_ASYNC_FAST_SWIZZLE", "TRITON_LLVM_DEBUG_ONLY", "TRITON_ENABLE_ASAN", "TRITON_OVERRIDE_ARCH", diff --git a/lib/Conversion/TritonGPUToLLVM/Utility.cpp b/lib/Conversion/TritonGPUToLLVM/Utility.cpp index e09a08105926..8f691b673d6d 100644 --- a/lib/Conversion/TritonGPUToLLVM/Utility.cpp +++ b/lib/Conversion/TritonGPUToLLVM/Utility.cpp @@ -409,7 +409,8 @@ bool emitTransferBetweenRegistersAndShared( LinearLayout ®Layout, triton::gpu::MemDescType sharedTy, Type elemLlvmTy, std::optional maxVecElems, const SharedMemoryObject &smemObj, Location loc, RewriterBase &rewriter, const TargetInfoBase &target, - std::function perVectorCallback) { + std::function perVectorCallback, + bool forceLane0) { MLIRContext *ctx = rewriter.getContext(); auto b = TritonLLVMOpBuilder(loc, rewriter); @@ -452,6 +453,17 @@ bool emitTransferBetweenRegistersAndShared( auto withCTAOffset = triton::gpu::getNumCTAs(sharedTy.getEncoding()) > 1; auto [laneId, warpId] = getLaneAndWarpId(rewriter, loc); + if (forceLane0) { + laneId = b.i32_val(0); + // NFC it's copied from getLaneAndWarpId but we add a shuffleIdx(0) to the + // tid so LLVM sees that warpId is a scalar + // This is not optimal as it adds a readlane which is not necessary but + // better than getting readfirstlanes for every direct-to-lds load + Value tid = target.shuffleIdx(rewriter, loc, getThreadId(rewriter, loc), 0); + int threadsPerWarp = triton::gpu::lookupThreadsPerWarp(rewriter); + Value warpSizeVal = b.i32_val(threadsPerWarp); + warpId = b.udiv(tid, warpSizeVal); + } Value blockId = withCTAOffset ? target.getClusterCTAId(rewriter, loc) : b.i32_val(0); @@ -486,12 +498,13 @@ bool emitTransferBetweenRegistersAndShared( Type elemLlvmTy, std::optional maxVecElems, const SharedMemoryObject &smemObj, Location loc, RewriterBase &rewriter, const TargetInfoBase &target, - std::function perVectorCallback) { + std::function perVectorCallback, + bool forceLane0) { auto regLayout = triton::gpu::toLinearLayout(registerTy.getShape(), registerTy.getEncoding()); return emitTransferBetweenRegistersAndShared( regLayout, sharedTy, elemLlvmTy, maxVecElems, smemObj, loc, rewriter, - target, perVectorCallback); + target, perVectorCallback, forceLane0); } SmallVector loadSharedToDistributed(triton::gpu::LocalLoadOp localLoadOp, diff --git a/test/Conversion/amd/buffer_load_to_local_to_llvm.mlir b/test/Conversion/amd/buffer_load_to_local_to_llvm.mlir index 04fc1397d626..2892fcec625c 100644 --- a/test/Conversion/amd/buffer_load_to_local_to_llvm.mlir +++ b/test/Conversion/amd/buffer_load_to_local_to_llvm.mlir @@ -145,10 +145,12 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, "ttg.thr %arg2: !ttg.memdesc<64xf32, #shared, #smem, mutable>) { %0 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #blocked> // The first constant 0 skips the LDS offset which is also 0 - // COMMON: llvm.getelementptr + // COMMON: rocdl.make.buffer.rsrc + // COMMON: llvm.select // COMMON: llvm.mlir.constant(0 : i32) : i32 // COMMON: %[[aux_ca:.*]] = llvm.mlir.constant(0 : i32) : i32 - // COMMON: rocdl.raw.ptr.buffer.load.lds {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, %[[aux_ca]] + // COMMON: llvm.mlir.constant(0 : i32) : i32 + // COMMON-: rocdl.raw.ptr.buffer.load.lds {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, %[[aux_ca]] %1 = amdgpu.buffer_load_to_local %arg0[%0] cacheModifier = ca into %arg2: [tensor<64xi32, #blocked>] -> <64xf32, #shared, #smem, mutable> // COMMON: llvm.getelementptr // COMMON: %[[aux_cg:.*]] = llvm.mlir.constant(3 : i32) : i32 diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp index a668b6031782..96e0cc00c629 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp @@ -14,6 +14,7 @@ #include "mlir/Transforms/DialectConversion.h" #include "triton/Conversion/TritonGPUToLLVM/Utility.h" #include "triton/Dialect/Triton/IR/Types.h" +#include "triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h" #include "triton/Dialect/TritonGPU/Transforms/Utility.h" using namespace mlir; @@ -238,6 +239,114 @@ struct DirectToLdsLoadConversionBase : public LoadStoreConversionBase { } } + SmallVector emitSharedBaseAddr(RewriterBase &rewriter, Operation *op, + RankedTensorType srcTy, + MemDescType dstTy, bool hasSwizzling, + Type resElemTy, Value llDst, + VectorType &vecTy) const { + auto emitSharedAddresses = [&](MemDescType dstTy, + SmallVector &shmemAddrs, + VectorType &vecTy, bool forceLane0) { + auto loc = op->getLoc(); + auto smemObj = mlir::LLVM::getSharedMemoryObjectFromStruct( + loc, llDst, resElemTy, rewriter); + bool ok = emitTransferBetweenRegistersAndShared( + srcTy, dstTy, resElemTy, {}, smemObj, loc, rewriter, targetInfo, + [&](VectorType vecTy_, Value shmemAddr) { + vecTy = vecTy_; + shmemAddrs.push_back(shmemAddr); + }, + forceLane0); + assert(ok); + }; + + if (hasSwizzling) { + // Rewrite dstTy to be coalesced + auto dstEnc = cast(dstTy.getEncoding()); + auto flatSharedEnc = SwizzledSharedEncodingAttr::get( + op->getContext(), dstEnc.getVec(), 1, 1, dstEnc.getOrder(), + dstEnc.getCTALayout()); + dstTy = MemDescType::get(dstTy.getShape(), dstTy.getElementType(), + flatSharedEnc, dstTy.getMemorySpace()); + } + SmallVector ldsAddrs; + emitSharedAddresses(dstTy, ldsAddrs, vecTy, true); + return ldsAddrs; + } + + SmallVector emitSwizzleOffsets(Operation *op, RewriterBase &rewriter, + RankedTensorType srcTy, + MemDescType dstTy, VectorType vecTy, + int numberOfLoads) const { + auto loc = op->getLoc(); + TritonLLVMOpBuilder b(loc, rewriter); + + // Compute swizzle offsets + auto regLayout = + triton::gpu::toLinearLayout(srcTy.getShape(), srcTy.getEncoding()); + auto shape = dstTy.getShape(); + LinearLayout sharedLayout = + triton::gpu::toLinearLayout(shape, dstTy.getEncoding()); + LinearLayout regToSharedLayout = regLayout.invertAndCompose(sharedLayout); + + auto dstEnc = cast(dstTy.getEncoding()); + auto flatSharedEnc = SwizzledSharedEncodingAttr::get( + srcTy.getContext(), dstEnc.getVec(), 1, 1, dstEnc.getOrder(), + dstEnc.getCTALayout()); + auto flatDst = MemDescType::get(dstTy.getShape(), dstTy.getElementType(), + flatSharedEnc, dstTy.getMemorySpace()); + + auto regToSharedFlat = regLayout.invertAndCompose( + triton::gpu::toLinearLayout(shape, flatDst.getEncoding())); + // llvm::outs() << "Flat: " << regToSharedFlat << "\n"; + + MLIRContext *ctx = rewriter.getContext(); + StringAttr kBlock = str_attr("block"); + StringAttr kRegister = str_attr("register"); + StringAttr kLane = str_attr("lane"); + StringAttr kWarp = str_attr("warp"); + auto [laneId, warpId] = getLaneAndWarpId(rewriter, loc); + SmallVector swizzleOffsets; + for (int i = 0; i < numberOfLoads; i++) { + auto regId = b.i32_val(i * vecTy.getNumElements()); + + // for (int l = 0; l < 64; l++) { + // SmallVector> input = { + // {kRegister, i * vecTy.getNumElements()}, + // {kLane, l}, + // {kWarp, 0}, + // {kBlock, 0}}; + + // auto swizzOff = regToSharedLayout.apply(input)[0].second; + // auto flatOff = regToSharedFlat.apply(input)[0].second; + + // auto laneOff = (swizzOff - flatOff) / vecTy.getNumElements(); + + // llvm::outs() << l << ": " << swizzOff << ", " << flatOff << " = " + // << laneOff << "\n"; + // } + + auto swizzleOffset = + llvm::to_vector(llvm::drop_end(llvm::make_second_range( + applyLinearLayout(loc, rewriter, regToSharedLayout, + {{kRegister, regId}, + {kLane, laneId}, + {kWarp, warpId}, + {kBlock, b.i32_val(0)}}))))[0]; + auto flatOffset = llvm::to_vector(llvm::drop_end(llvm::make_second_range( + applyLinearLayout(loc, rewriter, regToSharedFlat, + {{kRegister, regId}, + {kLane, laneId}, + {kWarp, warpId}, + {kBlock, b.i32_val(0)}}))))[0]; + auto laneOffet = b.sdiv(b.sub(swizzleOffset, flatOffset), + b.i32_val(vecTy.getNumElements())); + swizzleOffsets.push_back(laneOffet); + } + + return swizzleOffsets; + } + // Emits the computation to get the lane id offset which holds the source // pointers/offsets we need to store to shared memory Value emitSwizzledLaneOffset(RewriterBase &rewriter, TritonLLVMOpBuilder &b, @@ -523,6 +632,11 @@ struct BufferLoadToLocalOpConversion llDst, coalescedShmemAddr, swizzledShmemAddr, vecTy); assert(vecTy.getNumElements() == vec); + auto ldsBaseAddresses = emitSharedBaseAddr( + rewriter, op, ptrType, dstTy, hasSwizzling, resElemTy, llDst, vecTy); + auto swizzleOffsets = emitSwizzleOffsets(op, rewriter, ptrType, dstTy, + vecTy, ldsBaseAddresses.size()); + int vecBytes = (vecTy.getNumElements() * vecTy.getElementTypeBitWidth()) / 8; assert(llvm::isPowerOf2_32(vecBytes)); @@ -532,9 +646,14 @@ struct BufferLoadToLocalOpConversion // based on the collected shared addresses and vector size Value rsrcDesc = bufferEmitter.createResourceDescriptor(llPtr, llStride); - for (int i = 0; i < coalescedShmemAddr.size(); i++) { + bool useFastSwizzling = tools::getBoolEnv("TRITON_HIP_ASYNC_FAST_SWIZZLE"); + + for (int i = 0; i < ldsBaseAddresses.size(); i++) { auto srcIdx = i * vec; auto offsetIn = offsetElems[srcIdx]; + auto ldsDst = + useFastSwizzling ? ldsBaseAddresses[i] : coalescedShmemAddr[i]; + Value pred = mask ? maskElems[srcIdx] : b.true_val(); if (hasSwizzling) { @@ -542,6 +661,11 @@ struct BufferLoadToLocalOpConversion Value laneOffset = emitSwizzledLaneOffset(rewriter, b, loc, coalescedShmemAddr[i], swizzledShmemAddr[i], vecBytesVal); + + if (useFastSwizzling) { + laneOffset = swizzleOffsets[i]; + } + // laneId + laneOffset will always stay inside the warp [0, // threadsPerWarp) because we only swizzle inside a warp Value swizzledLaneId = b.add(getLaneId(rewriter, loc), laneOffset); @@ -561,8 +685,7 @@ struct BufferLoadToLocalOpConversion } auto bufferLoadToLds = bufferEmitter.emitLoadToLds( - vecTy, vecBytesVal, rsrcDesc, offsetIn, coalescedShmemAddr[i], pred, - op.getCache()); + vecTy, vecBytesVal, rsrcDesc, offsetIn, ldsDst, pred, op.getCache()); LLVM::AMD::addAsyncCopyAliasScope(bufferLoadToLds); if (!otherElems.empty()) { Value storeVal = packElementRangeIntoVector( From 6527f107d1d53c182328e32a0b6f7c18fc136127 Mon Sep 17 00:00:00 2001 From: Jungwook Park Date: Mon, 19 May 2025 18:46:18 -0500 Subject: [PATCH 38/44] Code cleanup avoid wrongly enabled. --- .../TritonAMDGPUTransforms/BlockPingpong.cpp | 57 ++++--------------- 1 file changed, 12 insertions(+), 45 deletions(-) diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/BlockPingpong.cpp b/third_party/amd/lib/TritonAMDGPUTransforms/BlockPingpong.cpp index 7a21bbb4e737..a2fbc874d539 100644 --- a/third_party/amd/lib/TritonAMDGPUTransforms/BlockPingpong.cpp +++ b/third_party/amd/lib/TritonAMDGPUTransforms/BlockPingpong.cpp @@ -1032,44 +1032,10 @@ LogicalResult Pingponger::transformFAv3(OpBuilder &builder, Location loc) { } LogicalResult Pingponger::transformFP4s(OpBuilder &builder, Location loc) { - //FIXME: support nonscale. + // FIXME: support nonscale. if (lLoadOps.size() != 4) return failure(); -//#define OBO - -#if defined (OBO) - builder.setInsertionPointAfter(forOp); - - // FIXME: This is duplicated code, need to refactorize. - auto i32ty = builder.getIntegerType(32); - auto workIDX = builder.create(loc, i32ty); - workIDX->moveBefore(forOp); - builder.setInsertionPointAfter(workIDX); - auto constZero = builder.create(loc, 0, 32); - auto constWarpSize = builder.create(loc, 256, 32); - auto warpIDX = builder.create(loc, workIDX, constWarpSize); - auto warpLow = builder.create(loc, arith::CmpIPredicate::eq, - warpIDX, constZero); - auto warpHigh = builder.create(loc, arith::CmpIPredicate::ne, - warpIDX, constZero); - - builder.setInsertionPointAfter(dotSOps[0]); - updateOpInsertion(dotSOps[0]); - - appendOp(builder.create(loc, 0)); - appendOp(builder.create(loc, warpLow)); - appendOp(builder.create(loc, 0)); - - for (int i = 0; i < 4; i++) - appendOp(lLoadOps[i]); - appendOp(dotSOps[0]); - - appendOp(builder.create(loc, 0)); - appendOp(builder.create(loc, warpHigh)); - - -#else auto tokens = asyncWaitOps[0].getAsyncToken(); Operation *aWait = asyncWaitOps[0]; builder.setInsertionPointToStart(forOp.getBody()); @@ -1105,17 +1071,14 @@ LogicalResult Pingponger::transformFP4s(OpBuilder &builder, Location loc) { appendOp(builder.create(loc, 0)); appendOp(builder.create(loc)); appendOp(builder.create(loc, 0)); - + appendOp(lLoadOps[1]); appendOp(lLoadOps[3]); appendOp(dotSOps[0]); -#endif - return success(); } - LogicalResult Pingponger::transformFP4(OpBuilder &builder, Location loc) { builder.setInsertionPointAfter(forOp); @@ -1267,8 +1230,8 @@ void Pingponger::getDotPingponged() { } // FIXME: place tile size restriction here and obtain kWidth - if (dotSOps.size() == 1) { - kWidth = 16; + if (dotSOps.size() == 1 && numWarps == 8 && numStages == 2 && + asyncCopyOps.size() > 0) { auto dotSType = dotSOps[0].getType(); auto dotSShape = dotSType.getShape(); auto aType = dotSOps[0].getA().getType(); @@ -1277,7 +1240,9 @@ void Pingponger::getDotPingponged() { int64_t tileSize = dotSShape[0] * dotSShape[1] * aShape[1]; // 256x256x256 (128xi8) - if (tileSize == 8388608 && aShape[1] == 128 && elemWidth == 8){ + if (tileSize == 8388608 && aShape[0] == 256 && aShape[1] == 128 && + elemWidth == 8) { + kWidth = 16; if (transformFP4(builder, dotSOps[0]->getLoc()).failed()) { LDBG("Encountered failure when trying to execute the two ping pong " "cluster transformation"); @@ -1285,17 +1250,19 @@ void Pingponger::getDotPingponged() { } } // 128x128x512 (256xi8) - else if (tileSize == 4194304 && aShape[1] == 256 && elemWidth == 8){ + else if (tileSize == 4194304 && aShape[0] == 128 && aShape[1] == 256 && + elemWidth == 8) { if (transformFP4s(builder, dotSOps[0]->getLoc()).failed()) { LDBG("Encountered failure when trying to execute the two ping pong " "cluster transformation"); return; } } - + addAsymmetricSyncToLoop(builder, loc); return; - } + } else if (dotSOps.size() == 1) + return; // Determine if we have a persistent GEMM. This will decide how we interpret // any memory operations that we find in conditionals. From 5c4b1fb73078c3bf5044db2352c8bb7a4a631270 Mon Sep 17 00:00:00 2001 From: Alexander Weinrauch Date: Tue, 20 May 2025 08:53:52 +0000 Subject: [PATCH 39/44] [FA] Disable pipelining for causal loop --- fa/flash-attention.py | 10 ++++++---- .../lib/TritonAMDGPUTransforms/FourStagePipeliner.cpp | 6 ------ 2 files changed, 6 insertions(+), 10 deletions(-) diff --git a/fa/flash-attention.py b/fa/flash-attention.py index 97dbb8b5e9ed..de81cbf8f9a2 100644 --- a/fa/flash-attention.py +++ b/fa/flash-attention.py @@ -243,9 +243,11 @@ def _attn_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, bias_ptrs, stride_kn, stri BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, OFFS_M: tl.constexpr, OFFS_N: tl.constexpr, PRE_LOAD_V: tl.constexpr, MASK_STEPS: tl.constexpr, ENABLE_DROPOUT: tl.constexpr, RETURN_ENCODED_SOFTMAX: tl.constexpr, PADDED_HEAD: tl.constexpr, ACTUAL_BLOCK_DMODEL: tl.constexpr, - QK_SCALE: tl.constexpr, INT8_GEMM: tl.constexpr, USE_P_SCALE: tl.constexpr, INT8_KV: tl.constexpr): + QK_SCALE: tl.constexpr, INT8_GEMM: tl.constexpr, USE_P_SCALE: tl.constexpr, INT8_KV: tl.constexpr, + ENABLE_PIPELINING: tl.constexpr): # loop over k, v, and update accumulator - for start_n in range(block_min, block_max, BLOCK_N): + num_stages: tl.constexpr = None if ENABLE_PIPELINING else 1 # Set num_stages==1 if we want to disable pipelining + for start_n in tl.range(block_min, block_max, BLOCK_N, num_stages=num_stages): # For padded blocks, we will overrun the tensor size if # we load all BLOCK_N. For others, the blocks are all within range. if MASK_STEPS: @@ -674,7 +676,7 @@ def attn_fwd(Q, K, V, bias, SM_SCALE: tl.constexpr, L, Out, stride_qz, stride_qh # _, MASK_STEPS, ... PRE_LOAD_V, False, ENABLE_DROPOUT, RETURN_ENCODED_SOFTMAX, PADDED_HEAD, ACTUAL_BLOCK_DMODEL, QK_SCALE, INT8_GEMM, USE_P_SCALE, - INT8_KV) + INT8_KV, True) block_min = block_max block_max = n_blocks * BLOCK_N @@ -698,7 +700,7 @@ def attn_fwd(Q, K, V, bias, SM_SCALE: tl.constexpr, L, Out, stride_qz, stride_qh p_scale, IS_CAUSAL, BLOCK_M, BLOCK_DMODEL, BLOCK_N, offs_m, offs_n, # _, MASK_STEPS, ... PRE_LOAD_V, True, ENABLE_DROPOUT, RETURN_ENCODED_SOFTMAX, PADDED_HEAD, ACTUAL_BLOCK_DMODEL, - QK_SCALE, INT8_GEMM, USE_P_SCALE, INT8_KV) + QK_SCALE, INT8_GEMM, USE_P_SCALE, INT8_KV, False) if INT8 and not INT8_KV: if USE_P_SCALE: diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/FourStagePipeliner.cpp b/third_party/amd/lib/TritonAMDGPUTransforms/FourStagePipeliner.cpp index e30da8cf74a5..4c515ba308a6 100644 --- a/third_party/amd/lib/TritonAMDGPUTransforms/FourStagePipeliner.cpp +++ b/third_party/amd/lib/TritonAMDGPUTransforms/FourStagePipeliner.cpp @@ -74,12 +74,6 @@ FourStagePipeliner::FourStagePipeliner(scf::ForOp _forOp, int _numStages, } bool FourStagePipeliner::checkPrecondition(scf::ForOp forOp, int numStages) { - // Skip the second loop (causual loop) - static bool isFirst = true; - if (!isFirst) - return false; - isFirst = false; - unsigned dotCount{}; unsigned reduceCount{}; From 18ae32bc4adf213efa1ccc3f7c9063214ed3eed3 Mon Sep 17 00:00:00 2001 From: Ilya Veselov Date: Fri, 16 May 2025 14:15:30 +0000 Subject: [PATCH 40/44] [AMD] Add an option to force async copy overlapping Signed-off-by: Ilya Veselov --- include/triton/Tools/Sys/GetEnv.hpp | 1 + .../TritonAMDGPUTransforms/StreamPipeline.cpp | 20 ++++++++++++++++--- 2 files changed, 18 insertions(+), 3 deletions(-) diff --git a/include/triton/Tools/Sys/GetEnv.hpp b/include/triton/Tools/Sys/GetEnv.hpp index 112161cd5b1a..0e0bde48ef52 100644 --- a/include/triton/Tools/Sys/GetEnv.hpp +++ b/include/triton/Tools/Sys/GetEnv.hpp @@ -35,6 +35,7 @@ inline const std::set CACHE_INVALIDATING_ENV_VARS = { "TRITON_HIP_LOCAL_PREFETCH", "TRITON_HIP_USE_ASYNC_COPY", "TRITON_HIP_ASYNC_COPY_BYPASS_PERMUTE", + "TRITON_HIP_ASYNC_COPY_OVERLAP", "TRITON_HIP_ENABLE_F16_ASYNC_PINGPONG", "TRITON_HIP_USE_BLOCK_PINGPONG", "TRITON_HIP_USE_IN_THREAD_TRANSPOSE", diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/StreamPipeline.cpp b/third_party/amd/lib/TritonAMDGPUTransforms/StreamPipeline.cpp index 96dd6ef51130..14a35f1e61ef 100644 --- a/third_party/amd/lib/TritonAMDGPUTransforms/StreamPipeline.cpp +++ b/third_party/amd/lib/TritonAMDGPUTransforms/StreamPipeline.cpp @@ -122,10 +122,10 @@ class StreamPipeliner { public: StreamPipeliner(scf::ForOp _forOp, int _numStages, int _globalPrefetch, int _localPrefetch, bool _useAsyncCopy, - bool _useF16BlockPingpong) + bool _useF16BlockPingpong, bool _useAsyncCopyOverlap) : forOp(_forOp), numStages(_numStages), numBuffers(1), useAsyncCopy(_useAsyncCopy), useF16BlockPingpong(_useF16BlockPingpong), - schedule(numStages), + useAsyncCopyOverlap(_useAsyncCopyOverlap), schedule(numStages), axisInfoAnalysis(forOp->getParentOfType()) { int lastStage = numStages - 1; stages[SCHED_GLOBAL_LOAD] = 0; @@ -181,6 +181,9 @@ class StreamPipeliner { // Whether or not we are intend to ping-pong. bool useF16BlockPingpong; + // Move AsyncCopy before AsyncWait. + bool useAsyncCopyOverlap; + // Stage for each SchedType Op int stages[SCHED_SIZE]; // Cluster for each SchedType Op @@ -297,6 +300,14 @@ LogicalResult StreamPipeliner::initSchedule(int maxIndirectionLevel) { computeCluster = localLoadCluster; } + if (useAsyncCopyOverlap) { + globalLoadCluster = 0; + localStoreCluster = 1; + asyncWaitCluster = 2; + localLoadCluster = 3; + computeCluster = 3; + } + // Make assignments std::array clusterVec; std::generate(clusterVec.begin(), clusterVec.end(), @@ -1072,6 +1083,9 @@ struct PipelinePass : public TritonAMDGPUStreamPipelineBase { // between MXFP4 and FP16. bool useF16BlockPingpong = triton::tools::getBoolEnv("TRITON_HIP_ENABLE_F16_ASYNC_PINGPONG"); + bool useAsyncCopyOverlap = + triton::tools::getBoolEnv("TRITON_HIP_ASYNC_COPY_OVERLAP") & + useAsyncCopy; SmallVector loops; getOperation()->walk([&](scf::ForOp forOp) { labelLoadOpsForTritonDot(forOp); @@ -1092,7 +1106,7 @@ struct PipelinePass : public TritonAMDGPUStreamPipelineBase { } else { StreamPipeliner sp(forOp, tt::getNumStagesOrDefault(forOp, numStages), globalPrefetch, localPrefetch, useAsyncCopy, - useF16BlockPingpong); + useF16BlockPingpong, useAsyncCopyOverlap); (void)sp.pipelineLoop(); } } From c5ceb64faf114a4cbedc76aaab6c8d7f470f3835 Mon Sep 17 00:00:00 2001 From: ravil-mobile Date: Mon, 19 May 2025 12:58:31 +0000 Subject: [PATCH 41/44] [AMD] Improved CanonicalizePointers for ExtractSlice --- third_party/amd/backend/compiler.py | 1 + .../ExtractSliceOpToLLVM.cpp | 96 +++++++++---------- .../amd/lib/TritonAMDGPUToLLVM/Utility.cpp | 39 ++++++++ .../amd/lib/TritonAMDGPUToLLVM/Utility.h | 17 ++++ .../TritonAMDGPUTransforms/BlockPingpong.cpp | 64 ++++++++----- .../CanonicalizePointers.cpp | 53 ++++++++-- .../ConvertToBufferOps.cpp | 4 + 7 files changed, 193 insertions(+), 81 deletions(-) diff --git a/third_party/amd/backend/compiler.py b/third_party/amd/backend/compiler.py index d2d6b2073c66..07f4dfd75ee7 100644 --- a/third_party/amd/backend/compiler.py +++ b/third_party/amd/backend/compiler.py @@ -263,6 +263,7 @@ def make_ttgir(mod, metadata, options): use_block_pingpong = is_pingpong_schedule_enabled(options.arch) if use_block_pingpong and options.num_stages in [2, 4]: amd.passes.ttgpuir.add_block_pingpong(pm, options.num_stages, use_async_copy) + passes.ttgpuir.add_remove_layout_conversions(pm) if knobs.amd.use_buffer_ops: amd.passes.ttgpuir.add_canonicalize_pointers(pm) diff --git a/third_party/amd/lib/TritonAMDGPUDialectToLLVM/ExtractSliceOpToLLVM.cpp b/third_party/amd/lib/TritonAMDGPUDialectToLLVM/ExtractSliceOpToLLVM.cpp index 07cf91870fed..ed915577bf85 100644 --- a/third_party/amd/lib/TritonAMDGPUDialectToLLVM/ExtractSliceOpToLLVM.cpp +++ b/third_party/amd/lib/TritonAMDGPUDialectToLLVM/ExtractSliceOpToLLVM.cpp @@ -1,3 +1,4 @@ +#include "../TritonAMDGPUToLLVM/Utility.h" #include "Dialect/TritonAMDGPU/IR/Dialect.h" #include "TritonAMDGPUToLLVM/GCNAsmFormat.h" #include "mlir/Conversion/LLVMCommon/Pattern.h" @@ -49,6 +50,7 @@ using namespace mlir::triton; // clang-format on namespace { + struct ExtractSliceOpConversion : public ConvertOpToLLVMPattern { explicit ExtractSliceOpConversion(LLVMTypeConverter &typeConverter, @@ -60,61 +62,61 @@ struct ExtractSliceOpConversion ConversionPatternRewriter &rewriter) const { Location loc = op->getLoc(); auto srcTy = cast(op.getSource().getType()); - auto srcLayout = srcTy.getEncoding(); + auto dstTy = cast(op.getType()); auto srcShape = srcTy.getShape(); - auto resultTy = cast(op.getType()); - auto vals = unpackLLElements(loc, adaptor.getSource(), rewriter); - auto elemsPerThread = triton::gpu::getElemsPerThread(srcTy); - auto contigPerThread = triton::gpu::getContigPerThread(srcTy); - auto totalContigPerThread = product(contigPerThread); - auto order = triton::gpu::getOrder(srcTy); + auto dstShape = dstTy.getShape(); - // Calculate valid total number of workers in each dimension + auto vals = unpackLLElements(loc, adaptor.getSource(), rewriter); auto shapePerCTATile = triton::gpu::getShapePerCTATile(srcTy); - shapePerCTATile[0] = - std::min(static_cast(srcShape[0]), shapePerCTATile[0]); - shapePerCTATile[1] = - std::min(static_cast(srcShape[1]), shapePerCTATile[1]); - - // Rank == 2 checked in the verifier - SmallVector sizes; - for (auto i = 0; i < 2; ++i) { - sizes.push_back(resultTy.getDimSize(i)); - } + auto srcCTAShape = LLVM::AMD::multiDimElementwise( + srcShape, shapePerCTATile, std::divides()); + auto dstCTAShape = LLVM::AMD::multiDimElementwise( + dstShape, shapePerCTATile, std::divides()); + auto numCTATiles = std::accumulate(dstCTAShape.begin(), dstCTAShape.end(), + 1, std::multiplies<>()); auto offsets = op.getStaticOffsets(); + auto firstTileCoordinate = + LLVM::AMD::multiDimElementwise( + offsets, shapePerCTATile, std::divides()); - // Calculate offsets and sizes in terms of CTA units. - std::array CTAOffsets{offsets[0] / shapePerCTATile[0], - offsets[1] / shapePerCTATile[1]}; - std::array CTASizes{sizes[0] / shapePerCTATile[0], - sizes[1] / shapePerCTATile[1]}; - std::array CTAPerShape{srcShape[0] / shapePerCTATile[0], - srcShape[1] / shapePerCTATile[1]}; - - // The diagram above illustrates the graphical representation of the - // skipElems, tensorStride, and lastIdx variables. - auto skipElems = CTAOffsets[order[1]] * (elemsPerThread[order[0]] * - contigPerThread[order[1]]) + - CTAOffsets[order[0]] * totalContigPerThread; - auto tensorStride = - (CTAPerShape[order[0]] - CTASizes[order[0]]) * totalContigPerThread; - auto lastIdx = - (CTAOffsets[order[1]] + CTASizes[order[1]] - 1) * - elemsPerThread[order[0]] * contigPerThread[order[1]] + - (CTAOffsets[order[0]] + CTASizes[order[0]]) * totalContigPerThread; - - assert(lastIdx <= vals.size()); + Attribute srcEncoding = srcTy.getEncoding(); + Attribute dstEncoding = dstTy.getEncoding(); + auto linearLayoutSrc = triton::gpu::toLinearLayout(srcShape, srcEncoding); + auto linearLayoutDst = triton::gpu::toLinearLayout(dstShape, dstEncoding); + auto srcCTAOrder = + LLVM::AMD::getCTATileOrder(srcTy.getContext(), linearLayoutSrc); + auto dstCTAOrder = + LLVM::AMD::getCTATileOrder(srcTy.getContext(), linearLayoutDst); + + unsigned elemsPerThreadPerCTA = + triton::gpu::getTotalElemsPerThread(srcTy) / + std::accumulate(srcCTAShape.begin(), srcCTAShape.end(), 1, + std::multiplies<>()); + + // 1. Process CTA tiles in the destination tensor according to the + // destination's linear layout order of CTA tiles. + // 2. For each tile position in the destination tensor, compute its + // corresponding position in the source tensor. + // 3. Copy the values from the source tile to the destination slice. SmallVector resultVals; - for (int i = skipElems; i < lastIdx; i += tensorStride) { - for (int j = 0; j < totalContigPerThread * CTASizes[order[0]]; ++j, ++i) { - assert(i < lastIdx); - resultVals.push_back(vals[i]); + for (size_t i = 0; i < numCTATiles; i++) { + auto coordInDstTensor = + mlir::LLVM::delinearize(i, dstCTAShape, dstCTAOrder); + auto coordInSrcTensor = + LLVM::AMD::multiDimElementwise( + coordInDstTensor, firstTileCoordinate, std::plus()); + auto linearIdxInSrcTensor = + mlir::LLVM::linearize(coordInSrcTensor, srcCTAShape, srcCTAOrder); + + for (size_t j = 0; j < elemsPerThreadPerCTA; j++) { + resultVals.push_back( + vals[linearIdxInSrcTensor * elemsPerThreadPerCTA + j]); } } Value ret = packLLElements(loc, this->getTypeConverter(), resultVals, - rewriter, resultTy); + rewriter, dstTy); rewriter.replaceOp(op, ret); return success(); @@ -124,11 +126,7 @@ struct ExtractSliceOpConversion matchAndRewrite(amdgpu::ExtractSliceOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto srcTy = op.getSource().getType(); - if (isa( - op.getSource().getType().getEncoding())) { - return processLayout(op, adaptor, rewriter); - } - return failure(); + return processLayout(op, adaptor, rewriter); } }; } // namespace diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/Utility.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/Utility.cpp index 11ff747b6675..a78a10bf4f59 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/Utility.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/Utility.cpp @@ -755,4 +755,43 @@ void addLocalLoadNoAliasScope(AliasAnalysisOpInterface llLoadOp) { llLoadOp.setAliasScopes(aliasScopes); } +SmallVector getCTATileOrder(MLIRContext *ctx, + const LinearLayout &layout) { + auto llEnc = triton::gpu::LinearEncodingAttr::get(ctx, layout); + auto regDim = StringAttr::get(ctx, "register"); + auto &bases = layout.getBases().find(regDim)->second; + + // Compute number of CTA tiles in a layout. + unsigned totalElems = layout.getTotalOutDimSize(); + auto ctaShape = llEnc.getShapePerCTATile(); + unsigned elemsPerCTA = + std::accumulate(ctaShape.begin(), ctaShape.end(), 1, std::multiplies<>()); + assert((totalElems % elemsPerCTA) == 0 && + "Total elements must be divisible by elemsPerCTA"); + unsigned numCTAs = totalElems / elemsPerCTA; + + // To determine the CTA tile order, start by identifying the register basis + // vector that corresponds to the first element of the second CTA tile. The + // nonzero index in the logical tensor it maps to indicates the most minor + // dimension. Then, for each subsequent basis register (first element of + // some CTA tile), extract the next nonzero index to build the full dimension + // order. + unsigned totalPerThread = + product(llEnc.basesPerDim(regDim, /*skipBroadcast=*/false)) / numCTAs; + unsigned startIndex = static_cast(std::log2(totalPerThread)); + + llvm::SmallSetVector order; + for (unsigned i = startIndex; i < bases.size(); ++i) { + auto it = std::find_if(bases[i].begin(), bases[i].end(), + [](unsigned v) { return v != 0; }); + if (it != bases[i].end()) + order.insert(std::distance(bases[i].begin(), it)); + } + + // Append any dims missing from our default order. + for (unsigned dim : llEnc.getOrder()) + order.insert(dim); + + return SmallVector(order.begin(), order.end()); +} } // namespace mlir::LLVM::AMD diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/Utility.h b/third_party/amd/lib/TritonAMDGPUToLLVM/Utility.h index dda259360c61..f2d1e62f97c6 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/Utility.h +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/Utility.h @@ -137,6 +137,23 @@ void addLocalLoadNoAliasScope(AliasAnalysisOpInterface llLoadOp); // Attaches the "AsyncCopies" alias scope to llLoadDirectToLdsOp void addAsyncCopyAliasScope(AliasAnalysisOpInterface llLoadDirectToLdsOp); +// Determine the order in which CTA tiles are laid out across the tensor. +SmallVector getCTATileOrder(MLIRContext *ctx, + const LinearLayout &layout); + +template +std::vector multiDimElementwise(const ArrayRef &lhs, + const ArrayRef &rhs, BinaryOp op) { + assert(lhs.size() == rhs.size() && "Input dimensions must match"); + std::vector result; + result.reserve(lhs.size()); + for (size_t i = 0, n = lhs.size(); i < n; ++i) { + unsigned a = static_cast(lhs[i]); + unsigned b = static_cast(rhs[i]); + result.push_back(op(a, b)); + } + return result; +} } // namespace mlir::LLVM::AMD #endif // TRITON_THIRD_PARTY_AMD_LIB_TRITONAMDGPUTOLLVM_UTILITY_H_ diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/BlockPingpong.cpp b/third_party/amd/lib/TritonAMDGPUTransforms/BlockPingpong.cpp index 7a7d907234bb..e01ed34cc922 100644 --- a/third_party/amd/lib/TritonAMDGPUTransforms/BlockPingpong.cpp +++ b/third_party/amd/lib/TritonAMDGPUTransforms/BlockPingpong.cpp @@ -1518,6 +1518,12 @@ LogicalResult Pingponger::transformFP4(OpBuilder &builder, Location loc) { builder.setInsertionPointAfter(dotSOps[0]); if (sliceDotScaled(builder, loc, dotSOps[0], 4).failed()) return failure(); + + if (genAsyncCopySlices(builder).failed()) { + LDBG("failed to slice global-to-local async copies"); + return failure(); + } + updateOpInsertion(dotSliceOps[0]); appendOp(builder.create(loc, 0)); @@ -1660,7 +1666,8 @@ void Pingponger::getDotPingponged() { } // FIXME: place tile size restriction here and obtain kWidth - if (dotSOps.size() == 1) { + if (dotSOps.size() == 1 && numWarps == 8 && numStages == 2 && + asyncCopyOps.size() > 0) { kWidth = 16; auto dotSType = dotSOps[0].getType(); auto dotSShape = dotSType.getShape(); @@ -1668,36 +1675,45 @@ void Pingponger::getDotPingponged() { auto aShape = aType.getShape(); auto elemWidth = aType.getElementTypeBitWidth(); int64_t tileSize = dotSShape[0] * dotSShape[1] * aShape[1]; - if (tileSize != 8388608 || aShape[1] != 128 || elemWidth != 8) { - LDBG("encountered large matrix for scale dot: " - << "TileSize==" << tileSize << "; aShape[1]==" << aShape[1] - << "; elemWidth: " << elemWidth); - return; - } - - if (transformFP4(builder, dotSOps[0]->getLoc()).failed()) { - LDBG("Encountered failure when trying to execute the two ping pong " - "cluster transformation"); - return; - } - if (llvm::failed(genAsyncCopySlices(builder))) { - LDBG("failed to slice global-to-local async copies"); - } - - auto updateSignature = updateForOpSignature(builder); - if (llvm::failed(updateSignature)) { - LDBG("failed to update forOp signature"); - } + // 256x256x256 (128xi8) + if (tileSize == 8388608 && aShape[0] == 256 && aShape[1] == 128 && + elemWidth == 8) { + kWidth = 16; + if (transformFP4(builder, dotSOps[0]->getLoc()).failed()) { + LDBG("Encountered failure when trying to execute the two ping pong " + "cluster transformation"); + return; + } - if (llvm::succeeded(updateSignature)) { - if (llvm::failed(adjustRefinedAsyncTokens(builder))) { + auto updateSignature = updateForOpSignature(builder); + if (llvm::failed(updateSignature)) { LDBG("failed to update forOp signature"); } + + if (llvm::succeeded(updateSignature)) { + if (llvm::failed(adjustRefinedAsyncTokens(builder))) { + LDBG("failed to update forOp signature"); + } + } + + forOp->walk([](ttg::AsyncCommitGroupOp groupOp) { + auto users = groupOp.getResult().getUsers(); + if (users.empty()) { + SmallVector toDeleteVec; + for (auto token : groupOp.getInputTokens()) { + toDeleteVec.push_back(token.getDefiningOp()); + } + groupOp->erase(); + llvm::for_each(toDeleteVec, [](Operation *op) { op->erase(); }); + } + }); } + addAsymmetricSyncToLoop(builder, loc); return; - } + } else if (dotSOps.size() == 1) + return; // Determine if we have a persistent GEMM. This will decide how we interpret // any memory operations that we find in conditionals. diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/CanonicalizePointers.cpp b/third_party/amd/lib/TritonAMDGPUTransforms/CanonicalizePointers.cpp index f782c237b59e..733944394538 100644 --- a/third_party/amd/lib/TritonAMDGPUTransforms/CanonicalizePointers.cpp +++ b/third_party/amd/lib/TritonAMDGPUTransforms/CanonicalizePointers.cpp @@ -1176,20 +1176,57 @@ class ConvertExtractSliceOp } Location loc = extractSliceOp->getLoc(); - + RankedTensorType resultType = extractSliceOp.getResult().getType(); const FatPointers::FatPtrAttrs &fatPtrAttrs = fatPtrs.at({fatPtrBase, fatPtrOffset}); - auto newSrc = createTensorPointer(rewriter, fatPtrBase, fatPtrOffset, loc, - fatPtrAttrs); - RankedTensorType resType = extractSliceOp.getResult().getType(); - tt::amdgpu::ExtractSliceOp newExtractSliceOp = + Value newFatPtrOffset = nullptr; + auto origFatOffsetType = dyn_cast(fatPtrOffset.getType()); + auto slicedFatOffsetType = RankedTensorType::get( + resultType.getShape(), origFatOffsetType.getElementType(), + origFatOffsetType.getEncoding()); + + tt::amdgpu::ExtractSliceOp slicedFatPtrOffset = rewriter.create( - loc, Type{resType}, Value{newSrc}, + loc, Type{slicedFatOffsetType}, Value{fatPtrOffset}, extractSliceOp.getStaticOffsetsAttr()); - rewriter.replaceOp(extractSliceOp, newExtractSliceOp); - fatPtrs[{fatPtrBase, newExtractSliceOp}] = + + auto newResultPtrType = + RankedTensorType::get(resultType.getShape(), fatPtrBase.getType(), + origFatOffsetType.getEncoding()); + + // Scalar case: we only need to `tt.addptr %basePtr, %offset` + if (!origFatOffsetType) { + auto addPtrOp = rewriter.create( + loc, newResultPtrType, fatPtrBase, slicedFatPtrOffset); + for (const auto &attribute : fatPtrAttrs.attributes) + addPtrOp->setAttr(attribute.getFirst(), attribute.getSecond()); + newFatPtrOffset = addPtrOp.getResult(); + } + + // Tensor case: splat the scalar pointer and add the (tensor) offset: + // ``` + // %tensorBasePtr = tt.splat %basePtr + // %tensorPtr = tt.addptr %tensorBasePtr, %offset + // ``` + if (fatPtrAttrs.canNarrow) + fatPtrOffset = createTruncIOffset(rewriter, loc, fatPtrOffset, + rewriter.getI32Type()); + + tt::SplatOp tensorPtr = + rewriter.create(loc, newResultPtrType, fatPtrBase); + tt::AddPtrOp addPtrOp = rewriter.create( + loc, newResultPtrType, tensorPtr, slicedFatPtrOffset); + + for (const auto &attribute : fatPtrAttrs.attributes) + addPtrOp->setAttr(attribute.getFirst(), attribute.getSecond()); + newFatPtrOffset = addPtrOp.getResult(); + + assert(newFatPtrOffset); + rewriter.replaceOp(extractSliceOp, newFatPtrOffset); + fatPtrs[{fatPtrBase, newFatPtrOffset}] = fatPtrs.at({fatPtrBase, fatPtrOffset}); + return success(); } }; diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/ConvertToBufferOps.cpp b/third_party/amd/lib/TritonAMDGPUTransforms/ConvertToBufferOps.cpp index 7aa0bf102ca8..3cfd5a5ccb2d 100644 --- a/third_party/amd/lib/TritonAMDGPUTransforms/ConvertToBufferOps.cpp +++ b/third_party/amd/lib/TritonAMDGPUTransforms/ConvertToBufferOps.cpp @@ -200,6 +200,10 @@ bool verifyNonNegativeExpr( return verifyNonSmallerByAssumption(op.getLhs(), assumptions, op.getRhs()); }) + .Case([&](auto op) { + return verifyNonNegativeExpr(op->getOperand(0), assumptions, + solver); + }) .Default([&](Operation *) { // Conservatively assume that the expression is negative LDBG(" Unhandled op, cannot assume non-negative"); From a981b011df2ebabae772849be9d256d52521e96f Mon Sep 17 00:00:00 2001 From: plognjen Date: Tue, 20 May 2025 17:31:10 +0200 Subject: [PATCH 42/44] [AMD] Add a Concat op to AMDGPU dialect (#6590) The "concat" operation combines a list of source n-dimensional tensors into a single larger destination tensor. All source tensors must have the same shape, element type, and encoding. The concatenation dimension is inferred from the source and destination shapes provided by the user. For example, two tensors of shape 64x128 can produce a destination shape of 128x128, indicating concatenation along dimension 0; or 64x256, indicating concatenation along dimension 1. Generally, source tensors passed as op arguments can be arranged into the resulting shape in multiple ways. For example, given four tensors of shape 64x64: concat s0<64x64>, s1<64x64>, s2<64x64>, s3<64x64> -> <128x128> They can be laid out in different configurations within the result tensor: 1) s0 s1 s2 s3 2) s0 s2 s1 s3 From a logical tensor perspective, the source tensors are treated as elements of a tensor of tensors. In other words, the 1-D array of input tensors is conceptually reshaped into an n-D grid. The semantics of this op assume a row-major order (or its n-D generalization), meaning the fastest-varying dimension is filled first, and the slowest-varying dimension is filled last. In the example above, this corresponds to layout 1). The source and destination tensors must have identical linear layouts at the CTA tile level. That is, all base vectors for input dimensions must match, except for the register input dimension. The register basis must align on the subset that defines the logical tensor shape of a single CTA tile. This ensures that the concatenation is a no-op, meaning no data rearrangement among threads is required to assemble the destination tensor with the given shape and layout. However, the order of CTA tiles within the layout does not need to match between source and destination layouts. It is the responsibility of the op's lowering logic to handle this correctly. This op is designed to work on logical tensors directly, avoiding the need for complex layout reinterpretation or reshaping. For example, the `tt.join` operation only supports concatenation along the innermost dimension, and requires that the resulting innermost dimension provide 2 elements per thread, distributed across registers. In contrast, this `concat` op imposes no constraints on the concatenation dimension or the size of dimensions. --------- Co-authored-by: Ognjen Plavsic Co-authored-by: Lei Zhang --- .github/workflows/integration-tests-amd.yml | 2 +- test/Conversion/amd/invalid_concat_op.mlir | 174 ++++++++++++++ test/TritonGPU/amd/amd-concat-op.mlir | 105 ++++++++ .../TritonAMDGPU/IR/TritonAMDGPUOps.td | 69 ++++++ .../PatternTritonAMDGPUToLLVM.h | 3 + .../lib/Dialect/TritonAMDGPU/IR/Dialect.cpp | 92 +++++++ .../TritonAMDGPUDialectToLLVM/CMakeLists.txt | 1 + .../ConcatOpToLLVM.cpp | 171 +++++++++++++ .../TritonAMDGPUToLLVMPatterns.cpp | 1 + .../amd/python/test/test_extract_slice.py | 112 --------- .../test/test_extract_slice_concat_op.py | 227 ++++++++++++++++++ 11 files changed, 844 insertions(+), 113 deletions(-) create mode 100644 test/Conversion/amd/invalid_concat_op.mlir create mode 100644 test/TritonGPU/amd/amd-concat-op.mlir create mode 100644 third_party/amd/lib/TritonAMDGPUDialectToLLVM/ConcatOpToLLVM.cpp delete mode 100644 third_party/amd/python/test/test_extract_slice.py create mode 100644 third_party/amd/python/test/test_extract_slice_concat_op.py diff --git a/.github/workflows/integration-tests-amd.yml b/.github/workflows/integration-tests-amd.yml index 7098c20c53c3..92f228b273b3 100644 --- a/.github/workflows/integration-tests-amd.yml +++ b/.github/workflows/integration-tests-amd.yml @@ -109,7 +109,7 @@ jobs: echo "Could not find '${INSTRUMENTATION_LIB_DIR}'" ; exit -1 fi pytest --capture=tee-sys -rfs python/tutorials/06-fused-attention.py - pytest --capture=tee-sys -rfs third_party/amd/python/test/test_extract_slice.py + pytest --capture=tee-sys -rfs third_party/amd/python/test/test_extract_slice_concat_op.py TRITON_ALWAYS_COMPILE=1 pytest --capture=tee-sys -rfs third_party/amd/python/test/test_scalarize_packed_fops.py cd python/test/unit pytest --capture=tee-sys -rfs -n 12 language runtime \ diff --git a/test/Conversion/amd/invalid_concat_op.mlir b/test/Conversion/amd/invalid_concat_op.mlir new file mode 100644 index 000000000000..2b359dc059ed --- /dev/null +++ b/test/Conversion/amd/invalid_concat_op.mlir @@ -0,0 +1,174 @@ +// RUN: triton-opt -split-input-file %s --convert-triton-amdgpu-to-llvm='arch=gfx942' -verify-diagnostics + + +// Invalid ranks +#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +module attributes {"ttg.compute-capability" = 0 : i32, "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} { + tt.func @invalid_concat( + %arg0: tensor<32x64xf32, #blocked>, + %arg1: tensor<32x64xf32, #blocked>, + %arg2: tensor<32x64xf32, #blocked>, + %arg3: tensor<32x64xf32, #blocked>, + %arg4: tensor<32x64xf32, #blocked>, + %arg5: tensor<32x64xf32, #blocked>, + %arg6: tensor<32x64xf32, #blocked>, + %arg7: tensor<32x64xf32, #blocked>) { + + // expected-error @+1 {{Source and destination tensors must have the same rank.}} + %1 = amdgpu.concat %arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7: + tensor<32x64xf32, #blocked>,tensor<32x64xf32, #blocked>, tensor<32x64xf32, #blocked>, tensor<32x64xf32, #blocked>, tensor<32x64xf32, #blocked>, tensor<32x64xf32, #blocked>, tensor<32x64xf32, #blocked>, tensor<32x64xf32, #blocked> -> tensor<256xf32, #blocked> + tt.return + } +} + +// ----- + +// Invalid shapes 1 +#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +module attributes {"ttg.compute-capability" = 0 : i32, "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} { + tt.func @invalid_concat( + %arg0: tensor<32x64xf32, #blocked>, + %arg1: tensor<32x64xf32, #blocked>, + %arg2: tensor<32x64xf32, #blocked>, + %arg3: tensor<32x64xf32, #blocked>, + %arg4: tensor<32x64xf32, #blocked>, + %arg5: tensor<32x64xf32, #blocked>, + %arg6: tensor<32x64xf32, #blocked>, + %arg7: tensor<32x64xf32, #blocked>) { + + // expected-error @+1 {{Source and destination tensor shapes don't match.}} + %1 = amdgpu.concat %arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7: + tensor<32x64xf32, #blocked>,tensor<32x64xf32, #blocked>, tensor<32x64xf32, #blocked>, tensor<32x64xf32, #blocked>, tensor<32x64xf32, #blocked>, tensor<32x64xf32, #blocked>, tensor<32x64xf32, #blocked>, tensor<32x64xf32, #blocked> -> tensor<257x128xf32, #blocked> + tt.return + } +} + +// ----- + +// Invalid shapes 2 +#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +module attributes {"ttg.compute-capability" = 0 : i32, "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} { + tt.func @invalid_concat( + %arg0: tensor<32x64xf32, #blocked>, + %arg1: tensor<32x64xf32, #blocked>, + %arg2: tensor<32x64xf32, #blocked>, + %arg3: tensor<32x64xf32, #blocked>, + %arg4: tensor<32x64xf32, #blocked>, + %arg5: tensor<32x64xf32, #blocked>, + %arg6: tensor<32x64xf32, #blocked>, + %arg7: tensor<32x64xf32, #blocked>) { + + // expected-error @+1 {{Number of source tiles (8) doesn't match required count (16).}} + %1 = amdgpu.concat %arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7: + tensor<32x64xf32, #blocked>,tensor<32x64xf32, #blocked>, tensor<32x64xf32, #blocked>, tensor<32x64xf32, #blocked>, tensor<32x64xf32, #blocked>, tensor<32x64xf32, #blocked>, tensor<32x64xf32, #blocked>, tensor<32x64xf32, #blocked> -> tensor<256x128xf32, #blocked> + tt.return + } +} + + +// ----- + +// Invalid shapes 3 +#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 16], warpsPerCTA = [4, 1], order = [1, 0]}> +module attributes {"ttg.compute-capability" = 0 : i32, "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} { + tt.func @invalid_concat( + %arg0: tensor<32x64xf32, #blocked>, + %arg1: tensor<32x64xf32, #blocked>, + %arg2: tensor<32x64xf32, #blocked>, + %arg3: tensor<32x64xf32, #blocked>, + %arg4: tensor<32x64xf32, #blocked>, + %arg5: tensor<32x64xf32, #blocked>, + %arg6: tensor<32x64xf32, #blocked>, + %arg7: tensor<32x64xf32, #blocked>) { + + // expected-error @+1 {{CTA tile shapes must match between source and destination tensors.}} + %1 = amdgpu.concat %arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7: + tensor<32x64xf32, #blocked>,tensor<32x64xf32, #blocked>, tensor<32x64xf32, #blocked>, tensor<32x64xf32, #blocked>, tensor<32x64xf32, #blocked>, tensor<32x64xf32, #blocked>, tensor<32x64xf32, #blocked>, tensor<32x64xf32, #blocked> -> tensor<128x128xf32, #blocked1> + tt.return + } +} + +// ----- + +// Different types +#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [4, 16], warpsPerCTA = [4, 1], order = [1, 0]}> +module attributes {"ttg.compute-capability" = 0 : i32, "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} { + tt.func @invalid_concat( + %arg0: tensor<32x64xf32, #blocked1>, + %arg1: tensor<32x64xf32, #blocked>, + %arg2: tensor<32x64xf32, #blocked>, + %arg3: tensor<32x64xf32, #blocked>, + %arg4: tensor<32x64xf32, #blocked>, + %arg5: tensor<32x64xf32, #blocked>, + %arg6: tensor<32x64xf32, #blocked>, + %arg7: tensor<32x64xf32, #blocked>) { + + // expected-error @+1 {{All sources must have identical tensor types.}} + %1 = amdgpu.concat %arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7: + tensor<32x64xf32, #blocked1>,tensor<32x64xf32, #blocked>, tensor<32x64xf32, #blocked>, tensor<32x64xf32, #blocked>, tensor<32x64xf32, #blocked>, tensor<32x64xf32, #blocked>, tensor<32x64xf32, #blocked>, tensor<32x64xf32, #blocked> -> tensor<128x128xf32, #blocked> + tt.return + } +} + +// ----- + +// Invalid element types +#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +module attributes {"ttg.compute-capability" = 0 : i32, "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} { + tt.func @invalid_concat( + %arg0: tensor<32x64xf32, #blocked>, + %arg1: tensor<32x64xf32, #blocked>, + %arg2: tensor<32x64xf32, #blocked>, + %arg3: tensor<32x64xf32, #blocked>, + %arg4: tensor<32x64xf32, #blocked>, + %arg5: tensor<32x64xf32, #blocked>, + %arg6: tensor<32x64xf32, #blocked>, + %arg7: tensor<32x64xf32, #blocked>) { + + // expected-error @+1 {{Element types of sources and destination must match.}} + %1 = amdgpu.concat %arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7: + tensor<32x64xf32, #blocked>,tensor<32x64xf32, #blocked>, tensor<32x64xf32, #blocked>, tensor<32x64xf32, #blocked>, tensor<32x64xf32, #blocked>, tensor<32x64xf32, #blocked>, tensor<32x64xf32, #blocked>, tensor<32x64xf32, #blocked> -> tensor<256x64xf16, #blocked> + tt.return + } +} + + +// ----- + +// Different layouts 1 +#src_layout = #ttg.linear<{register=[[0, 1], [0, 2], [0, 8], [0, 16], [0, 64], [64, 0]], lane=[[1, 0], [2, 0], [4, 0], [8, 0], [16, 0], [0, 4]], warp=[[0, 32], [32, 0]], block=[]}> +#dst_layout = #ttg.linear<{register=[[0, 1], [0, 2], [0, 8], [0, 16], [0, 64], [0, 128], [64, 0], [128, 0]], lane=[[1, 0], [2, 0], [4, 0], [8, 0], [16, 0], [0, 4], [0, 0]], warp=[[0, 32], [32, 0]], block=[]}> +module attributes {"ttg.compute-capability" = 0 : i32, "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} { + tt.func @invalid_concat( + %arg0: tensor<128x128xf32, #src_layout>, + %arg1: tensor<128x128xf32, #src_layout>, + %arg2: tensor<128x128xf32, #src_layout>, + %arg3: tensor<128x128xf32, #src_layout>) { + + // expected-error @+1 {{Lane and warp dim basis must match between source and destination layout.}} + %1 = amdgpu.concat %arg0, %arg1, %arg2, %arg3: + tensor<128x128xf32, #src_layout>, tensor<128x128xf32, #src_layout>, tensor<128x128xf32, #src_layout>, tensor<128x128xf32, #src_layout> -> tensor<256x256xf32, #dst_layout> + tt.return + } +} + +// ----- + +// Different layouts 2 +#src_layout = #ttg.linear<{register=[[0, 1], [0, 2], [0, 8], [0, 16], [0, 64], [64, 0]], lane=[[1, 0], [2, 0], [4, 0], [8, 0], [16, 0], [0, 4]], warp=[[0, 32], [32, 0]], block=[]}> +#dst_layout = #ttg.linear<{register=[[0, 0], [0, 1], [0, 2], [0, 8], [0, 16], [0, 64], [0, 128], [64, 0], [128, 0]], lane=[[1, 0], [2, 0], [4, 0], [8, 0], [16, 0], [0, 4]], warp=[[0, 32], [32, 0]], block=[]}> +module attributes {"ttg.compute-capability" = 0 : i32, "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} { + tt.func @invalid_concat( + %arg0: tensor<128x128xf32, #src_layout>, + %arg1: tensor<128x128xf32, #src_layout>, + %arg2: tensor<128x128xf32, #src_layout>, + %arg3: tensor<128x128xf32, #src_layout>) { + + // expected-error @+1 {{Register basis must match on a CTA tile between source and destination.}} + %1 = amdgpu.concat %arg0, %arg1, %arg2, %arg3: + tensor<128x128xf32, #src_layout>, tensor<128x128xf32, #src_layout>, tensor<128x128xf32, #src_layout>, tensor<128x128xf32, #src_layout> -> tensor<256x256xf32, #dst_layout> + tt.return + } +} diff --git a/test/TritonGPU/amd/amd-concat-op.mlir b/test/TritonGPU/amd/amd-concat-op.mlir new file mode 100644 index 000000000000..715b32587bd2 --- /dev/null +++ b/test/TritonGPU/amd/amd-concat-op.mlir @@ -0,0 +1,105 @@ +// RUN: triton-opt %s -split-input-file --convert-triton-amdgpu-to-llvm='arch=gfx942' | FileCheck %s + +// ----- + +#blocked1 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +module attributes {"ttg.compute-capability" = 0 : i32, "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} { + tt.func @concat_blocked( + %arg0: tensor<32x64xf32, #blocked1>, + %arg1: tensor<32x64xf32, #blocked1>, + %arg2: tensor<32x64xf32, #blocked1>, + %arg3: tensor<32x64xf32, #blocked1>, + %arg4: tensor<32x64xf32, #blocked1>, + %arg5: tensor<32x64xf32, #blocked1>, + %arg6: tensor<32x64xf32, #blocked1>, + %arg7: tensor<32x64xf32, #blocked1>) { + // CHECK: llvm.func @concat_blocked + + // CHECK-COUNT-8: %{{.*}} = llvm.extractvalue %arg0[{{.*}}] : !llvm.struct + // CHECK-COUNT-8: %{{.*}} = llvm.extractvalue %arg1[{{.*}}] : !llvm.struct + // CHECK-COUNT-8: %{{.*}} = llvm.extractvalue %arg2[{{.*}}] : !llvm.struct + // CHECK-COUNT-8: %{{.*}} = llvm.extractvalue %arg3[{{.*}}] : !llvm.struct + // CHECK-COUNT-8: %{{.*}} = llvm.extractvalue %arg4[{{.*}}] : !llvm.struct + // CHECK-COUNT-8: %{{.*}} = llvm.extractvalue %arg5[{{.*}}] : !llvm.struct + // CHECK-COUNT-8: %{{.*}} = llvm.extractvalue %arg6[{{.*}}] : !llvm.struct + // CHECK-COUNT-8: %{{.*}} = llvm.extractvalue %arg7[{{.*}}] : !llvm.struct + + // CHECK-COUNT-64: %{{[0-9]*}} = llvm.insertvalue %{{.*}} : !llvm.struct + + %1 = amdgpu.concat %arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7: + tensor<32x64xf32, #blocked1>,tensor<32x64xf32, #blocked1>, tensor<32x64xf32, #blocked1>, tensor<32x64xf32, #blocked1>, tensor<32x64xf32, #blocked1>, tensor<32x64xf32, #blocked1>, tensor<32x64xf32, #blocked1>, tensor<32x64xf32, #blocked1> -> tensor<128x128xf32, #blocked1> + tt.return + } +} + +// ----- + +#src_layout = #ttg.linear<{register=[[0, 1], [0, 2], [0, 8], [0, 16], [0, 64], [64, 0]], lane=[[1, 0], [2, 0], [4, 0], [8, 0], [16, 0], [0, 4]], warp=[[0, 32], [32, 0]], block=[]}> +#dst_layout = #ttg.linear<{register=[[0, 1], [0, 2], [0, 8], [0, 16], [0, 64], [0, 128], [64, 0], [128, 0]], lane=[[1, 0], [2, 0], [4, 0], [8, 0], [16, 0], [0, 4]], warp=[[0, 32], [32, 0]], block=[]}> +module attributes {"ttg.compute-capability" = 0 : i32, "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} { + tt.func @concat_ll_2d_1( + %arg0: tensor<128x128xf32, #src_layout>, + %arg1: tensor<128x128xf32, #src_layout>, + %arg2: tensor<128x128xf32, #src_layout>, + %arg3: tensor<128x128xf32, #src_layout>){ + // CHECK: llvm.func @concat_ll_2d_1 + + // CHECK-COUNT-64: %{{.*}} = llvm.extractvalue %arg0[{{.*}}] : !llvm.struct + // CHECK-COUNT-64: %{{.*}} = llvm.extractvalue %arg1[{{.*}}] : !llvm.struct + // CHECK-COUNT-64: %{{.*}} = llvm.extractvalue %arg2[{{.*}}] : !llvm.struct + // CHECK-COUNT-64: %{{.*}} = llvm.extractvalue %arg3[{{.*}}] : !llvm.struct + // CHECK-COUNT-256: %{{.*}} = llvm.insertvalue %{{.*}} : !llvm.struct + + %1 = amdgpu.concat %arg0, %arg1, %arg2, %arg3: + tensor<128x128xf32, #src_layout>, tensor<128x128xf32, #src_layout>, tensor<128x128xf32, #src_layout>, tensor<128x128xf32, #src_layout> -> tensor<256x256xf32, #dst_layout> + tt.return + } +} + +// ----- + +#src_layout = #ttg.linear<{register=[[1, 0], [2, 0], [4, 0]], lane=[[0, 1], [0, 2], [0, 4], [0, 8], [8, 0], [16, 0]], warp=[[0, 16]], block=[]}> +#dst_layout = #ttg.linear<{register=[[1, 0], [2, 0], [4, 0], [32, 0], [0, 32]], lane=[[0, 1], [0, 2], [0, 4], [0, 8], [8, 0], [16, 0]], warp=[[0, 16]], block=[]}> +module attributes {"ttg.compute-capability" = 0 : i32, "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} { + tt.func @concat_ll_2d_2( + %arg0: tensor<32x32xf32, #src_layout>, + %arg1: tensor<32x32xf32, #src_layout>, + %arg2: tensor<32x32xf32, #src_layout>, + %arg3: tensor<32x32xf32, #src_layout>){ + // CHECK: llvm.func @concat_ll_2d_2 + + // CHECK-COUNT-8: %{{.*}} = llvm.extractvalue %arg0[{{.*}}] : !llvm.struct + // CHECK-COUNT-8: %{{.*}} = llvm.extractvalue %arg1[{{.*}}] : !llvm.struct + // CHECK-COUNT-8: %{{.*}} = llvm.extractvalue %arg2[{{.*}}] : !llvm.struct + // CHECK-COUNT-8: %{{.*}} = llvm.extractvalue %arg3[{{.*}}] : !llvm.struct + // CHECK-COUNT-32: %{{.*}} = llvm.insertvalue %{{.*}} : !llvm.struct + + %1 = amdgpu.concat %arg0, %arg1, %arg2, %arg3: + tensor<32x32xf32, #src_layout>, tensor<32x32xf32, #src_layout>, tensor<32x32xf32, #src_layout>, tensor<32x32xf32, #src_layout> -> tensor<64x64xf32, #dst_layout> + tt.return + } +} + +// ----- + +#src_layout = #ttg.linear<{register=[[1]], lane=[[2], [4], [8], [16], [32], [64]], warp=[[128]], block=[]}> +#dst_layout = #ttg.linear<{register=[[1], [256], [512]], lane=[[2], [4], [8], [16], [32], [64]], warp=[[128]], block=[]}> +module attributes {"ttg.compute-capability" = 0 : i32, "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} { + tt.func @concat_ll_1d( + %arg0: tensor<256xf32, #src_layout>, + %arg1: tensor<256xf32, #src_layout>, + %arg2: tensor<256xf32, #src_layout>, + %arg3: tensor<256xf32, #src_layout>){ + // CHECK: llvm.func @concat_ll_1d + + // CHECK-COUNT-2: %{{.*}} = llvm.extractvalue %arg0[{{.*}}] : !llvm.struct + // CHECK-COUNT-2: %{{.*}} = llvm.extractvalue %arg1[{{.*}}] : !llvm.struct + // CHECK-COUNT-2: %{{.*}} = llvm.extractvalue %arg2[{{.*}}] : !llvm.struct + // CHECK-COUNT-2: %{{.*}} = llvm.extractvalue %arg3[{{.*}}] : !llvm.struct + // CHECK-COUNT-8: %{{.*}} = llvm.insertvalue %{{.*}} : !llvm.struct + + %1 = amdgpu.concat %arg0, %arg1, %arg2, %arg3: + tensor<256xf32, #src_layout>, tensor<256xf32, #src_layout>, tensor<256xf32, #src_layout>, tensor<256xf32, #src_layout> -> tensor<1024xf32, #dst_layout> + tt.return + } +} diff --git a/third_party/amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUOps.td b/third_party/amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUOps.td index 17d9409468d8..b487c1402332 100644 --- a/third_party/amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUOps.td +++ b/third_party/amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUOps.td @@ -119,6 +119,75 @@ def ExtractSliceOp : TT_AMDGPU_Op<"extract_slice", [Pure]> { let hasVerifier = 1; } +def ConcatOp : TT_AMDGPU_Op<"concat", [Pure]> { + let summary = "concat operation"; + let description = [{ + The "concat" operation combines a list of source n-dimensional tensors into a single larger destination tensor. + + All source tensors must have the same shape, element type, and encoding. + The concatenation dimension is inferred from the source and destination shapes provided by the user. + For example, two tensors of shape 64x128 can produce a destination shape of 128x128, + indicating concatenation along dimension 0; or 64x256, indicating concatenation along dimension 1. + + Generally, source tensors passed as op arguments can be arranged into the resulting shape in multiple ways. + For example, given four tensors of shape 64x64: + concat s0<64x64>, s1<64x64>, s2<64x64>, s3<64x64> -> <128x128> + + They can be laid out in different configurations within the result tensor: + 1) s0 s1 2) s0 s2 + s2 s3 s1 s3 + + From a logical tensor perspective, the source tensors are treated as elements of a tensor of tensors. + In other words, the 1-D array of input tensors is conceptually reshaped into an n-D grid. + The semantics of this op assume a row-major order (or its n-D generalization), + meaning the fastest-varying dimension is filled first, and the slowest-varying dimension is filled last. + In the example above, this corresponds to layout 1). + + The source and destination tensors must have identical linear layouts at the CTA tile level. + That is, all base vectors for input dimensions must match, except for the register input dimension. + The register basis must align on the subset that defines the logical tensor shape of a single CTA tile. + + This ensures that the concatenation is a no-op, meaning no data rearrangement among threads is required + to assemble the destination tensor with the given shape and layout. + However, the order of CTA tiles within the layout does not need to match between source and destination layouts. + It is the responsibility of the op's lowering logic to handle this correctly. + + This op is designed to work on logical tensors directly, avoiding the need for complex layout reinterpretation or reshaping. + For example, the `tt.join` operation only supports concatenation along the innermost dimension, + and requires that the resulting innermost dimension provide 2 elements per thread, distributed across registers. + In contrast, this `concat` op imposes no constraints on the concatenation dimension or the size of dimensions. + + * sources: a list of the input tensors. + + Example 1: + + ```mlir + #blocked = #ttg.blocked<{sizePerThread = [1, 8], + threadsPerWarp = [8, 8], warpsPerCTA = [4, 1], order = [1, 0]}> + %0 = amdgpu.concat %arg0, %arg1: tensor<32x64xf32, #blocked>,tensor<32x64xf32, #blocked>, + -> tensor<64x64xf32, #blocked> + ``` + + Example 2: + ```mlir + #src_layout = #ttg.linear<{register=[[0, 1], [0, 2], [0, 8], [0, 16], [0, 64], [64, 0]], lane=[[1, 0], [2, 0], [4, 0], [8, 0], [16, 0], [0, 4]], warp=[[0, 32], [32, 0]], block=[]}> + #dst_layout = #ttg.linear<{register=[[0, 1], [0, 2], [0, 8], [0, 16], [0, 64], [0, 128], [64, 0], [128, 0]], lane=[[1, 0], [2, 0], [4, 0], [8, 0], [16, 0], [0, 4]], warp=[[0, 32], [32, 0]], block=[]}> + %0 = amdgpu.concat %arg0, %arg1, %arg2, %arg3 : tensor<128x128xf16, #src_layout>, tensor<128x128xf16, #src_layout>, tensor<128x128xf16, #src_layout>, + tensor<128x128xf16, #src_layout> -> tensor<256x256xf16, #dst_layout> + ``` + + }]; + + let arguments = (ins Variadic:$sources); + let results = (outs AnyRankedTensor:$result); + + let assemblyFormat = [{ + $sources attr-dict `:` type($sources) `->` type($result) + }]; + + let hasVerifier = 1; +} + //===----------------------------------------------------------------------===// // InstructionSchedHint //===----------------------------------------------------------------------===// diff --git a/third_party/amd/include/TritonAMDGPUToLLVM/PatternTritonAMDGPUToLLVM.h b/third_party/amd/include/TritonAMDGPUToLLVM/PatternTritonAMDGPUToLLVM.h index 6763de2eba22..724849f01bbf 100644 --- a/third_party/amd/include/TritonAMDGPUToLLVM/PatternTritonAMDGPUToLLVM.h +++ b/third_party/amd/include/TritonAMDGPUToLLVM/PatternTritonAMDGPUToLLVM.h @@ -11,6 +11,9 @@ void populateExtractSliceOpToLLVMPatterns( void populateInThreadTransposeOpToTTGPatterns(mlir::RewritePatternSet &patterns, mlir::PatternBenefit benefit); +void populateConcatOpToLLVMPatterns(mlir::LLVMTypeConverter &typeConverter, + mlir::RewritePatternSet &patterns, + mlir::PatternBenefit benefit); } // namespace mlir::triton::AMD diff --git a/third_party/amd/lib/Dialect/TritonAMDGPU/IR/Dialect.cpp b/third_party/amd/lib/Dialect/TritonAMDGPU/IR/Dialect.cpp index 7543805fc084..586ebfda9dc1 100644 --- a/third_party/amd/lib/Dialect/TritonAMDGPU/IR/Dialect.cpp +++ b/third_party/amd/lib/Dialect/TritonAMDGPU/IR/Dialect.cpp @@ -302,4 +302,96 @@ InThreadTransposeOp::deduceOutputLayout(ArrayRef shape, return transposedLL; } +LogicalResult ConcatOp::verify() { + auto sources = getSources(); + auto result = getResult(); + + auto srcType = cast(sources.front().getType()); + auto dstType = cast(result.getType()); + + auto srcShape = srcType.getShape(); + auto dstShape = dstType.getShape(); + unsigned rank = srcShape.size(); + + // 1) Shape related checks. + if (rank != dstShape.size()) + return emitError() + << "Source and destination tensors must have the same rank."; + + unsigned numTiles = 1; + for (int i = 0; i < rank; ++i) { + if (dstShape[i] % srcShape[i] != 0) + return emitError() << "Source and destination tensor shapes don't match."; + numTiles *= dstShape[i] / srcShape[i]; + } + + if (numTiles != sources.size()) + return emitError() << "Number of source tiles (" << sources.size() + << ") doesn't match required count (" << numTiles + << ")."; + + // 2) Check that all sources have same type and element type match. + for (auto src : sources) { + auto curr = dyn_cast(src.getType()); + if (curr != srcType) + return emitError() << "All sources must have identical tensor types."; + } + + if (dstType.getElementType() != srcType.getElementType()) + return emitError() + << "Element types of sources and destination must match."; + + // 3) Verify that source and destination layout match on a CTA tile. + auto srcLL = triton::gpu::toLinearLayout(srcShape, srcType.getEncoding()); + auto dstLL = triton::gpu::toLinearLayout(dstShape, dstType.getEncoding()); + + auto getBases = [&](StringRef name) { + auto key = StringAttr::get(getContext(), name); + return std::pair{srcLL.getBases().lookup(key), + dstLL.getBases().lookup(key)}; + }; + + auto [regSrc, regDst] = getBases("register"); + auto [laneSrc, laneDst] = getBases("lane"); + auto [warpSrc, warpDst] = getBases("warp"); + + auto shapeCTASrc = mlir::triton::gpu::getShapePerCTATile(srcType); + auto shapeCTADst = mlir::triton::gpu::getShapePerCTATile(dstType); + if (shapeCTASrc != shapeCTADst) + return emitError() << "CTA tile shapes must match between source and " + "destination tensors."; + + unsigned numCTAs = 1; + for (int d = 0; d < rank; ++d) + numCTAs *= srcShape[d] / shapeCTASrc[d]; + unsigned elemsPerThread = + triton::gpu::getTotalElemsPerThread(srcType) / numCTAs; + unsigned regCompareLen = std::log2(elemsPerThread); + + auto compareBasis = [&](auto &srcBasis, auto &dstBasis, StringRef message, + int limit = -1) { + int n = (limit < 0 ? srcBasis.size() + : std::min(srcBasis.size(), limit)); + for (size_t i = 0; i < n; ++i) { + if (srcBasis[i] != dstBasis[i]) { + emitError() << message; + return false; + } + } + return true; + }; + + if (laneSrc != laneDst || warpSrc != warpDst) { + return emitError() << "Lane and warp dim basis must match between source " + "and destination layout."; + } + + if (!compareBasis(regSrc, regDst, + "Register basis must match on a CTA tile between source " + "and destination.", + regCompareLen)) + return failure(); + + return success(); +} } // namespace mlir::triton::amdgpu diff --git a/third_party/amd/lib/TritonAMDGPUDialectToLLVM/CMakeLists.txt b/third_party/amd/lib/TritonAMDGPUDialectToLLVM/CMakeLists.txt index 693bd41bc55a..35310b86eecd 100644 --- a/third_party/amd/lib/TritonAMDGPUDialectToLLVM/CMakeLists.txt +++ b/third_party/amd/lib/TritonAMDGPUDialectToLLVM/CMakeLists.txt @@ -2,6 +2,7 @@ add_triton_library(TritonAMDGPUDialectToLLVM TritonAMDGPUToLLVMPatterns.cpp ExtractSliceOpToLLVM.cpp InThreadTransposeOpToTTG.cpp + ConcatOpToLLVM.cpp DEPENDS TritonAMDGPUIR diff --git a/third_party/amd/lib/TritonAMDGPUDialectToLLVM/ConcatOpToLLVM.cpp b/third_party/amd/lib/TritonAMDGPUDialectToLLVM/ConcatOpToLLVM.cpp new file mode 100644 index 000000000000..9d75b3b7d204 --- /dev/null +++ b/third_party/amd/lib/TritonAMDGPUDialectToLLVM/ConcatOpToLLVM.cpp @@ -0,0 +1,171 @@ +#include "Dialect/TritonAMDGPU/IR/Dialect.h" +#include "mlir/Conversion/LLVMCommon/Pattern.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" + +using namespace mlir; +using namespace mlir::triton; + +namespace { + +template +std::vector multiDimElementwise(ArrayRef lhs, ArrayRef rhs, + BinaryOp op) { + assert(lhs.size() == rhs.size() && "Input dimensions must match"); + std::vector result; + result.reserve(lhs.size()); + for (size_t i = 0, n = lhs.size(); i < n; ++i) { + unsigned a = static_cast(lhs[i]); + unsigned b = static_cast(rhs[i]); + result.push_back(op(a, b)); + } + return result; +} + +template unsigned getNumElements(const ArrayRef shape) { + return std::accumulate(shape.begin(), shape.end(), 1, std::multiplies<>()); +} + +// Determine the order in which CTA tiles are laid out across the tensor. +// That is, create vector of dimensions from fastest to slowest varying. +SmallVector getCTATileOrder(MLIRContext *ctx, + const LinearLayout &layout) { + auto llEnc = triton::gpu::LinearEncodingAttr::get(ctx, layout); + auto regDim = StringAttr::get(ctx, "register"); + auto &bases = layout.getBases().find(regDim)->second; + + // Compute number of CTA tiles in a layout. + unsigned totalElems = layout.getTotalOutDimSize(); + auto ctaShape = llEnc.getShapePerCTATile(); + unsigned elemsPerCTA = + std::accumulate(ctaShape.begin(), ctaShape.end(), 1, std::multiplies<>()); + assert((totalElems % elemsPerCTA) == 0 && + "Total elements must be divisible by elemsPerCTA"); + unsigned numCTAs = totalElems / elemsPerCTA; + + // To determine the CTA tile order, start by identifying the register basis + // vector that corresponds to the first element of the second CTA tile. The + // nonzero index in the logical tensor it maps to indicates the fastest + // varying dimension. Then, for each subsequent basis register (first element + // of some CTA tile), extract the next nonzero index to build the full + // dimension order. + unsigned registersPerThreadPerCTA = + product(llEnc.basesPerDim(regDim, /*skipBroadcast=*/false)) / numCTAs; + unsigned startIndex = + static_cast(std::log2(registersPerThreadPerCTA)); + + llvm::SmallSetVector order; + for (unsigned i = startIndex; i < bases.size(); ++i) { + auto range = llvm::make_range(bases[i].begin(), bases[i].end()); + auto it = llvm::find_if(range, [](unsigned v) { return v != 0; }); + if (it != bases[i].end()) + order.insert(std::distance(bases[i].begin(), it)); + } + + // Append any dims missing from our default order. + for (unsigned dim : llEnc.getOrder()) + order.insert(dim); + + return order.takeVector(); +} + +struct ConcatOpConversion : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(amdgpu::ConcatOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + RankedTensorType resultType = + cast(op.getResult().getType()); + + ArrayRef dstShape = resultType.getShape(); + Attribute dstEncoding = resultType.getEncoding(); + + Value srcVal = op.getSources()[0]; + RankedTensorType srcType = cast(srcVal.getType()); + ArrayRef srcShape = srcType.getShape(); + Attribute srcEncoding = srcType.getEncoding(); + + MLIRContext *context = resultType.getContext(); + auto linearLayoutSrc = triton::gpu::toLinearLayout(srcShape, srcEncoding); + auto linearLayoutDst = triton::gpu::toLinearLayout(dstShape, dstEncoding); + auto srcCTAOrder = getCTATileOrder(context, linearLayoutSrc); + auto dstCTAOrder = getCTATileOrder(context, linearLayoutSrc); + + auto rank = srcShape.size(); + auto shapePerCTATile = triton::gpu::getShapePerCTATile(resultType); + auto sources = adaptor.getSources(); + + unsigned totalElems = ::getNumElements(dstShape); + unsigned elemsPerTile = ::getNumElements(shapePerCTATile); + unsigned numCTATiles = totalElems / elemsPerTile; + + // Default order is fastest to slowest varying dimension. + std::vector defaultOrder(rank); + std::iota(defaultOrder.rbegin(), defaultOrder.rend(), 0); + + auto dstCTAShape = multiDimElementwise( + dstShape, shapePerCTATile, std::divides()); + auto srcCTAShape = multiDimElementwise( + srcShape, shapePerCTATile, std::divides()); + auto srcToDstShape = multiDimElementwise( + dstShape, srcShape, std::divides()); + + unsigned elemsPerThreadPerCTA = + triton::gpu::getTotalElemsPerThread(srcType) / + ::getNumElements(srcCTAShape); + + llvm::SmallVector resultVals; + llvm::SmallVector> unpackedSources; + unpackedSources.reserve(sources.size()); + + for (size_t i = 0; i < sources.size(); i++) { + Value currSrc = sources[i]; + unpackedSources.push_back(unpackLLElements(loc, currSrc, rewriter)); + } + + // Traverse CTA tiles in the result tensor + for (int i = 0; i < numCTATiles; ++i) { + auto currTileIdx = mlir::LLVM::delinearize(i, dstCTAShape, dstCTAOrder); + // The n-dim destination tensor is built by arranging n-dim source tensors + // into a destination tensor shape. Determine which source tensor contains + // the current CTA tile. + auto multiDimSrcIdx = multiDimElementwise( + currTileIdx, srcCTAShape, std::divides()); + // Compute linear index of the current source tensor. + // Concat operands are laid out in the destination tensor + // in fastest slowest varying dimension order. + auto linearSrcIdx = + mlir::LLVM::linearize(multiDimSrcIdx, srcToDstShape, defaultOrder); + + // After determining which source tensor the current CTA tile belongs to, + // compute the index of this CTA tile within that source tensor, + // considering the source tensors may include CTA tiles. + auto multiDimSrcCTAIdx = multiDimElementwise( + currTileIdx, srcCTAShape, std::modulus()); + auto linearSrcCTAIdx = + mlir::LLVM::linearize(multiDimSrcCTAIdx, srcCTAShape, srcCTAOrder); + auto unpackedElements = unpackedSources[linearSrcIdx]; + + auto startIt = + unpackedElements.begin() + linearSrcCTAIdx * elemsPerThreadPerCTA; + auto endIt = startIt + elemsPerThreadPerCTA; + llvm::append_range(resultVals, llvm::make_range(startIt, endIt)); + } + + Value packedResult = packLLElements(loc, this->getTypeConverter(), + resultVals, rewriter, resultType); + + rewriter.replaceOp(op, packedResult); + return success(); + } +}; +} // namespace + +namespace mlir::triton::AMD { +void populateConcatOpToLLVMPatterns(mlir::LLVMTypeConverter &typeConverter, + mlir::RewritePatternSet &patterns, + mlir::PatternBenefit benefit) { + patterns.add(typeConverter, benefit); +} +} // namespace mlir::triton::AMD diff --git a/third_party/amd/lib/TritonAMDGPUDialectToLLVM/TritonAMDGPUToLLVMPatterns.cpp b/third_party/amd/lib/TritonAMDGPUDialectToLLVM/TritonAMDGPUToLLVMPatterns.cpp index a84d84b2819d..c0cf0fb5fefa 100644 --- a/third_party/amd/lib/TritonAMDGPUDialectToLLVM/TritonAMDGPUToLLVMPatterns.cpp +++ b/third_party/amd/lib/TritonAMDGPUDialectToLLVM/TritonAMDGPUToLLVMPatterns.cpp @@ -7,5 +7,6 @@ void populateTritonAMDGPUToLLVMPatterns(LLVMTypeConverter &typeConverter, PatternBenefit benefit) { populateExtractSliceOpToLLVMPatterns(typeConverter, patterns, benefit); populateInThreadTransposeOpToTTGPatterns(patterns, benefit); + populateConcatOpToLLVMPatterns(typeConverter, patterns, benefit); } } // namespace mlir::triton::AMD diff --git a/third_party/amd/python/test/test_extract_slice.py b/third_party/amd/python/test/test_extract_slice.py deleted file mode 100644 index c52d5d3a6e5a..000000000000 --- a/third_party/amd/python/test/test_extract_slice.py +++ /dev/null @@ -1,112 +0,0 @@ -import pytest -import torch - -import triton - -from triton._internal_testing import is_hip - -num_ctas_list = [1] - -GPU_DIALECT = "ttg" - -if is_hip(): - THREADS_PER_WARP = triton.runtime.driver.active.get_current_target().warp_size -else: - THREADS_PER_WARP = 32 - - -class BlockedLayout: - - def __init__(self, size_per_thread, threads_per_warp, warps_per_cta, order, ctas_per_cga, cta_split_num, cta_order): - self.sz_per_thread = size_per_thread - self.threads_per_warp = threads_per_warp - self.warps_per_cta = warps_per_cta - self.order = order - self.ctas_per_cga = ctas_per_cga - self.cta_split_num = cta_split_num - self.cta_order = cta_order - - def __str__(self): - return f"#{GPU_DIALECT}.blocked<{{sizePerThread={self.sz_per_thread}, threadsPerWarp={self.threads_per_warp}, warpsPerCTA={self.warps_per_cta}, order={self.order}, CTAsPerCGA={self.ctas_per_cga}, CTASplitNum={self.cta_split_num}, CTAOrder={self.cta_order}}}>" - - -# ----------------------- -# test extract slice -# ----------------------- - -extract_layout = [ - BlockedLayout([1, 8], [16, 4], [4, 1], [1, 0], [1, 1], [1, 1], [0, 1]), - BlockedLayout([2, 2], [64, 1], [2, 2], [1, 0], [1, 1], [1, 1], [0, 1]), - BlockedLayout([2, 2], [16, 4], [4, 1], [0, 1], [1, 1], [1, 1], [0, 1]), - BlockedLayout([1, 8], [16, 4], [4, 1], [1, 0], [1, 1], [1, 1], [0, 1]), - # FIXME(Lezcano): This layout errors out - #BlockedLayout([1, 8], [16, 4], [4, 1], [0, 1], [1, 1], [1, 1], [0, 1]), -] -blocked_layout = [ - BlockedLayout([1, 8], [16, 4], [4, 1], [1, 0], [1, 1], [1, 1], [0, 1]), - BlockedLayout([2, 2], [16, 4], [2, 2], [1, 0], [1, 1], [1, 1], [0, 1]), - BlockedLayout([2, 2], [16, 4], [2, 2], [0, 1], [1, 1], [1, 1], [0, 1]), - BlockedLayout([1, 8], [16, 4], [4, 1], [1, 0], [1, 1], [1, 1], [0, 1]), - BlockedLayout([1, 8], [16, 4], [4, 1], [0, 1], [1, 1], [1, 1], [0, 1]), -] - - -@pytest.mark.parametrize("M, N, M_tile_size, N_tile_size, M_tile_offset, N_tile_offset", - [[256, 256, 256, 32, 0, 32], [128, 128, 128, 64, 0, 64]]) -@pytest.mark.parametrize("dtype", [torch.float16]) -@pytest.mark.parametrize("extract_layout", extract_layout) -@pytest.mark.parametrize("blocked_layout", blocked_layout) -def test_extract_slice(dtype, M, N, M_tile_size, N_tile_size, M_tile_offset, N_tile_offset, blocked_layout, - extract_layout, device='cuda'): - if not is_hip(): - pytest.skip("extract_slice is AMD specific instruction.") - - ir = f""" - #blocked = {blocked_layout} - #extract_layout = {extract_layout} - module attributes {{"ttg.num-ctas" = 1, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = {str(64)} : i32}} {{ - tt.func public @kernel(%arg0: !tt.ptr {{tt.divisibility = 16 : i32}}, %arg1: !tt.ptr {{tt.divisibility = 16 : i32}}) {{ - %cst = arith.constant dense<{N}> : tensor<{M}x1xi32, #blocked> - %cst_n = arith.constant dense<{N_tile_size}> : tensor<{M_tile_size}x1xi32, #blocked> - %0 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #ttg.slice<{{dim = 1, parent = #blocked}}>> - %42 = tt.make_range {{end = {M_tile_size} : i32, start = 0 : i32}} : tensor<{M_tile_size}xi32, #ttg.slice<{{dim = 1, parent = #blocked}}>> - %1 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #ttg.slice<{{dim = 0, parent = #blocked}}>> - %2 = tt.splat %arg0 : !tt.ptr -> tensor<{M}x{N}x!tt.ptr, #blocked> - %4 = tt.expand_dims %0 {{axis = 1 : i32}} : tensor<{M}xi32, #ttg.slice<{{dim = 1, parent = #blocked}}>> -> tensor<{M}x1xi32, #blocked> - %43 = tt.expand_dims %42 {{axis = 1 : i32}} : tensor<{M_tile_size}xi32, #ttg.slice<{{dim = 1, parent = #blocked}}>> -> tensor<{M_tile_size}x1xi32, #blocked> - %5 = arith.muli %4, %cst : tensor<{M}x1xi32, #blocked> - %44 = arith.muli %43, %cst_n : tensor<{M_tile_size}x1xi32, #blocked> - %6 = tt.expand_dims %1 {{axis = 0 : i32}} : tensor<{M}xi32, #ttg.slice<{{dim = 0, parent = #blocked}}>> -> tensor<1x{M}xi32, #blocked> - %7 = tt.broadcast %6 : tensor<1x{M}xi32, #blocked> -> tensor<{M}x{N}xi32, #blocked> - %8 = tt.broadcast %5 : tensor<{M}x1xi32, #blocked> -> tensor<{M}x{N}xi32, #blocked> - %9 = arith.addi %8, %7 : tensor<{M}x{N}xi32, #blocked> - %33 = tt.make_range {{end = {N_tile_size} : i32, start = 0 : i32}} : tensor<{N_tile_size}xi32, #ttg.slice<{{dim = 0, parent = #blocked}}>> - %34 = tt.splat %arg1 : !tt.ptr -> tensor<{M_tile_size}x{N_tile_size}x!tt.ptr, #blocked> - %37 = tt.expand_dims %33 {{axis = 0 : i32}} : tensor<{N_tile_size}xi32, #ttg.slice<{{dim = 0, parent = #blocked}}>> -> tensor<1x{N_tile_size}xi32, #blocked> - %38 = tt.broadcast %37 : tensor<1x{N_tile_size}xi32, #blocked> -> tensor<{M_tile_size}x{N_tile_size}xi32, #blocked> - %39 = tt.broadcast %44 : tensor<{M_tile_size}x1xi32, #blocked> -> tensor<{M_tile_size}x{N_tile_size}xi32, #blocked> - %40 = arith.addi %38, %39 : tensor<{M_tile_size}x{N_tile_size}xi32, #blocked> - %10 = tt.addptr %2, %9 : tensor<{M}x{N}x!tt.ptr, #blocked>, tensor<{M}x{N}xi32, #blocked> - %11 = tt.load %10 {{cache = 1 : i32, evict = 1 : i32, isVolatile = false}} : tensor<{M}x{N}x!tt.ptr, #blocked> - %12 = ttg.convert_layout %11 : tensor<{M}x{N}xf16, #blocked> -> tensor<{M}x{N}xf16, #extract_layout> - %13 = amdgpu.extract_slice %12 [{M_tile_offset}, {N_tile_offset}] : tensor<{M}x{N}xf16, #extract_layout> to tensor<{M_tile_size}x{N_tile_size}xf16, #extract_layout> - %14 = ttg.convert_layout %13 : tensor<{M_tile_size}x{N_tile_size}xf16, #extract_layout> -> tensor<{M_tile_size}x{N_tile_size}xf16, #blocked> - %15 = tt.addptr %34, %40 : tensor<{M_tile_size}x{N_tile_size}x!tt.ptr, #blocked>, tensor<{M_tile_size}x{N_tile_size}xi32, #blocked> - tt.store %15, %14 : tensor<{M_tile_size}x{N_tile_size}x!tt.ptr, #blocked> - tt.return - }} - }} - """ - x = torch.randn((M, N), device=device, dtype=torch.float16) - import tempfile - with tempfile.NamedTemporaryFile(mode='w', suffix='.ttgir') as f: - f.write(ir) - f.flush() - kernel = triton.compile(f.name) - - extract_slice = torch.empty((M_tile_size, N_tile_size), device=device, dtype=torch.float16) - - kernel[(1, 1, 1)](x.data_ptr(), extract_slice) - test_result = torch.equal(x[M_tile_offset:M_tile_size + M_tile_offset, N_tile_offset:N_tile_offset + N_tile_size], - extract_slice) - assert test_result diff --git a/third_party/amd/python/test/test_extract_slice_concat_op.py b/third_party/amd/python/test/test_extract_slice_concat_op.py new file mode 100644 index 000000000000..b403a69ebf29 --- /dev/null +++ b/third_party/amd/python/test/test_extract_slice_concat_op.py @@ -0,0 +1,227 @@ +import pytest +import torch + +import triton + +from triton._internal_testing import is_hip + +num_ctas_list = [1] + +GPU_DIALECT = "ttg" + +if is_hip(): + THREADS_PER_WARP = triton.runtime.driver.active.get_current_target().warp_size +else: + THREADS_PER_WARP = 32 + + +class LinearLayout: + + def __init__(self, register, lane, warp, block): + self.register = register + self.lane = lane + self.warp = warp + self.block = block + + def __str__(self): + return f"#{GPU_DIALECT}.linear<{{register={self.register}, lane={self.lane}, warp={self.warp}, block={self.block}}}>" + + +class BlockedLayout: + + def __init__(self, size_per_thread, threads_per_warp, warps_per_cta, order, ctas_per_cga, cta_split_num, cta_order): + self.sz_per_thread = size_per_thread + self.threads_per_warp = threads_per_warp + self.warps_per_cta = warps_per_cta + self.order = order + self.ctas_per_cga = ctas_per_cga + self.cta_split_num = cta_split_num + self.cta_order = cta_order + + def __str__(self): + return f"#{GPU_DIALECT}.blocked<{{sizePerThread={self.sz_per_thread}, threadsPerWarp={self.threads_per_warp}, warpsPerCTA={self.warps_per_cta}, order={self.order}, CTAsPerCGA={self.ctas_per_cga}, CTASplitNum={self.cta_split_num}, CTAOrder={self.cta_order}}}>" + + +# ----------------------- +# test extract slice +# ----------------------- + +extract_layout = [ + BlockedLayout([1, 8], [16, 4], [4, 1], [1, 0], [1, 1], [1, 1], [0, 1]), + BlockedLayout([2, 2], [64, 1], [2, 2], [1, 0], [1, 1], [1, 1], [0, 1]), + BlockedLayout([2, 2], [16, 4], [4, 1], [0, 1], [1, 1], [1, 1], [0, 1]), + BlockedLayout([1, 8], [16, 4], [4, 1], [1, 0], [1, 1], [1, 1], [0, 1]), + # FIXME(Lezcano): This layout errors out + #BlockedLayout([1, 8], [16, 4], [4, 1], [0, 1], [1, 1], [1, 1], [0, 1]), +] +blocked_layout = [ + BlockedLayout([1, 8], [16, 4], [4, 1], [1, 0], [1, 1], [1, 1], [0, 1]), + BlockedLayout([2, 2], [16, 4], [2, 2], [1, 0], [1, 1], [1, 1], [0, 1]), + BlockedLayout([2, 2], [16, 4], [2, 2], [0, 1], [1, 1], [1, 1], [0, 1]), + BlockedLayout([1, 8], [16, 4], [4, 1], [1, 0], [1, 1], [1, 1], [0, 1]), + BlockedLayout([1, 8], [16, 4], [4, 1], [0, 1], [1, 1], [1, 1], [0, 1]), +] + + +@pytest.mark.parametrize("M, N, M_tile_size, N_tile_size, M_tile_offset, N_tile_offset", + [[256, 256, 256, 32, 0, 32], [128, 128, 128, 64, 0, 64]]) +@pytest.mark.parametrize("dtype", [torch.float16]) +@pytest.mark.parametrize("extract_layout", extract_layout) +@pytest.mark.parametrize("blocked_layout", blocked_layout) +def test_extract_slice(dtype, M, N, M_tile_size, N_tile_size, M_tile_offset, N_tile_offset, blocked_layout, + extract_layout, device='cuda'): + if not is_hip(): + pytest.skip("extract_slice is AMD specific instruction.") + + ir = f""" + #blocked = {blocked_layout} + #extract_layout = {extract_layout} + module attributes {{"ttg.num-ctas" = 1, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = {str(64)} : i32}} {{ + tt.func public @kernel(%arg0: !tt.ptr {{tt.divisibility = 16 : i32}}, %arg1: !tt.ptr {{tt.divisibility = 16 : i32}}) {{ + %cst = arith.constant dense<{N}> : tensor<{M}x1xi32, #blocked> + %cst_n = arith.constant dense<{N_tile_size}> : tensor<{M_tile_size}x1xi32, #blocked> + %0 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #ttg.slice<{{dim = 1, parent = #blocked}}>> + %42 = tt.make_range {{end = {M_tile_size} : i32, start = 0 : i32}} : tensor<{M_tile_size}xi32, #ttg.slice<{{dim = 1, parent = #blocked}}>> + %1 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #ttg.slice<{{dim = 0, parent = #blocked}}>> + %2 = tt.splat %arg0 : !tt.ptr -> tensor<{M}x{N}x!tt.ptr, #blocked> + %4 = tt.expand_dims %0 {{axis = 1 : i32}} : tensor<{M}xi32, #ttg.slice<{{dim = 1, parent = #blocked}}>> -> tensor<{M}x1xi32, #blocked> + %43 = tt.expand_dims %42 {{axis = 1 : i32}} : tensor<{M_tile_size}xi32, #ttg.slice<{{dim = 1, parent = #blocked}}>> -> tensor<{M_tile_size}x1xi32, #blocked> + %5 = arith.muli %4, %cst : tensor<{M}x1xi32, #blocked> + %44 = arith.muli %43, %cst_n : tensor<{M_tile_size}x1xi32, #blocked> + %6 = tt.expand_dims %1 {{axis = 0 : i32}} : tensor<{M}xi32, #ttg.slice<{{dim = 0, parent = #blocked}}>> -> tensor<1x{M}xi32, #blocked> + %7 = tt.broadcast %6 : tensor<1x{M}xi32, #blocked> -> tensor<{M}x{N}xi32, #blocked> + %8 = tt.broadcast %5 : tensor<{M}x1xi32, #blocked> -> tensor<{M}x{N}xi32, #blocked> + %9 = arith.addi %8, %7 : tensor<{M}x{N}xi32, #blocked> + %33 = tt.make_range {{end = {N_tile_size} : i32, start = 0 : i32}} : tensor<{N_tile_size}xi32, #ttg.slice<{{dim = 0, parent = #blocked}}>> + %34 = tt.splat %arg1 : !tt.ptr -> tensor<{M_tile_size}x{N_tile_size}x!tt.ptr, #blocked> + %37 = tt.expand_dims %33 {{axis = 0 : i32}} : tensor<{N_tile_size}xi32, #ttg.slice<{{dim = 0, parent = #blocked}}>> -> tensor<1x{N_tile_size}xi32, #blocked> + %38 = tt.broadcast %37 : tensor<1x{N_tile_size}xi32, #blocked> -> tensor<{M_tile_size}x{N_tile_size}xi32, #blocked> + %39 = tt.broadcast %44 : tensor<{M_tile_size}x1xi32, #blocked> -> tensor<{M_tile_size}x{N_tile_size}xi32, #blocked> + %40 = arith.addi %38, %39 : tensor<{M_tile_size}x{N_tile_size}xi32, #blocked> + %10 = tt.addptr %2, %9 : tensor<{M}x{N}x!tt.ptr, #blocked>, tensor<{M}x{N}xi32, #blocked> + %11 = tt.load %10 {{cache = 1 : i32, evict = 1 : i32, isVolatile = false}} : tensor<{M}x{N}x!tt.ptr, #blocked> + %12 = ttg.convert_layout %11 : tensor<{M}x{N}xf16, #blocked> -> tensor<{M}x{N}xf16, #extract_layout> + %13 = amdgpu.extract_slice %12 [{M_tile_offset}, {N_tile_offset}] : tensor<{M}x{N}xf16, #extract_layout> to tensor<{M_tile_size}x{N_tile_size}xf16, #extract_layout> + %14 = ttg.convert_layout %13 : tensor<{M_tile_size}x{N_tile_size}xf16, #extract_layout> -> tensor<{M_tile_size}x{N_tile_size}xf16, #blocked> + %15 = tt.addptr %34, %40 : tensor<{M_tile_size}x{N_tile_size}x!tt.ptr, #blocked>, tensor<{M_tile_size}x{N_tile_size}xi32, #blocked> + tt.store %15, %14 : tensor<{M_tile_size}x{N_tile_size}x!tt.ptr, #blocked> + tt.return + }} + }} + """ + x = torch.randn((M, N), device=device, dtype=torch.float16) + import tempfile + with tempfile.NamedTemporaryFile(mode='w', suffix='.ttgir') as f: + f.write(ir) + f.flush() + kernel = triton.compile(f.name) + + extract_slice = torch.empty((M_tile_size, N_tile_size), device=device, dtype=torch.float16) + + kernel[(1, 1, 1)](x.data_ptr(), extract_slice) + test_result = torch.equal(x[M_tile_offset:M_tile_size + M_tile_offset, N_tile_offset:N_tile_offset + N_tile_size], + extract_slice) + assert test_result + + +# ----------------------- +# test concat op +# ----------------------- + +src_layout = [ + LinearLayout(register=[[0, 1], [0, 2], [0, 8], [0, 16], [0, 64], [64, 0]], lane=[[1, 0], [2, 0], [4, 0], [8, 0], + [16, 0], [0, 4]], warp=[[0, 32], + [32, 0]], + block=[]), + LinearLayout(register=[[1, 0], [2, 0], [4, 0]], lane=[[0, 1], [0, 2], [0, 4], [0, 8], [8, 0], [16, 0]], + warp=[[0, 16]], block=[]), +] + +dst_layout = [ + LinearLayout(register=[[0, 1], [0, 2], [0, 8], [0, 16], [0, 64], [0, 128], [64, 0], [128, 0]], + lane=[[1, 0], [2, 0], [4, 0], [8, 0], [16, 0], [0, 4]], warp=[[0, 32], [32, 0]], block=[]), + LinearLayout(register=[[1, 0], [2, 0], [4, 0], [32, 0], [0, 32]], lane=[[0, 1], [0, 2], [0, 4], [0, 8], [8, 0], + [16, 0]], warp=[[0, 16]], block=[]), +] + + +@pytest.mark.parametrize( + "src_layout, dst_layout, M, N, M_tile_size, N_tile_size", + [[src_layout[0], dst_layout[0], 128, 128, 256, 256], [src_layout[1], dst_layout[1], 32, 32, 64, 64]]) +@pytest.mark.parametrize("dtype", [torch.float16]) +def test_concat_op(dtype, M, N, M_tile_size, N_tile_size, src_layout, dst_layout, device='cuda'): + if not is_hip(): + pytest.skip("concat op is AMD specific instruction.") + + ir = f""" + #blocked = #ttg.blocked<{{sizePerThread=[1, 8], threadsPerWarp=[16, 4], warpsPerCTA=[4, 1], order=[1, 0], CTAsPerCGA=[1, 1], CTASplitNum=[1, 1], CTAOrder=[0, 1]}}> + #src_layout = {src_layout} + #dst_layout = {dst_layout} + + module attributes {{"ttg.num-ctas" = 1, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = {str(64)} : i32}} {{ + tt.func public @kernel(%arg0: !tt.ptr {{tt.divisibility = 16 : i32}}, %arg1: !tt.ptr {{tt.divisibility = 16 : i32}}, %arg2: !tt.ptr {{tt.divisibility = 16 : i32}}, %arg3: !tt.ptr {{tt.divisibility = 16 : i32}}, %arg4: !tt.ptr {{tt.divisibility = 16 : i32}}) {{ + %cst = arith.constant dense<{N}> : tensor<{M}x1xi32, #blocked> + %cst_n = arith.constant dense<{N_tile_size}> : tensor<{M_tile_size}x1xi32, #blocked> + %0 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #ttg.slice<{{dim = 1, parent = #blocked}}>> + %42 = tt.make_range {{end = {M_tile_size} : i32, start = 0 : i32}} : tensor<{M_tile_size}xi32, #ttg.slice<{{dim = 1, parent = #blocked}}>> + %1 = tt.make_range {{end = {M} : i32, start = 0 : i32}} : tensor<{M}xi32, #ttg.slice<{{dim = 0, parent = #blocked}}>> + %2 = tt.splat %arg0 : !tt.ptr -> tensor<{M}x{N}x!tt.ptr, #blocked> + %100 = tt.splat %arg1 : !tt.ptr -> tensor<{M}x{N}x!tt.ptr, #blocked> + %101 = tt.splat %arg2 : !tt.ptr -> tensor<{M}x{N}x!tt.ptr, #blocked> + %102 = tt.splat %arg3 : !tt.ptr -> tensor<{M}x{N}x!tt.ptr, #blocked> + %4 = tt.expand_dims %0 {{axis = 1 : i32}} : tensor<{M}xi32, #ttg.slice<{{dim = 1, parent = #blocked}}>> -> tensor<{M}x1xi32, #blocked> + %43 = tt.expand_dims %42 {{axis = 1 : i32}} : tensor<{M_tile_size}xi32, #ttg.slice<{{dim = 1, parent = #blocked}}>> -> tensor<{M_tile_size}x1xi32, #blocked> + %5 = arith.muli %4, %cst : tensor<{M}x1xi32, #blocked> + %44 = arith.muli %43, %cst_n : tensor<{M_tile_size}x1xi32, #blocked> + %6 = tt.expand_dims %1 {{axis = 0 : i32}} : tensor<{M}xi32, #ttg.slice<{{dim = 0, parent = #blocked}}>> -> tensor<1x{M}xi32, #blocked> + %7 = tt.broadcast %6 : tensor<1x{M}xi32, #blocked> -> tensor<{M}x{N}xi32, #blocked> + %8 = tt.broadcast %5 : tensor<{M}x1xi32, #blocked> -> tensor<{M}x{N}xi32, #blocked> + %9 = arith.addi %8, %7 : tensor<{M}x{N}xi32, #blocked> + %33 = tt.make_range {{end = {N_tile_size} : i32, start = 0 : i32}} : tensor<{N_tile_size}xi32, #ttg.slice<{{dim = 0, parent = #blocked}}>> + %34 = tt.splat %arg4 : !tt.ptr -> tensor<{M_tile_size}x{N_tile_size}x!tt.ptr, #blocked> + %37 = tt.expand_dims %33 {{axis = 0 : i32}} : tensor<{N_tile_size}xi32, #ttg.slice<{{dim = 0, parent = #blocked}}>> -> tensor<1x{N_tile_size}xi32, #blocked> + %38 = tt.broadcast %37 : tensor<1x{N_tile_size}xi32, #blocked> -> tensor<{M_tile_size}x{N_tile_size}xi32, #blocked> + %39 = tt.broadcast %44 : tensor<{M_tile_size}x1xi32, #blocked> -> tensor<{M_tile_size}x{N_tile_size}xi32, #blocked> + %40 = arith.addi %38, %39 : tensor<{M_tile_size}x{N_tile_size}xi32, #blocked> + %10 = tt.addptr %2, %9 : tensor<{M}x{N}x!tt.ptr, #blocked>, tensor<{M}x{N}xi32, #blocked> + %200 = tt.addptr %100, %9 : tensor<{M}x{N}x!tt.ptr, #blocked>, tensor<{M}x{N}xi32, #blocked> + %201 = tt.addptr %101, %9 : tensor<{M}x{N}x!tt.ptr, #blocked>, tensor<{M}x{N}xi32, #blocked> + %202 = tt.addptr %102, %9 : tensor<{M}x{N}x!tt.ptr, #blocked>, tensor<{M}x{N}xi32, #blocked> + %11 = tt.load %10 {{cache = 1 : i32, evict = 1 : i32, isVolatile = false}} : tensor<{M}x{N}x!tt.ptr, #blocked> + %300 = tt.load %200 {{cache = 1 : i32, evict = 1 : i32, isVolatile = false}} : tensor<{M}x{N}x!tt.ptr, #blocked> + %301 = tt.load %201 {{cache = 1 : i32, evict = 1 : i32, isVolatile = false}} : tensor<{M}x{N}x!tt.ptr, #blocked> + %302 = tt.load %202 {{cache = 1 : i32, evict = 1 : i32, isVolatile = false}} : tensor<{M}x{N}x!tt.ptr, #blocked> + + %12 = ttg.convert_layout %11 : tensor<{M}x{N}xf16, #blocked> -> tensor<{M}x{N}xf16, #src_layout> + %400 = ttg.convert_layout %300 : tensor<{M}x{N}xf16, #blocked> -> tensor<{M}x{N}xf16, #src_layout> + %401 = ttg.convert_layout %301 : tensor<{M}x{N}xf16, #blocked> -> tensor<{M}x{N}xf16, #src_layout> + %402 = ttg.convert_layout %302 : tensor<{M}x{N}xf16, #blocked> -> tensor<{M}x{N}xf16, #src_layout> + + %13 = amdgpu.concat %12, %400, %401, %402 : tensor<{M}x{N}xf16, #src_layout>, tensor<{M}x{N}xf16, #src_layout>, tensor<{M}x{N}xf16, #src_layout>, tensor<{M}x{N}xf16, #src_layout> -> tensor<{M_tile_size}x{N_tile_size}xf16, #dst_layout> + %14 = ttg.convert_layout %13 : tensor<{M_tile_size}x{N_tile_size}xf16, #dst_layout> -> tensor<{M_tile_size}x{N_tile_size}xf16, #blocked> + %15 = tt.addptr %34, %40 : tensor<{M_tile_size}x{N_tile_size}x!tt.ptr, #blocked>, tensor<{M_tile_size}x{N_tile_size}xi32, #blocked> + tt.store %15, %14 : tensor<{M_tile_size}x{N_tile_size}x!tt.ptr, #blocked> + tt.return + }} + }} + """ + x1 = torch.randn((M, N), device=device, dtype=torch.float16) + x2 = torch.randn((M, N), device=device, dtype=torch.float16) + x3 = torch.randn((M, N), device=device, dtype=torch.float16) + x4 = torch.randn((M, N), device=device, dtype=torch.float16) + + import tempfile + with tempfile.NamedTemporaryFile(mode='w', suffix='.ttgir') as f: + f.write(ir) + f.flush() + kernel = triton.compile(f.name) + + concat = torch.empty((M_tile_size, N_tile_size), device=device, dtype=torch.float16) + kernel[(1, 1, 1)](x1.data_ptr(), x2.data_ptr(), x3.data_ptr(), x4.data_ptr(), concat) + + top = torch.cat([x1, x2], dim=1) + bottom = torch.cat([x3, x4], dim=1) + result = torch.cat([top, bottom], dim=0) + + test_result = torch.equal(result, concat) + assert test_result From 6a6fb701b54c4d0e9789ff940c92e80627244aae Mon Sep 17 00:00:00 2001 From: Alexander Weinrauch Date: Wed, 21 May 2025 11:33:09 +0000 Subject: [PATCH 43/44] WA for incorrect strides in subview --- .../Conversion/TritonGPUToLLVM/Utility.h | 12 ++++-- lib/Conversion/TritonGPUToLLVM/Utility.cpp | 41 ++++++++++++++----- 2 files changed, 39 insertions(+), 14 deletions(-) diff --git a/include/triton/Conversion/TritonGPUToLLVM/Utility.h b/include/triton/Conversion/TritonGPUToLLVM/Utility.h index 09ef9904b17e..54930e45f8c8 100644 --- a/include/triton/Conversion/TritonGPUToLLVM/Utility.h +++ b/include/triton/Conversion/TritonGPUToLLVM/Utility.h @@ -376,9 +376,13 @@ class SharedMemoryObject { return types; } - SmallVector getStrides(triton::gpu::MemDescType memDesc, Location loc, - RewriterBase &rewriter) const { + SmallVector + getStrides(triton::gpu::MemDescType memDesc, Location loc, + RewriterBase &rewriter, + ArrayRef overwriteAllocSize = {}) const { auto allocShape = memDesc.getAllocShape(); + if (!overwriteAllocSize.empty()) + allocShape = overwriteAllocSize; auto allocShapePerCTA = triton::gpu::getAllocationShapePerCTA( memDesc.getEncoding(), allocShape); auto layoutOrder = triton::gpu::getOrder(memDesc); @@ -699,14 +703,14 @@ emitIndices(Location loc, RewriterBase &rewriter, const TargetInfoBase &target, const SharedMemoryObject &smemObj, Location loc, RewriterBase &rewriter, const TargetInfoBase &target, std::function perVectorCallback, - bool forceLane0 = false); + bool forceLane0 = false, ArrayRef overwriteAllocSize = {}); [[nodiscard]] bool emitTransferBetweenRegistersAndShared( LinearLayout ®Layout, triton::gpu::MemDescType sharedTy, Type elemLlvmTy, std::optional maxVecElems, const SharedMemoryObject &smemObj, Location loc, RewriterBase &rewriter, const TargetInfoBase &target, std::function perVectorCallback, - bool forceLane0 = false); + bool forceLane0 = false, ArrayRef overwriteAllocSize = {}); SmallVector loadSharedToDistributed(triton::gpu::LocalLoadOp localLoadOp, Type elemLlvmTy, diff --git a/lib/Conversion/TritonGPUToLLVM/Utility.cpp b/lib/Conversion/TritonGPUToLLVM/Utility.cpp index 8f691b673d6d..94f26772bf09 100644 --- a/lib/Conversion/TritonGPUToLLVM/Utility.cpp +++ b/lib/Conversion/TritonGPUToLLVM/Utility.cpp @@ -277,7 +277,8 @@ Value getSmemVecAddr(const LinearLayout ®Layout, const SharedMemoryObject &smemObj, triton::gpu::MemDescType sharedTy, Type elemLlvmTy, Value regId, Value laneId, Value warpId, Value blockId, - Location loc, RewriterBase &rewriter) { + Location loc, RewriterBase &rewriter, + ArrayRef overwriteAllocSize) { auto b = TritonLLVMOpBuilder(loc, rewriter); MLIRContext *ctx = rewriter.getContext(); StringAttr kBlock = str_attr("block"); @@ -292,7 +293,8 @@ Value getSmemVecAddr(const LinearLayout ®Layout, auto smemBase = smemObj.getBase(); auto smemOffsets = smemObj.getOffsets(); - auto smemStrides = smemObj.getStrides(sharedTy, loc, rewriter); + auto smemStrides = + smemObj.getStrides(sharedTy, loc, rewriter, overwriteAllocSize); Value smemOffset; // When loading or storing to shared memory, we consider two cases for // performance reasons: @@ -410,7 +412,7 @@ bool emitTransferBetweenRegistersAndShared( std::optional maxVecElems, const SharedMemoryObject &smemObj, Location loc, RewriterBase &rewriter, const TargetInfoBase &target, std::function perVectorCallback, - bool forceLane0) { + bool forceLane0, ArrayRef overwriteAllocSize) { MLIRContext *ctx = rewriter.getContext(); auto b = TritonLLVMOpBuilder(loc, rewriter); @@ -485,9 +487,10 @@ bool emitTransferBetweenRegistersAndShared( SmallVector ret; for (int i = 0; i < numElems / vecElems; i++) { auto regId = b.i32_val(i * vecElems); - auto vecAddr = getSmemVecAddr( - regLayout, regToSharedLayout, invertAllocSharedLayout, smemObj, - sharedTy, elemLlvmTy, regId, laneId, warpId, blockId, loc, rewriter); + auto vecAddr = + getSmemVecAddr(regLayout, regToSharedLayout, invertAllocSharedLayout, + smemObj, sharedTy, elemLlvmTy, regId, laneId, warpId, + blockId, loc, rewriter, overwriteAllocSize); perVectorCallback(vecTy, vecAddr); } return true; @@ -499,12 +502,12 @@ bool emitTransferBetweenRegistersAndShared( const SharedMemoryObject &smemObj, Location loc, RewriterBase &rewriter, const TargetInfoBase &target, std::function perVectorCallback, - bool forceLane0) { + bool forceLane0, ArrayRef overwriteAllocSize) { auto regLayout = triton::gpu::toLinearLayout(registerTy.getShape(), registerTy.getEncoding()); return emitTransferBetweenRegistersAndShared( regLayout, sharedTy, elemLlvmTy, maxVecElems, smemObj, loc, rewriter, - target, perVectorCallback, forceLane0); + target, perVectorCallback, forceLane0, overwriteAllocSize); } SmallVector loadSharedToDistributed(triton::gpu::LocalLoadOp localLoadOp, @@ -515,11 +518,28 @@ SmallVector loadSharedToDistributed(triton::gpu::LocalLoadOp localLoadOp, auto srcTy = localLoadOp.getSrc().getType(); auto dstTy = localLoadOp.getResult().getType(); + // We overwrite the alloc size if we are a subview to fix subviews in the + // fastest dim + SmallVector overwriteSmemAllocSize; + auto src = localLoadOp.getSrc(); + if (auto subView = src.getDefiningOp()) { + auto subViewSrcTy = + dyn_cast(subView.getSrc().getType()); + if (subViewSrcTy) { + auto origAllocSize = subViewSrcTy.getAllocShape(); + auto srcAllocSize = srcTy.getAllocShape(); + if (origAllocSize.size() == 3 && srcAllocSize.size() == 2) { + overwriteSmemAllocSize = to_vector(origAllocSize.drop_front()); + } + } + } + auto b = TritonLLVMOpBuilder(loc, rewriter); SmallVector ret; bool success = emitTransferBetweenRegistersAndShared( dstTy, srcTy, elemLlvmTy, /*maxVecElems=*/std::nullopt, smemObj, loc, - rewriter, target, [&](VectorType vecTy, Value vecAddr) { + rewriter, target, + [&](VectorType vecTy, Value vecAddr) { auto vecVal = b.load(vecTy, vecAddr); target.localLoadOpAnnotation(localLoadOp, vecVal); vecVal.setAlignment(vecTy.getNumElements() * @@ -528,7 +548,8 @@ SmallVector loadSharedToDistributed(triton::gpu::LocalLoadOp localLoadOp, for (int v = 0; v < vecTy.getNumElements(); v++) { ret.push_back(b.extract_element(elemLlvmTy, vecVal, b.i32_val(v))); } - }); + }, + false, overwriteSmemAllocSize); if (!success) llvm::report_fatal_error("Failed to emit transfer from shared to register"); From 34538bced40ac5bee5948bc7814b8896fa943038 Mon Sep 17 00:00:00 2001 From: ravil-mobile Date: Mon, 26 May 2025 13:23:17 +0000 Subject: [PATCH 44/44] [AMD] improved subviewing for async-copy-local-to-global --- .../TritonAMDGPUTransforms/BlockPingpong.cpp | 38 +++++++++++++++---- 1 file changed, 30 insertions(+), 8 deletions(-) diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/BlockPingpong.cpp b/third_party/amd/lib/TritonAMDGPUTransforms/BlockPingpong.cpp index d7ae4708e637..163080e46be8 100644 --- a/third_party/amd/lib/TritonAMDGPUTransforms/BlockPingpong.cpp +++ b/third_party/amd/lib/TritonAMDGPUTransforms/BlockPingpong.cpp @@ -961,6 +961,14 @@ LogicalResult Pingponger::genAsyncCopySlices(OpBuilder &builder) { return resValue; }; + auto origSubViewType = subView.getType(); + auto subViewDescType = ttg::MemDescType::get( + slicedShape, origSubViewType.getElementType(), + origSubViewType.getEncoding(), origSubViewType.getMemorySpace(), + origSubViewType.getMutableMemory(), + subView.getSrc().getType().getShape()); + Value subViewSelector = subView.getOffsets().front(); + assert(slicedDim != -1); SmallVector newCommits; auto numReps = origShape[slicedDim] / slicedShape[slicedDim]; @@ -973,20 +981,34 @@ LogicalResult Pingponger::genAsyncCopySlices(OpBuilder &builder) { auto extractedMask = extract(slicedMaskType, newMask, offsetAttr); auto extractedOther = extract(slicedOtherType, newOther, offsetAttr); + SmallVector newSubviewOffset = {subViewSelector}; + llvm::for_each(offset, [&](auto off) { + newSubviewOffset.push_back( + builder.create(subView.getLoc(), off, 32)); + }); + + auto newSlicedSubView = builder.create( + subView.getLoc(), subViewDescType, subView.getSrc(), + newSubviewOffset); + auto newAsyncCopy = builder.create( - asyncCopy->getLoc(), extractedSrc, Value{subViews[rep].getResult()}, - extractedMask, extractedOther, asyncCopy.getCache(), - asyncCopy.getEvict(), asyncCopy.getIsVolatile()); + asyncCopy->getLoc(), extractedSrc, + Value{newSlicedSubView.getResult()}, extractedMask, extractedOther, + asyncCopy.getCache(), asyncCopy.getEvict(), + asyncCopy.getIsVolatile()); auto newCommit = builder.create( asyncCopy->getLoc(), newAsyncCopy.getToken()); // propagate all attributes from `mem-view` to the commit token + newSlicedSubView->setAttrs(subViews[rep]->getAttrs()); newAsyncCopy->setAttrs(subViews[rep]->getAttrs()); newCommit->setAttrs(subViews[rep]->getAttrs()); newAsyncGroups[rep].push_back(newCommit); newCommits.push_back(newCommit); + + subViews[rep]->erase(); } auto origCommitGroup = getSingleUserOf(asyncCopy); @@ -1739,11 +1761,11 @@ void Pingponger::getDotPingponged() { LDBG("failed to update forOp signature"); } - if (llvm::succeeded(updateSignature)) { - if (llvm::failed(adjustRefinedAsyncTokens(builder))) { - LDBG("failed to update forOp signature"); - } - } + // if (llvm::succeeded(updateSignature)) { + // if (llvm::failed(adjustRefinedAsyncTokens(builder))) { + // LDBG("failed to update forOp signature"); + // } + // } forOp->walk([](ttg::AsyncCommitGroupOp groupOp) { auto users = groupOp.getResult().getUsers();