Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
138 changes: 138 additions & 0 deletions tests/pytorch/test_numerics.py
Original file line number Diff line number Diff line change
Expand Up @@ -3075,6 +3075,144 @@ def test_grouped_gemm(shape, dtype, layout, accumulate, use_cutlass):
os.environ.pop("NVTE_USE_CUTLASS_GROUPED_GEMM", None)


if IS_HIP_EXTENSION:
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16], ids=str)
@pytest.mark.parametrize("layout", ["TN", "NN", "NT", "TT"])
@pytest.mark.parametrize("accumulate", [False, True])
@pytest.mark.parametrize(
"pad_dim",
["K", "M", "N", "MK", "MKN"],
ids=lambda d: f"pad{d}",
)
def test_grouped_gemm_unaligned(dtype, layout, accumulate, pad_dim, capfd):
"""Test CK grouped GEMM with M, N, or K not aligned to CK tile size.

CK constraints for bf16/fp16:
- Contiguous dim of A/B must be dword-aligned (even for 2-byte types).
RowMajor: contiguous dim is cols (K for A, N for B).
ColMajor: contiguous dim is rows (M for A, K for B).
- K tile: 64, M tile: 256, N tile: 128/256
"""
torch.manual_seed(0)
z = 8

# Unaligned values per dimension (all satisfy CK vector-load constraints).
# K: even but not multiple of tile (64). Same for all groups.
# M: not multiples of tile (256), varies per group.
# N: multiple of 16 but not multiple of tile (128).
unaligned_k = 2016
unaligned_m = [100, 300, 150, 200, 50, 350, 250, 180]
unaligned_n = 2032

# Aligned defaults.
k_aligned = 2048
m_aligned = 256
n_aligned = 2048

# Select (un)aligned values based on pad_dim.
k_val = unaligned_k if "K" in pad_dim else k_aligned
m_vals = unaligned_m if "M" in pad_dim else [m_aligned] * z
n_val = unaligned_n if "N" in pad_dim else n_aligned

total_m = sum(m_vals)
os.environ["NVTE_USE_CUTLASS_GROUPED_GEMM"] = "1"
os.environ["NVTE_CUTLASS_GROUPED_GEMM_WARN_FALLBACK"] = "1"

if layout == "TN":
A = [torch.randn(n_val, k_val, dtype=dtype, device="cuda") for _ in range(z)]
B = [torch.randn(m, k_val, dtype=dtype, device="cuda") for m in m_vals]
out = [torch.randn(total_m, n_val, dtype=dtype, device="cuda")]
out_ref = [o.clone() for o in torch.split(out[0], m_vals)]
m_splits = m_vals
grad = False
single_output = True
elif layout == "NN":
A = [torch.randn(k_val, n_val, dtype=dtype, device="cuda") for _ in range(z)]
B = [torch.randn(m, k_val, dtype=dtype, device="cuda") for m in m_vals]
out = [torch.randn(total_m, n_val, dtype=dtype, device="cuda")]
out_ref = [o.clone() for o in torch.split(out[0], m_vals)]
m_splits = m_vals
grad = True
single_output = True
elif layout == "NT":
A = list(torch.split(
torch.randn(total_m, k_val, dtype=dtype, device="cuda"), m_vals
))
B = list(torch.split(
torch.randn(total_m, n_val, dtype=dtype, device="cuda"), m_vals
))
out = [torch.randn(n_val, k_val, dtype=dtype, device="cuda") for _ in range(z)]
out_ref = [o.clone() for o in out]
m_splits = m_vals
grad = True
single_output = False
else: # TT
A = [torch.randn(n_val, k_val, dtype=dtype, device="cuda") for _ in range(z)]
B = [torch.randn(k_val, m, dtype=dtype, device="cuda") for m in m_vals]
out = [torch.randn(total_m, n_val, dtype=dtype, device="cuda")]
out_ref = [o.clone() for o in torch.split(out[0], m_vals)]
m_splits = m_vals
grad = False
single_output = True

