CK Tile Group GEMM gfx1250#576
Conversation
| // Currently only support cutlass group gemm on Hopper Arch | ||
| if (!(is_hopper && use_cutlass)) { | ||
| //if (!(is_hopper && use_cutlass)) { | ||
| if (!use_cutlass) { |
| using type = TileCfg_256x256x64_WMMA; | ||
| }; | ||
|
|
||
| template <GPUArch Arch> |
There was a problem hiding this comment.
Why does it need template over reguler if-else or switch-case?
There was a problem hiding this comment.
The template is needed because the arch selection affects CK kernel template instantiation, not just runtime control flow. GPUArch must be a compile-time value so if constexpr can prune unsupported tile/kernel combinations for a given architecture. In this case, it prevents the MFMA configs from being instantiated for gfx1250.
There was a problem hiding this comment.
I didn't compile it with gfx1250 arch only but I was still puzzled about this templated dispatch. In line 298, you still rely on runtime detect_gpu_arch() to branch to specific ck_tile_grouped_gemm_fp16_dispatch_arch<arch_id>'s. So I presume all three arches verions will still be instantiated? And I didn't see any compile time guarding?
| if (arch == 125 || arch == 1250) { | ||
| return GPUArch::GFX1250; | ||
| } |
There was a problem hiding this comment.
Why do we want to host two possible arch ids for gfx1250? Is it because in some docker image, it shows 125 but in other docker images it shows 1250?
| static constexpr ck_tile::index_t M_Warp_Tile = 16; | ||
| static constexpr ck_tile::index_t N_Warp_Tile = 16; | ||
| static constexpr ck_tile::index_t K_Warp_Tile = 32; | ||
|
|
||
| static constexpr bool kPadM = true; | ||
| static constexpr bool kPadN = true; | ||
| static constexpr bool kPadK = true; |
There was a problem hiding this comment.
so the difference btw TileCfg_256x256x64_MFMA and TileCfg_256x256x64_WMMA is inside M, N, K warp tile and kPads?
| using type = TileCfg_256x256x64_WMMA; | ||
| }; | ||
|
|
||
| template <GPUArch Arch> |
There was a problem hiding this comment.
I didn't compile it with gfx1250 arch only but I was still puzzled about this templated dispatch. In line 298, you still rely on runtime detect_gpu_arch() to branch to specific ck_tile_grouped_gemm_fp16_dispatch_arch<arch_id>'s. So I presume all three arches verions will still be instantiated? And I didn't see any compile time guarding?
| COMPILE_OPTIONS "-g0;-dopt=on") | ||
| else() | ||
| set(CK_ROOT ${CMAKE_CURRENT_SOURCE_DIR}/../../3rdparty/aiter/3rdparty/composable_kernel) | ||
| set(CK_ROOT ${CMAKE_CURRENT_SOURCE_DIR}/../../3rdparty/rocm_libraries/projects/composablekernel) |
There was a problem hiding this comment.
nit: Will the whole rocm_libraries too big? Do we have a way to have sparse check out for this ck subdir?
Description
Extend the present CK tile grouped GEMM (F16/F8) implementation for compatibility with gfx1250. Replaces 3rdparty/aiter with 3rdparty/rocm-libraries for the gfx1250 changes from CK.
Fixes #16490
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: