From 7e18bb7dbde5975f5ee700fbedce1ac8e5a3626c Mon Sep 17 00:00:00 2001 From: Aristotle Martin Date: Wed, 20 May 2026 21:13:14 +0000 Subject: [PATCH 1/5] add shape specialization for ck tile fp8 ggemm --- .../gemm/ck_grouped_gemm/ck_grouped_gemm.cpp | 1 + .../ck_grouped_gemm/ck_grouped_gemm_common.h | 1 + .../ck_grouped_gemm/ck_grouped_gemm_fp16.cpp | 10 ++--- .../ck_grouped_gemm/ck_grouped_gemm_fp8.cpp | 41 ++++++++++++++----- 4 files changed, 38 insertions(+), 15 deletions(-) diff --git a/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm.cpp b/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm.cpp index 5684be1cd..10231d2ab 100644 --- a/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm.cpp +++ b/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm.cpp @@ -156,6 +156,7 @@ bool ck_tile_grouped_gemm(const NVTETensor* A, B_use, D, static_cast(n), + static_cast(kA), group_num, transA_use, transB_use, diff --git a/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_common.h b/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_common.h index c89f10232..7675ad088 100644 --- a/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_common.h +++ b/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_common.h @@ -75,6 +75,7 @@ struct GroupedGemmRunContext { const NVTETensor* B = nullptr; NVTETensor* D = nullptr; int64_t N = 0; + int64_t K = 0; int group_num = 0; bool transA = false; diff --git a/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_fp16.cpp b/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_fp16.cpp index 660dbefb8..42ddaea99 100644 --- a/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_fp16.cpp +++ b/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_fp16.cpp @@ -196,7 +196,7 @@ class GroupedGemmRunner : public RunnerInterface { } }; -#define MAKE_RUNNER(TileCfg_) \ +#define MAKE_FP16_RUNNER(TileCfg_) \ TRANSFORMER_ENGINE_SWITCH_CONDITION(ctx.accumulate, accum_option, { \ using Runner = GroupedGemmRunner::type; if (ctx.N % 256 == 0) { - MAKE_RUNNER(TileCfg_256x256x64); + MAKE_FP16_RUNNER(TileCfg_256x256x64); } else if (ctx.N % 128 == 0) { - MAKE_RUNNER(TileCfg_256x128x64); + MAKE_FP16_RUNNER(TileCfg_256x128x64); } else { - MAKE_RUNNER(TileCfg_256x128x64_padding); + MAKE_FP16_RUNNER(TileCfg_256x128x64_padding); } }); }); @@ -249,7 +249,7 @@ bool ck_tile_grouped_gemm_fp16_dispatch(DType a_dtype, return runner->run(s, ctx); } -#undef MAKE_RUNNER +#undef MAKE_FP16_RUNNER } // namespace grouped_gemm } // namespace transformer_engine diff --git a/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_fp8.cpp b/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_fp8.cpp index 50b701c05..94d2481f3 100644 --- a/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_fp8.cpp +++ b/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_fp8.cpp @@ -45,6 +45,13 @@ struct TileCfg_128x128x128_16x16x128_2x2x1 { static constexpr ck_tile::index_t TilePartitionerM01 = 8; }; +struct TileCfg_256x256x128_16x16x128_2x2x1 + : TileCfg_128x128x128_16x16x128_2x2x1 { + static constexpr ck_tile::index_t M_Tile = 256; + static constexpr ck_tile::index_t N_Tile = 256; + static constexpr ck_tile::index_t K_Tile = 128; +}; + // gfx950 device compilation cannot instantiate the literal 32x32x16 FP8 tile // configuration due to an unsupported warp GEMM dispatcher configuration. // See: ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp for supported variants. @@ -290,6 +297,17 @@ struct FP8TileCfg { using type = TileCfg_128x128x128_16x16x128_2x2x1; }; +#define MAKE_FP8_RUNNER(TileCfg_) \ + using Runner = QuantGroupedGemmRunner; \ + runner = std::make_unique() + template static bool ck_tile_grouped_gemm_fp8_dispatch_arch(DType a_dtype, DType b_dtype, @@ -299,7 +317,6 @@ static bool ck_tile_grouped_gemm_fp8_dispatch_arch(DType a_dtype, std::unique_ptr runner = nullptr; using CTypeLayout = RowMajor; - using TileCfg = typename FP8TileCfg::type; TRANSFORMER_ENGINE_SWITCH_CONDITION(ctx.transA, kTransA, { using ALayout = std::conditional_t; @@ -315,15 +332,17 @@ static bool ck_tile_grouped_gemm_fp8_dispatch_arch(DType a_dtype, TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(d_dtype, d_te_type, { using CType = typename TETypeToCKType::type; - using Runner = QuantGroupedGemmRunner; - runner = std::make_unique(); + + if constexpr (Arch == GPUArch::GFX950) { + if (ctx.K >= 2048 || ctx.N >= 2048) { + MAKE_FP8_RUNNER(TileCfg_256x256x128_16x16x128_2x2x1); + } else { + MAKE_FP8_RUNNER(TileCfg_128x128x128_16x16x128_2x2x1); + } + } else { + using TileCfg = typename FP8TileCfg::type; + MAKE_FP8_RUNNER(TileCfg); + } }); }); }); @@ -337,6 +356,8 @@ static bool ck_tile_grouped_gemm_fp8_dispatch_arch(DType a_dtype, return runner->run(s, ctx); } +#undef MAKE_FP8_RUNNER + bool ck_tile_grouped_gemm_fp8_dispatch(DType a_dtype, DType b_dtype, DType d_dtype, From 7c335647ccbd03617c9ae42663348673bc81edee Mon Sep 17 00:00:00 2001 From: Aristotle Martin Date: Thu, 21 May 2026 03:37:50 +0000 Subject: [PATCH 2/5] Fix CK grouped GEMM FP8 dtype gating for columnwise operands --- .../gemm/ck_grouped_gemm/ck_grouped_gemm.cpp | 7 +++++-- .../gemm/ck_grouped_gemm/ck_grouped_gemm_fp8.cpp | 10 +++++++--- transformer_engine/common/gemm/cublaslt_gemm.cu | 15 +++++++++++++-- 3 files changed, 25 insertions(+), 7 deletions(-) diff --git a/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm.cpp b/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm.cpp index 10231d2ab..7f8a7067c 100644 --- a/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm.cpp +++ b/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm.cpp @@ -94,8 +94,11 @@ bool ck_tile_grouped_gemm(const NVTETensor* A, } } - const auto a_dtype = convertNVTETensorCheck(A_use[0])->dtype(); - const auto b_dtype = convertNVTETensorCheck(B_use[0])->dtype(); + const auto& A0_data = use_a_colwise_data ? A0_te->columnwise_data : A0_te->data; + const auto& B0_data = use_b_colwise_data ? B0_te->columnwise_data : B0_te->data; + + const auto a_dtype = A0_data.dtype; + const auto b_dtype = B0_data.dtype; Tensor* D0_te = convertNVTETensorCheck(D[0]); const auto d_dtype = D0_te->dtype(); diff --git a/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_fp8.cpp b/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_fp8.cpp index 94d2481f3..81c809571 100644 --- a/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_fp8.cpp +++ b/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_fp8.cpp @@ -268,7 +268,9 @@ class QuantGroupedGemmRunner : public RunnerInterface { if (descs.empty()) { return false; } - return launch_grouped_gemm_kernel(descs, ctx, stream_cfg); + + const bool launched = launch_grouped_gemm_kernel(descs, ctx, stream_cfg); + return launched; } }; @@ -334,7 +336,7 @@ static bool ck_tile_grouped_gemm_fp8_dispatch_arch(DType a_dtype, using CType = typename TETypeToCKType::type; if constexpr (Arch == GPUArch::GFX950) { - if (ctx.K >= 2048 || ctx.N >= 2048) { + if (ctx.K >= 2048 && ctx.N >= 2048) { MAKE_FP8_RUNNER(TileCfg_256x256x128_16x16x128_2x2x1); } else { MAKE_FP8_RUNNER(TileCfg_128x128x128_16x16x128_2x2x1); @@ -353,7 +355,8 @@ static bool ck_tile_grouped_gemm_fp8_dispatch_arch(DType a_dtype, return false; } - return runner->run(s, ctx); + const bool ok = runner->run(s, ctx); + return ok; } #undef MAKE_FP8_RUNNER @@ -375,3 +378,4 @@ bool ck_tile_grouped_gemm_fp8_dispatch(DType a_dtype, } // namespace grouped_gemm } // namespace transformer_engine + diff --git a/transformer_engine/common/gemm/cublaslt_gemm.cu b/transformer_engine/common/gemm/cublaslt_gemm.cu index 7326f330f..03df8751b 100644 --- a/transformer_engine/common/gemm/cublaslt_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_gemm.cu @@ -1160,9 +1160,20 @@ void nvte_multi_tensor_gemm(const NVTETensor *A, const NVTETensor *B, NVTETensor auto *inputB = transformer_engine::convertNVTETensorCheck(B[0]); auto *OutputD = transformer_engine::convertNVTETensorCheck(D[0]); #ifdef __HIP_PLATFORM_AMD__ - auto A_dt = inputA->data.dtype; - auto B_dt = inputB->data.dtype; + auto effective_dtype = [](const transformer_engine::Tensor* t) { + if (is_fp8_dtype(t->data.dtype)) { + return t->data.dtype; + } + if (t->has_columnwise_data() && is_fp8_dtype(t->columnwise_data.dtype)) { + return t->columnwise_data.dtype; + } + return t->data.dtype; + }; + + auto A_dt = effective_dtype(inputA); + auto B_dt = effective_dtype(inputB); auto D_dt = OutputD->data.dtype; + return ( (is_fp8_dtype(A_dt) && is_fp8_dtype(B_dt)) ) || From bb493187eba56e611b331f60bc5b2ec20407f162 Mon Sep 17 00:00:00 2001 From: Aristotle Martin Date: Thu, 21 May 2026 13:06:53 +0000 Subject: [PATCH 3/5] Add FP8 grouped GEMM tile selection guards and support for N-padding --- .../ck_grouped_gemm/ck_grouped_gemm_fp8.cpp | 31 ++++++++++++------- 1 file changed, 20 insertions(+), 11 deletions(-) diff --git a/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_fp8.cpp b/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_fp8.cpp index 81c809571..c2d307d26 100644 --- a/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_fp8.cpp +++ b/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_fp8.cpp @@ -22,9 +22,9 @@ enum class GPUArch { UNKNOWN }; -struct TileCfg_128x128x128_16x16x128_2x2x1 { - static constexpr ck_tile::index_t M_Tile = 128; - static constexpr ck_tile::index_t N_Tile = 128; +struct TileCfg_256x256x128_16x16x128_2x2x1 { + static constexpr ck_tile::index_t M_Tile = 256; + static constexpr ck_tile::index_t N_Tile = 256; static constexpr ck_tile::index_t K_Tile = 128; static constexpr ck_tile::index_t M_Warp = 2; @@ -45,11 +45,15 @@ struct TileCfg_128x128x128_16x16x128_2x2x1 { static constexpr ck_tile::index_t TilePartitionerM01 = 8; }; -struct TileCfg_256x256x128_16x16x128_2x2x1 +struct TileCfg_128x128x128_16x16x128_2x2x1 + : TileCfg_256x256x128_16x16x128_2x2x1 { + static constexpr ck_tile::index_t M_Tile = 128; + static constexpr ck_tile::index_t N_Tile = 128; +}; + +struct TileCfg_128x128x128_16x16x128_2x2x1_padding : TileCfg_128x128x128_16x16x128_2x2x1 { - static constexpr ck_tile::index_t M_Tile = 256; - static constexpr ck_tile::index_t N_Tile = 256; - static constexpr ck_tile::index_t K_Tile = 128; + static constexpr bool kPadN = true; }; // gfx950 device compilation cannot instantiate the literal 32x32x16 FP8 tile @@ -57,8 +61,9 @@ struct TileCfg_256x256x128_16x16x128_2x2x1 // See: ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp for supported variants. // // To preserve the existing type name in shared template code, this struct -// inherits from the gfx950-safe 16x16x128 configuration in the gfx950 device -// compilation path, effectively reusing those parameters without redefining them. +// inherits from the gfx950-safe 128x128x128 16x16x128 configuration in the +// gfx950 device compilation path, effectively reusing those parameters without +// redefining them. // // In all other compilation paths, the struct overrides the relevant fields to // provide the intended 32x32x16 configuration. @@ -336,10 +341,14 @@ static bool ck_tile_grouped_gemm_fp8_dispatch_arch(DType a_dtype, using CType = typename TETypeToCKType::type; if constexpr (Arch == GPUArch::GFX950) { - if (ctx.K >= 2048 && ctx.N >= 2048) { + if (ctx.K % 128 != 0) { + NVTE_WARN("ck_tile_grouped_gemm: (FP8) K must be a multiple of 128. Falling back."); + } else if (ctx.N % 256 == 0) { MAKE_FP8_RUNNER(TileCfg_256x256x128_16x16x128_2x2x1); - } else { + } else if (ctx.N % 128 == 0) { MAKE_FP8_RUNNER(TileCfg_128x128x128_16x16x128_2x2x1); + } else { + MAKE_FP8_RUNNER(TileCfg_128x128x128_16x16x128_2x2x1_padding); } } else { using TileCfg = typename FP8TileCfg::type; From 942c0b999e753336d0e6c1ea8b7106045766bf1c Mon Sep 17 00:00:00 2001 From: Aristotle Martin Date: Thu, 21 May 2026 15:15:42 +0000 Subject: [PATCH 4/5] add k padding support --- .../ck_grouped_gemm/ck_grouped_gemm_fp8.cpp | 82 +++++++++++++------ 1 file changed, 56 insertions(+), 26 deletions(-) diff --git a/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_fp8.cpp b/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_fp8.cpp index c2d307d26..57f20804b 100644 --- a/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_fp8.cpp +++ b/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_fp8.cpp @@ -51,9 +51,25 @@ struct TileCfg_128x128x128_16x16x128_2x2x1 static constexpr ck_tile::index_t N_Tile = 128; }; -struct TileCfg_128x128x128_16x16x128_2x2x1_padding +struct TileCfg_256x256x128_16x16x128_2x2x1_kpad + : TileCfg_256x256x128_16x16x128_2x2x1 { + static constexpr bool kPadK = true; +}; + +struct TileCfg_128x128x128_16x16x128_2x2x1_kpad + : TileCfg_128x128x128_16x16x128_2x2x1 { + static constexpr bool kPadK = true; +}; + +struct TileCfg_128x128x128_16x16x128_2x2x1_npad + : TileCfg_128x128x128_16x16x128_2x2x1 { + static constexpr bool kPadN = true; +}; + +struct TileCfg_128x128x128_16x16x128_2x2x1_nkpad : TileCfg_128x128x128_16x16x128_2x2x1 { static constexpr bool kPadN = true; + static constexpr bool kPadK = true; }; // gfx950 device compilation cannot instantiate the literal 32x32x16 FP8 tile @@ -325,37 +341,52 @@ static bool ck_tile_grouped_gemm_fp8_dispatch_arch(DType a_dtype, using CTypeLayout = RowMajor; - TRANSFORMER_ENGINE_SWITCH_CONDITION(ctx.transA, kTransA, { - using ALayout = std::conditional_t; + // FP8 grouped GEMM is only compiled for CK's preferred NT presentation: + // transA=false, transB=true + // which maps to: + // ALayout=RowMajor, BLayout=ColMajor. + // + // The caller is responsible for rewriting other FP8 layouts into this form + // using columnwise_data when needed. Reject anything that did not normalize + // successfully so we do not instantiate unreachable/unsupported layout variants. + if (ctx.transA || !ctx.transB) { + return false; + } - TRANSFORMER_ENGINE_SWITCH_CONDITION(ctx.transB, kTransB, { - using BLayout = std::conditional_t; + using ALayout = RowMajor; + using BLayout = ColMajor; - TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(a_dtype, a_te_type, { - using AType = typename TETypeToCKType::type; + TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(a_dtype, a_te_type, { + using AType = typename TETypeToCKType::type; - TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(b_dtype, b_te_type, { - using BType = typename TETypeToCKType::type; + TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(b_dtype, b_te_type, { + using BType = typename TETypeToCKType::type; - TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(d_dtype, d_te_type, { - using CType = typename TETypeToCKType::type; + TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(d_dtype, d_te_type, { + using CType = typename TETypeToCKType::type; - if constexpr (Arch == GPUArch::GFX950) { - if (ctx.K % 128 != 0) { - NVTE_WARN("ck_tile_grouped_gemm: (FP8) K must be a multiple of 128. Falling back."); - } else if (ctx.N % 256 == 0) { - MAKE_FP8_RUNNER(TileCfg_256x256x128_16x16x128_2x2x1); - } else if (ctx.N % 128 == 0) { - MAKE_FP8_RUNNER(TileCfg_128x128x128_16x16x128_2x2x1); - } else { - MAKE_FP8_RUNNER(TileCfg_128x128x128_16x16x128_2x2x1_padding); - } + if constexpr (Arch == GPUArch::GFX950) { + if (ctx.N % 256 == 0) { + if (ctx.K % 128 == 0) { + MAKE_FP8_RUNNER(TileCfg_256x256x128_16x16x128_2x2x1); } else { - using TileCfg = typename FP8TileCfg::type; - MAKE_FP8_RUNNER(TileCfg); + MAKE_FP8_RUNNER(TileCfg_256x256x128_16x16x128_2x2x1_kpad); } - }); - }); + } else if (ctx.N % 128 == 0) { + if (ctx.K % 128 == 0) { + MAKE_FP8_RUNNER(TileCfg_128x128x128_16x16x128_2x2x1); + } else { + MAKE_FP8_RUNNER(TileCfg_128x128x128_16x16x128_2x2x1_kpad); + } + } else if (ctx.K % 128 == 0) { + MAKE_FP8_RUNNER(TileCfg_128x128x128_16x16x128_2x2x1_npad); + } else { + MAKE_FP8_RUNNER(TileCfg_128x128x128_16x16x128_2x2x1_nkpad); + } + } else { + using TileCfg = typename FP8TileCfg::type; + MAKE_FP8_RUNNER(TileCfg); + } }); }); }); @@ -387,4 +418,3 @@ bool ck_tile_grouped_gemm_fp8_dispatch(DType a_dtype, } // namespace grouped_gemm } // namespace transformer_engine - From 5e2dd3358d340f638af0aac7051f76605c5ad87d Mon Sep 17 00:00:00 2001 From: Aristotle Martin Date: Thu, 21 May 2026 18:52:41 +0000 Subject: [PATCH 5/5] padding needs to be across experts --- .../ck_grouped_gemm/ck_grouped_gemm_fp8.cpp | 75 +++++++++++++++++-- 1 file changed, 69 insertions(+), 6 deletions(-) diff --git a/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_fp8.cpp b/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_fp8.cpp index 57f20804b..fead8a048 100644 --- a/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_fp8.cpp +++ b/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_fp8.cpp @@ -320,6 +320,67 @@ struct FP8TileCfg { using type = TileCfg_128x128x128_16x16x128_2x2x1; }; +struct FP8GroupedShapeAlignment { + bool all_n_256_aligned = true; + bool all_n_128_aligned = true; + bool all_k_128_aligned = true; +}; + +static FP8GroupedShapeAlignment get_fp8_grouped_shape_alignment( + const GroupedGemmRunContext& ctx) { + FP8GroupedShapeAlignment alignment; + + for (int i = 0; i < ctx.group_num; ++i) { + const transformer_engine::Tensor* const A_te = + transformer_engine::convertNVTETensorCheck(ctx.A[i]); + const transformer_engine::Tensor* const B_te = + transformer_engine::convertNVTETensorCheck(ctx.B[i]); + + int64_t Ad0 = 0, Ad1 = 0, Bd0 = 0, Bd1 = 0; + + if (ctx.use_a_columnwise_data) { + if (!get_columnwise_storage_2d_dims(A_te->columnwise_data, Ad0, Ad1)) { + NVTE_ERROR("ck_tile_grouped_gemm: expected 2D columnwise_data for A in group ", i); + } + } else { + if (!get_flat_2d_dims(*A_te, Ad0, Ad1)) { + NVTE_ERROR("ck_tile_grouped_gemm: expected rank>=2 for normalized A in group ", i); + } + } + + if (ctx.use_b_columnwise_data) { + if (!get_columnwise_storage_2d_dims(B_te->columnwise_data, Bd0, Bd1)) { + NVTE_ERROR("ck_tile_grouped_gemm: expected 2D columnwise_data for B in group ", i); + } + } else { + if (!get_flat_2d_dims(*B_te, Bd0, Bd1)) { + NVTE_ERROR("ck_tile_grouped_gemm: expected rank>=2 for normalized B in group ", i); + } + } + + const int64_t K = ctx.transA ? Ad0 : Ad1; + const int64_t N = ctx.transB ? Bd0 : Bd1; + + if (N % 256 != 0) { + alignment.all_n_256_aligned = false; + } + if (N % 128 != 0) { + alignment.all_n_128_aligned = false; + } + if (K % 128 != 0) { + alignment.all_k_128_aligned = false; + } + + if (!alignment.all_n_256_aligned && + !alignment.all_n_128_aligned && + !alignment.all_k_128_aligned) { + break; + } + } + + return alignment; +} + #define MAKE_FP8_RUNNER(TileCfg_) \ using Runner = QuantGroupedGemmRunner::type; - if constexpr (Arch == GPUArch::GFX950) { - if (ctx.N % 256 == 0) { - if (ctx.K % 128 == 0) { + if constexpr (Arch == GPUArch::GFX950) { + const auto alignment = get_fp8_grouped_shape_alignment(ctx); + + if (alignment.all_n_256_aligned) { + if (alignment.all_k_128_aligned) { MAKE_FP8_RUNNER(TileCfg_256x256x128_16x16x128_2x2x1); } else { MAKE_FP8_RUNNER(TileCfg_256x256x128_16x16x128_2x2x1_kpad); } - } else if (ctx.N % 128 == 0) { - if (ctx.K % 128 == 0) { + } else if (alignment.all_n_128_aligned) { + if (alignment.all_k_128_aligned) { MAKE_FP8_RUNNER(TileCfg_128x128x128_16x16x128_2x2x1); } else { MAKE_FP8_RUNNER(TileCfg_128x128x128_16x16x128_2x2x1_kpad); } - } else if (ctx.K % 128 == 0) { + } else if (alignment.all_k_128_aligned) { MAKE_FP8_RUNNER(TileCfg_128x128x128_16x16x128_2x2x1_npad); } else { MAKE_FP8_RUNNER(TileCfg_128x128x128_16x16x128_2x2x1_nkpad);