CK Tile MXFP8 Group GEMM gfx1250#578
Conversation
…rash; remaining issue is numerical validation vs BF16 sequential reference.
| test_dequantize_mxfp8.cu | ||
| test_dequantize_nvfp4.cu | ||
| test_cast_nvfp4_transpose.cu | ||
| test_ck_grouped_mxfp8.cu |
There was a problem hiding this comment.
It should be for non CUDA only
| // Currently only support cutlass group gemm on Hopper Arch | ||
| if (!(is_hopper && use_cutlass)) { | ||
| // if (!(is_hopper && use_cutlass)) { | ||
| if (!use_cutlass) { |
| delay_wgrad_compute, | ||
| ): | ||
| os.environ["NVTE_USE_CUTLASS_GROUPED_GEMM"] = "1" | ||
| os.environ["NVTE_ROCM_ENABLE_MXFP8"] = "1" |
There was a problem hiding this comment.
I think this should only be set when the recipe we are testing is mxfp8.
There was a problem hiding this comment.
Good point. Looking at the parametrization, MXFP8BlockScaling is only added to fp8_recipes when NVTE_ROCM_ENABLE_MXFP8=1 is already set before test collection. So setting it inside this test is redundant and also broader than intended. Removed in 746afea
|
|
||
| // Treat TE tensors as generalized 2D matrices by flattening: | ||
| // (D1, D2, ..., Dn) -> (D1*...*D(n-1), Dn), consistent with TE Tensor::flat_*_dim. | ||
| static inline bool get_flat_2d_dims(const transformer_engine::Tensor& t, |
There was a problem hiding this comment.
Re-use get_flat_2d_dims from ck_grouped_gemm_common.h
There was a problem hiding this comment.
I think some portion of the code is already present in ck_grouped_gemm_common.h inside ck_grouped_gemm folder. What was the reasoning behind having a separate directory for ck_mx_grouped_gemm?
There was a problem hiding this comment.
No, there really was not a good reason for this. I agree that it makes more sense to keep it all under the same directory, and re-use the common functions already defined in the shared header. I have made these changes in 175855d
| #ifndef CK_TILE_USE_OCP_FP8 | ||
| #define CK_TILE_USE_OCP_FP8 1 | ||
| #endif |
There was a problem hiding this comment.
Just curious, where is this macro used?
| static float to_float(const bf16_t& x) { return static_cast<float>(x); } | ||
| static float to_float(const ck_tile::bfloat16_t& x) { return static_cast<float>(x); } |
There was a problem hiding this comment.
is ck_tile::bfloat16_t same as our bf16_t?
| setenv("NVTE_ROCM_ENABLE_MXFP8", "1", 0); | ||
| } | ||
|
|
||
| static float to_float(float x) { return x; } |
There was a problem hiding this comment.
Why do we need a float to float?
| static float to_float(const bf16_t& x) { return static_cast<float>(x); } | ||
| static float to_float(const ck_tile::bfloat16_t& x) { return static_cast<float>(x); } | ||
|
|
||
| __device__ __host__ __forceinline__ float ref_gelu_unused(float x) { |
| size_t a_idx = 0; | ||
| size_t b_idx = 0; | ||
|
|
||
| if (use_mxfp8) { |
There was a problem hiding this comment.
Based on your test name I presume you wanted to test mxfp8 but here it looks like you wanted to cover non-mxfp8 as well?
|
|
||
| cudaDeviceProp prop; | ||
| NVTE_CHECK_CUDA(cudaGetDeviceProperties(&prop, 0)); | ||
| #ifdef __HIP_PLATFORM_AMD__ |
There was a problem hiding this comment.
Probably not needed since NV upstream do not have this file
Description
This PR integrates CK Tile MXFP8 grouped GEMM backend with TDM into TE. Replaces 3rdparty/aiter with 3rdparty/rocm-libraries for the gfx1250 changes from CK.
Fixes # (16490)
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: