ck_tile grouped gemm: more padding#574
Conversation
| reason="Only enable CUTLASS/CK grouped gemm on Hopper or ROCm", | ||
| ) | ||
| @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16], ids=str) | ||
| @pytest.mark.parametrize("layout", ["TN", "NN", "NT"]) |
| k_val = k_aligned | ||
| m_vals = [m_aligned] * z | ||
| n_val = unaligned_n | ||
|
|
There was a problem hiding this comment.
Would we want an MKN unaligned test? Would that cover something that isn't included in the current test sweep?
| if pad_dim == "K": | ||
| k_val = unaligned_k | ||
| m_vals = [m_aligned] * z | ||
| n_val = n_aligned | ||
| elif pad_dim == "M": | ||
| k_val = k_aligned | ||
| m_vals = unaligned_m | ||
| n_val = n_aligned | ||
| elif pad_dim == "MK": | ||
| k_val = unaligned_k | ||
| m_vals = unaligned_m | ||
| n_val = n_aligned | ||
| else: # N | ||
| k_val = k_aligned | ||
| m_vals = [m_aligned] * z | ||
| n_val = unaligned_n |
There was a problem hiding this comment.
Can we factor out this if-elif-elif-else block that seems repeated for each layout?
| } | ||
| return launch_grouped_gemm_kernel<Kernel>(descs, ctx, stream_cfg); | ||
| // Dispatch with B's columnwise buffer as RowMajor (transB=false). | ||
| GroupedGemmRunContext ctx_nn = ctx; |
There was a problem hiding this comment.
nit: ctx_nn seems a bit misleading since this only rewrites B as non-transposed via columnwise_data; A can still be T or N. Maybe rename to something like ctx_b_colwise?
| grad = True | ||
| single_output = True | ||
| else: # NT | ||
| # NT GEMM: out[i] = A[i]^T @ B[i], A[i]: (m_i, k), B[i]: (m_i, n), out[i]: (n, k) |
There was a problem hiding this comment.
nit: this comment is a little confusing. For the grouped path, the user-facing NT inputs are A=(m_i,k), B=(m_i,n), out=(n,k), but normalization swaps operands/layouts before dispatch, so the actual dispatched gemm is B^T @ A = (n,m_i) @ (m_i,k).
There was a problem hiding this comment.
Right, I removed this comment in aee2c4c
Description
Enabling padding always causes a significant (~15%) reduction in speed, so only enable it when necessary.
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: