diff --git a/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm.cpp b/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm.cpp index 5684be1cd..7f8a7067c 100644 --- a/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm.cpp +++ b/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm.cpp @@ -94,8 +94,11 @@ bool ck_tile_grouped_gemm(const NVTETensor* A, } } - const auto a_dtype = convertNVTETensorCheck(A_use[0])->dtype(); - const auto b_dtype = convertNVTETensorCheck(B_use[0])->dtype(); + const auto& A0_data = use_a_colwise_data ? A0_te->columnwise_data : A0_te->data; + const auto& B0_data = use_b_colwise_data ? B0_te->columnwise_data : B0_te->data; + + const auto a_dtype = A0_data.dtype; + const auto b_dtype = B0_data.dtype; Tensor* D0_te = convertNVTETensorCheck(D[0]); const auto d_dtype = D0_te->dtype(); @@ -156,6 +159,7 @@ bool ck_tile_grouped_gemm(const NVTETensor* A, B_use, D, static_cast(n), + static_cast(kA), group_num, transA_use, transB_use, diff --git a/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_common.h b/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_common.h index c89f10232..7675ad088 100644 --- a/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_common.h +++ b/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_common.h @@ -75,6 +75,7 @@ struct GroupedGemmRunContext { const NVTETensor* B = nullptr; NVTETensor* D = nullptr; int64_t N = 0; + int64_t K = 0; int group_num = 0; bool transA = false; diff --git a/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_fp16.cpp b/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_fp16.cpp index 660dbefb8..42ddaea99 100644 --- a/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_fp16.cpp +++ b/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_fp16.cpp @@ -196,7 +196,7 @@ class GroupedGemmRunner : public RunnerInterface { } }; -#define MAKE_RUNNER(TileCfg_) \ +#define MAKE_FP16_RUNNER(TileCfg_) \ TRANSFORMER_ENGINE_SWITCH_CONDITION(ctx.accumulate, accum_option, { \ using Runner = GroupedGemmRunner::type; if (ctx.N % 256 == 0) { - MAKE_RUNNER(TileCfg_256x256x64); + MAKE_FP16_RUNNER(TileCfg_256x256x64); } else if (ctx.N % 128 == 0) { - MAKE_RUNNER(TileCfg_256x128x64); + MAKE_FP16_RUNNER(TileCfg_256x128x64); } else { - MAKE_RUNNER(TileCfg_256x128x64_padding); + MAKE_FP16_RUNNER(TileCfg_256x128x64_padding); } }); }); @@ -249,7 +249,7 @@ bool ck_tile_grouped_gemm_fp16_dispatch(DType a_dtype, return runner->run(s, ctx); } -#undef MAKE_RUNNER +#undef MAKE_FP16_RUNNER } // namespace grouped_gemm } // namespace transformer_engine diff --git a/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_fp8.cpp b/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_fp8.cpp index 50b701c05..fead8a048 100644 --- a/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_fp8.cpp +++ b/transformer_engine/common/gemm/ck_grouped_gemm/ck_grouped_gemm_fp8.cpp @@ -22,9 +22,9 @@ enum class GPUArch { UNKNOWN }; -struct TileCfg_128x128x128_16x16x128_2x2x1 { - static constexpr ck_tile::index_t M_Tile = 128; - static constexpr ck_tile::index_t N_Tile = 128; +struct TileCfg_256x256x128_16x16x128_2x2x1 { + static constexpr ck_tile::index_t M_Tile = 256; + static constexpr ck_tile::index_t N_Tile = 256; static constexpr ck_tile::index_t K_Tile = 128; static constexpr ck_tile::index_t M_Warp = 2; @@ -45,13 +45,41 @@ struct TileCfg_128x128x128_16x16x128_2x2x1 { static constexpr ck_tile::index_t TilePartitionerM01 = 8; }; +struct TileCfg_128x128x128_16x16x128_2x2x1 + : TileCfg_256x256x128_16x16x128_2x2x1 { + static constexpr ck_tile::index_t M_Tile = 128; + static constexpr ck_tile::index_t N_Tile = 128; +}; + +struct TileCfg_256x256x128_16x16x128_2x2x1_kpad + : TileCfg_256x256x128_16x16x128_2x2x1 { + static constexpr bool kPadK = true; +}; + +struct TileCfg_128x128x128_16x16x128_2x2x1_kpad + : TileCfg_128x128x128_16x16x128_2x2x1 { + static constexpr bool kPadK = true; +}; + +struct TileCfg_128x128x128_16x16x128_2x2x1_npad + : TileCfg_128x128x128_16x16x128_2x2x1 { + static constexpr bool kPadN = true; +}; + +struct TileCfg_128x128x128_16x16x128_2x2x1_nkpad + : TileCfg_128x128x128_16x16x128_2x2x1 { + static constexpr bool kPadN = true; + static constexpr bool kPadK = true; +}; + // gfx950 device compilation cannot instantiate the literal 32x32x16 FP8 tile // configuration due to an unsupported warp GEMM dispatcher configuration. // See: ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp for supported variants. // // To preserve the existing type name in shared template code, this struct -// inherits from the gfx950-safe 16x16x128 configuration in the gfx950 device -// compilation path, effectively reusing those parameters without redefining them. +// inherits from the gfx950-safe 128x128x128 16x16x128 configuration in the +// gfx950 device compilation path, effectively reusing those parameters without +// redefining them. // // In all other compilation paths, the struct overrides the relevant fields to // provide the intended 32x32x16 configuration. @@ -261,7 +289,9 @@ class QuantGroupedGemmRunner : public RunnerInterface { if (descs.empty()) { return false; } - return launch_grouped_gemm_kernel(descs, ctx, stream_cfg); + + const bool launched = launch_grouped_gemm_kernel(descs, ctx, stream_cfg); + return launched; } }; @@ -290,6 +320,78 @@ struct FP8TileCfg { using type = TileCfg_128x128x128_16x16x128_2x2x1; }; +struct FP8GroupedShapeAlignment { + bool all_n_256_aligned = true; + bool all_n_128_aligned = true; + bool all_k_128_aligned = true; +}; + +static FP8GroupedShapeAlignment get_fp8_grouped_shape_alignment( + const GroupedGemmRunContext& ctx) { + FP8GroupedShapeAlignment alignment; + + for (int i = 0; i < ctx.group_num; ++i) { + const transformer_engine::Tensor* const A_te = + transformer_engine::convertNVTETensorCheck(ctx.A[i]); + const transformer_engine::Tensor* const B_te = + transformer_engine::convertNVTETensorCheck(ctx.B[i]); + + int64_t Ad0 = 0, Ad1 = 0, Bd0 = 0, Bd1 = 0; + + if (ctx.use_a_columnwise_data) { + if (!get_columnwise_storage_2d_dims(A_te->columnwise_data, Ad0, Ad1)) { + NVTE_ERROR("ck_tile_grouped_gemm: expected 2D columnwise_data for A in group ", i); + } + } else { + if (!get_flat_2d_dims(*A_te, Ad0, Ad1)) { + NVTE_ERROR("ck_tile_grouped_gemm: expected rank>=2 for normalized A in group ", i); + } + } + + if (ctx.use_b_columnwise_data) { + if (!get_columnwise_storage_2d_dims(B_te->columnwise_data, Bd0, Bd1)) { + NVTE_ERROR("ck_tile_grouped_gemm: expected 2D columnwise_data for B in group ", i); + } + } else { + if (!get_flat_2d_dims(*B_te, Bd0, Bd1)) { + NVTE_ERROR("ck_tile_grouped_gemm: expected rank>=2 for normalized B in group ", i); + } + } + + const int64_t K = ctx.transA ? Ad0 : Ad1; + const int64_t N = ctx.transB ? Bd0 : Bd1; + + if (N % 256 != 0) { + alignment.all_n_256_aligned = false; + } + if (N % 128 != 0) { + alignment.all_n_128_aligned = false; + } + if (K % 128 != 0) { + alignment.all_k_128_aligned = false; + } + + if (!alignment.all_n_256_aligned && + !alignment.all_n_128_aligned && + !alignment.all_k_128_aligned) { + break; + } + } + + return alignment; +} + +#define MAKE_FP8_RUNNER(TileCfg_) \ + using Runner = QuantGroupedGemmRunner; \ + runner = std::make_unique() + template static bool ck_tile_grouped_gemm_fp8_dispatch_arch(DType a_dtype, DType b_dtype, @@ -299,33 +401,55 @@ static bool ck_tile_grouped_gemm_fp8_dispatch_arch(DType a_dtype, std::unique_ptr runner = nullptr; using CTypeLayout = RowMajor; - using TileCfg = typename FP8TileCfg::type; - - TRANSFORMER_ENGINE_SWITCH_CONDITION(ctx.transA, kTransA, { - using ALayout = std::conditional_t; - - TRANSFORMER_ENGINE_SWITCH_CONDITION(ctx.transB, kTransB, { - using BLayout = std::conditional_t; - - TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(a_dtype, a_te_type, { - using AType = typename TETypeToCKType::type; - - TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(b_dtype, b_te_type, { - using BType = typename TETypeToCKType::type; - - TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(d_dtype, d_te_type, { - using CType = typename TETypeToCKType::type; - using Runner = QuantGroupedGemmRunner; - runner = std::make_unique(); - }); - }); + + // FP8 grouped GEMM is only compiled for CK's preferred NT presentation: + // transA=false, transB=true + // which maps to: + // ALayout=RowMajor, BLayout=ColMajor. + // + // The caller is responsible for rewriting other FP8 layouts into this form + // using columnwise_data when needed. Reject anything that did not normalize + // successfully so we do not instantiate unreachable/unsupported layout variants. + if (ctx.transA || !ctx.transB) { + return false; + } + + using ALayout = RowMajor; + using BLayout = ColMajor; + + TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(a_dtype, a_te_type, { + using AType = typename TETypeToCKType::type; + + TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(b_dtype, b_te_type, { + using BType = typename TETypeToCKType::type; + + TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(d_dtype, d_te_type, { + using CType = typename TETypeToCKType::type; + + if constexpr (Arch == GPUArch::GFX950) { + const auto alignment = get_fp8_grouped_shape_alignment(ctx); + + if (alignment.all_n_256_aligned) { + if (alignment.all_k_128_aligned) { + MAKE_FP8_RUNNER(TileCfg_256x256x128_16x16x128_2x2x1); + } else { + MAKE_FP8_RUNNER(TileCfg_256x256x128_16x16x128_2x2x1_kpad); + } + } else if (alignment.all_n_128_aligned) { + if (alignment.all_k_128_aligned) { + MAKE_FP8_RUNNER(TileCfg_128x128x128_16x16x128_2x2x1); + } else { + MAKE_FP8_RUNNER(TileCfg_128x128x128_16x16x128_2x2x1_kpad); + } + } else if (alignment.all_k_128_aligned) { + MAKE_FP8_RUNNER(TileCfg_128x128x128_16x16x128_2x2x1_npad); + } else { + MAKE_FP8_RUNNER(TileCfg_128x128x128_16x16x128_2x2x1_nkpad); + } + } else { + using TileCfg = typename FP8TileCfg::type; + MAKE_FP8_RUNNER(TileCfg); + } }); }); }); @@ -334,9 +458,12 @@ static bool ck_tile_grouped_gemm_fp8_dispatch_arch(DType a_dtype, return false; } - return runner->run(s, ctx); + const bool ok = runner->run(s, ctx); + return ok; } +#undef MAKE_FP8_RUNNER + bool ck_tile_grouped_gemm_fp8_dispatch(DType a_dtype, DType b_dtype, DType d_dtype, diff --git a/transformer_engine/common/gemm/cublaslt_gemm.cu b/transformer_engine/common/gemm/cublaslt_gemm.cu index 7326f330f..03df8751b 100644 --- a/transformer_engine/common/gemm/cublaslt_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_gemm.cu @@ -1160,9 +1160,20 @@ void nvte_multi_tensor_gemm(const NVTETensor *A, const NVTETensor *B, NVTETensor auto *inputB = transformer_engine::convertNVTETensorCheck(B[0]); auto *OutputD = transformer_engine::convertNVTETensorCheck(D[0]); #ifdef __HIP_PLATFORM_AMD__ - auto A_dt = inputA->data.dtype; - auto B_dt = inputB->data.dtype; + auto effective_dtype = [](const transformer_engine::Tensor* t) { + if (is_fp8_dtype(t->data.dtype)) { + return t->data.dtype; + } + if (t->has_columnwise_data() && is_fp8_dtype(t->columnwise_data.dtype)) { + return t->columnwise_data.dtype; + } + return t->data.dtype; + }; + + auto A_dt = effective_dtype(inputA); + auto B_dt = effective_dtype(inputB); auto D_dt = OutputD->data.dtype; + return ( (is_fp8_dtype(A_dt) && is_fp8_dtype(B_dt)) ) ||