# Reference: individual GEMMs
for i in range(z):
if layout == "TT":
# general_gemm doesn't support TT; compute reference manually.
ref = B[i].T.to(torch.float32) @ A[i].T.to(torch.float32)
if accumulate:
out_ref[i] = (out_ref[i].to(torch.float32) + ref).to(dtype)
else:
out_ref[i] = ref.to(dtype)
else:
general_gemm(
A[i],
B[i],
dtype,
grad=grad,
accumulate=accumulate,
layout=layout,
out=out_ref[i],
)

if single_output:
out_ref = [torch.cat(out_ref)]

general_grouped_gemm(
A,
B,
out,
[None] * z,
dtype,
m_splits=m_splits,
grad=grad,
accumulate=accumulate,
layout=layout,
single_output=single_output,
)

for o, o_ref in zip(out, out_ref):
if IS_HIP_EXTENSION and accumulate and dtype == torch.bfloat16 and get_device_compute_capability() == (9, 4):
torch.testing.assert_close(o, o_ref, rtol=4e-2, atol=4e-2)
else:
torch.testing.assert_close(o, o_ref, rtol=1.5e-2, atol=1.5e-2)

# Check for CK fallback warnings from C++ (NVTE_WARN writes to std::cerr).
# capfd captures file-descriptor-level output, including C/C++ stderr.
captured = capfd.readouterr()
if "Falling back" in captured.err or "Fallback" in captured.err:
if "K" in pad_dim and layout != "NN":
pytest.xfail(
"Known CK_Tile limitation: K-padding with non-NN layouts may fall back to cuBLAS "
"(kPadK + ColMajor B bug, or CK_Tile stride alignment requirements)"
)
else:
pytest.fail(f"CK_Tile grouped GEMM fell back to cuBLAS:\n{captured.err}")

os.environ.pop("NVTE_USE_CUTLASS_GROUPED_GEMM", None)
os.environ.pop("NVTE_CUTLASS_GROUPED_GEMM_WARN_FALLBACK", None)


@pytest.mark.parametrize("N", [32])
@pytest.mark.parametrize("datatype", [torch.float16, torch.bfloat16])
@pytest.mark.parametrize(
Expand Down
4 changes: 4 additions & 0 deletions transformer_engine/common/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,10 @@ else()
gemm/ck_grouped_gemm/ck_grouped_gemm.cpp
gemm/ck_grouped_gemm/ck_grouped_gemm_fp8.cpp
gemm/ck_grouped_gemm/ck_grouped_gemm_fp16.cpp
gemm/ck_grouped_gemm/ck_grouped_gemm_fp16_nn.cpp
gemm/ck_grouped_gemm/ck_grouped_gemm_fp16_nt.cpp
gemm/ck_grouped_gemm/ck_grouped_gemm_fp16_tn.cpp
gemm/ck_grouped_gemm/ck_grouped_gemm_fp16_tt.cpp
amd_detail/system.cpp)
list(APPEND transformer_engine_cuda_sources
fused_attn_rocm/fused_attn_aotriton.cpp
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,16 @@ static inline bool launch_grouped_gemm_kernel(const DescContainer& descs,

if (!Kernel::IsSupportedArgument(kargs)) {
NVTE_WARN("ck_tile_grouped_gemm: CK_Tile kernel arguments not supported for this config. "
"Falling back.");
"transA=", ctx.transA, " transB=", ctx.transB,
" accumulate=", ctx.accumulate, " groups=", ctx.group_num,
". Falling back. "
"CK_Tile constraints for bf16/fp16: "
"contiguous dim of A and B must be dword-aligned (even).");
for (size_t i = 0; i < descs.size(); ++i) {
NVTE_WARN(" group ", i, ": M=", descs[i].M, " N=", descs[i].N, " K=", descs[i].K,
" stride_A=", descs[i].stride_A, " stride_B=", descs[i].stride_B,
" stride_E=", descs[i].stride_E);
}
return false;
}

Expand Down
Loading
Loading