From 70f6e4ad1bb9f111edcc33f0dffbd0c101cebb85 Mon Sep 17 00:00:00 2001 From: carlushuang Date: Sat, 16 May 2026 21:14:28 +0800 Subject: [PATCH 01/18] opus: forward-declare mma adaptors so opus.hpp parses on gfx1201 MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit On gfx1200 / gfx1201 (Navi 44 / Navi 48, RDNA4) device code, neither __GFX9__ nor __gfx1250__ is active, so the inner mfma_adaptor and wmma_adaptor definitions never get pulled in. That alone would be fine, except make_tiled_mma() at csrc/include/opus/opus.hpp:3057-3063 names both as default template arguments: template<..., typename WA = #if defined(__gfx1250__) wmma_adaptor, #else mfma_adaptor, #endif ...> Name lookup runs at parse time even when no caller instantiates the template, so on gfx1201 the header fails to compile with: csrc/include/opus/opus.hpp:3057:24: error: unknown type name 'mfma_adaptor' This blocks JIT builds for every kernel that includes aiter_opus_plus.h, e.g. sample_kernels.cu, moe_fused_gate.cu, topk_gating_kernels.cu, gated_rmsnorm_quant_kernels.cu, topk_softmax_kernels_group.cu, mhc_kernels.cu, quant_kernels.cu, quant_mxfp4.cu, fused_qk_rmsnorm_group_quant.cu, and rope_common.h. Fix: forward-declare mfma_adaptor / mfma_adaptor_swap_ab / wmma_adaptor / wmma_adaptor_swap_ab as incomplete types at the top of the opus namespace. Default-arg name lookup is satisfied for all archs; instantiation still requires the full definition, so callers that actually invoke make_tiled_mma() / make_mfma() / make_wmma() are unaffected. Behavior on gfx1250 and gfx9x is unchanged. Unit test: op_tests/opus/device/test_opus_parse_gfx1201.cu exercises the opus utilities sample_kernels.cu actually uses (opus::vector_t and opus::cast) — pure template machinery, no HIP intrinsics — under a gfx1201-gated kernel body. Wired into the existing test_opus_device.py harness; skips on non-gfx1201 archs. Verified on RX 9070 XT (gfx1201): test compiles + runs, max_diff=0.0 against torch reference. End-to-end check separately confirmed: with this fix in place, module_sample JIT-builds for gfx1201 and aiter.mixed_sample_outer_exponential produces bit-identical output to torch Gumbel-max (15.4us avg over 1000 calls). --- csrc/include/opus/opus.hpp | 22 ++++ op_tests/opus/device/setup.py | 1 + op_tests/opus/device/test_opus_device.py | 51 ++++++++ .../opus/device/test_opus_parse_gfx1201.cu | 114 ++++++++++++++++++ 4 files changed, 188 insertions(+) create mode 100644 op_tests/opus/device/test_opus_parse_gfx1201.cu diff --git a/csrc/include/opus/opus.hpp b/csrc/include/opus/opus.hpp index e0c84f7ecb..af4d43c086 100644 --- a/csrc/include/opus/opus.hpp +++ b/csrc/include/opus/opus.hpp @@ -499,6 +499,28 @@ template struct tuple_element constexpr auto embed(const X& x, const Y& y, seq) { return ( ... + (get(x) * get(y))); } diff --git a/op_tests/opus/device/setup.py b/op_tests/opus/device/setup.py index 376d8d2404..8273262f4c 100644 --- a/op_tests/opus/device/setup.py +++ b/op_tests/opus/device/setup.py @@ -40,6 +40,7 @@ "test_numeric_limits.cu", "test_workgroup_barrier.cu", "test_finfo.cu", + "test_opus_parse_gfx1201.cu", ] diff --git a/op_tests/opus/device/test_opus_device.py b/op_tests/opus/device/test_opus_device.py index 90e5a3f05e..9d9a7d660d 100644 --- a/op_tests/opus/device/test_opus_device.py +++ b/op_tests/opus/device/test_opus_device.py @@ -176,6 +176,13 @@ def run_vector_add(self, A, B, Result): fn.argtypes = [_VP, _VP, _VP, _I] fn(self._ptr(A), self._ptr(B), self._ptr(Result), int(A.numel())) + # -- opus_parse_gfx1201 (verifies opus.hpp parses + opus utils work on gfx1201) -- + def run_opus_parse_gfx1201(self, A, B, Result): + fn = self._lib.run_opus_parse_gfx1201 + fn.restype = None + fn.argtypes = [_VP, _VP, _VP, _I] + fn(self._ptr(A), self._ptr(B), self._ptr(Result), int(A.numel())) + # -- async_load -- def run_async_load(self, Src, Dst): fn = self._lib.run_async_load @@ -1166,6 +1173,49 @@ def test_vector_add(mod): return 0 +# Archs where the opus.hpp parse-time fix matters. The kernel body in +# test_opus_parse_gfx1201.cu is gated by __gfx1201__ — on other archs the +# launcher runs an empty kernel, so we skip the correctness check to avoid +# a misleading failure. +_OPUS_PARSE_GFX1201_ARCHS = {"gfx1201"} + + +def test_opus_parse_gfx1201(mod): + """Verify opus.hpp parses + opus utilities (make_gmem / .load/.store / + cast) work on gfx1201. Mirrors the load → cast → store pattern + in sample_kernels.cu. Skips on other archs (kernel body is gfx1201-only).""" + arch = _get_gpu_arch() + if arch not in _OPUS_PARSE_GFX1201_ARCHS: + print(f" SKIP: opus_parse_gfx1201 (arch={arch}, gfx1201-only test)") + return 0 + + n = 1310720 + device = torch.device("cuda") + dtype = torch.float32 + + torch.manual_seed(42) + A = torch.randn(n, device=device, dtype=dtype) + B = torch.randn(n, device=device, dtype=dtype) + Result = torch.empty(n, device=device, dtype=dtype) + + mod.run_opus_parse_gfx1201(A, B, Result) + + Ref = A + B + + atol, rtol = 1e-5, 1e-5 + ok = torch.allclose(Result, Ref, atol=atol, rtol=rtol) + max_diff = (Result - Ref).abs().max().item() + if not ok: + diff_count = (Result - Ref).abs().gt(atol + rtol * Ref.abs()).sum().item() + print( + f" FAIL: opus_parse_gfx1201 max_diff={max_diff:.6e}, " + f"{diff_count} elements outside tol" + ) + return 1 + print(f" PASS: opus_parse_gfx1201 (arch={arch}, n={n}), max_diff={max_diff:.6e}") + return 0 + + def test_async_load(mod): """Test async_load: copy data through LDS and verify integrity.""" # n should be a multiple of BLOCK_SIZE (256) @@ -2171,6 +2221,7 @@ def main(): failures += test_wmma_scale_16x16x128_fp8_bx32_scaled(mod) failures += test_mma_step_k_bf16(mod) failures += test_vector_add(mod) + failures += test_opus_parse_gfx1201(mod) failures += test_async_load(mod) failures += test_tr_load_f16(mod) failures += test_dtype_convert_fp32_bf16(mod) diff --git a/op_tests/opus/device/test_opus_parse_gfx1201.cu b/op_tests/opus/device/test_opus_parse_gfx1201.cu new file mode 100644 index 0000000000..7a9700eadb --- /dev/null +++ b/op_tests/opus/device/test_opus_parse_gfx1201.cu @@ -0,0 +1,114 @@ +// SPDX-License-Identifier: MIT +// Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. + +/** + * @file test_opus_parse_gfx1201.cu + * @brief Verify opus.hpp parses cleanly on gfx1201 (Navi 48 / RX 9070 XT, RDNA4). + * + * Without the gfx1201 forward-declaration fix in opus.hpp, kernels that + * include opus.hpp on gfx1201 fail to compile with: + * + * csrc/include/opus/opus.hpp:3065:24: error: unknown type name 'mfma_adaptor' + * + * because make_tiled_mma()'s default template argument references + * mfma_adaptor (defined only under __GFX9__) and wmma_adaptor (defined only + * under __gfx1250__) — neither of which is active in gfx1201 device code. + * + * This test exercises the opus utilities sample_kernels.cu actually uses + * (opus::vector_t for vectorized lane storage, opus::cast for type + * conversion). They are pure compile-time / C++ template machinery — no HIP + * intrinsics — so they work on any arch the opus.hpp header parses for. + * + * What this test does NOT exercise: + * - opus::make_gmem .load/.store — these route through buffer-load / + * buffer-store intrinsics that are not available on gfx1201 today; the + * kernel uses plain pointer arithmetic for memory I/O instead. + * - mfma_adaptor / wmma_adaptor instantiation — neither is defined for + * gfx1201; that is the subject of opus.hpp Phase 2 (full WMMA support). + * + * If this test builds and produces correct results on gfx1201, the + * forward declarations in opus.hpp are sufficient to keep gfx1201 device + * code compiling. Behavior on gfx1250 / gfx9x is unchanged — the kernel + * body is gated by __gfx1201__ so other archs see an empty no-op pass. + */ + +#ifdef __HIP_DEVICE_COMPILE__ +// ── Device pass: opus.hpp + kernel body, no hip_runtime.h ────────────────── +#include "opus/opus.hpp" + +#if defined(__gfx1201__) +// Element-wise add via opus::vector_t lanes + opus::cast; the +// load/store goes through plain pointer arithmetic because opus's +// buffer-intrinsic backed make_gmem store path is not available on +// gfx1201 today. Sample_kernels.cu uses the same vector_t + cast pattern +// for its per-lane FP8/BF16 → FP32 conversion. +template +__global__ void opus_parse_gfx1201_kernel( + const float* __restrict__ a, + const float* __restrict__ b, + float* __restrict__ result, + int n) +{ + int idx = __builtin_amdgcn_workgroup_id_x() * BLOCK_SIZE + + __builtin_amdgcn_workitem_id_x(); + int stride = __builtin_amdgcn_grid_size_x(); + + for (int base = idx * VECTOR_SIZE; base < n; base += stride * VECTOR_SIZE) { + opus::vector_t va, vb, vr; + for (int j = 0; j < VECTOR_SIZE; ++j) { + va[j] = a[base + j]; + vb[j] = b[base + j]; + } + for (int j = 0; j < VECTOR_SIZE; ++j) { + // opus::cast(float) is a compile-time pass-through. + // Including it in the test exercises the same template path + // sample_kernels.cu instantiates for its DTYPE_I → float lanes. + vr[j] = opus::cast(va[j]) + opus::cast(vb[j]); + } + for (int j = 0; j < VECTOR_SIZE; ++j) { + result[base + j] = vr[j]; + } + } +} + +template __global__ void opus_parse_gfx1201_kernel<256, 4>(const float*, const float*, float*, int); +#endif // __gfx1201__ + +#else +// ── Host pass: launcher + empty kernel stub ──────────────────────────────── +#include "opus/hip_minimal.hpp" +#include + +#define HIP_CALL(call) do { \ + hipError_t err = (call); \ + if (err != hipSuccess) { \ + fprintf(stderr, "HIP error %d at %s:%d\n", (int)err, __FILE__, __LINE__); \ + return; \ + } \ +} while(0) + +template +__global__ void opus_parse_gfx1201_kernel(const float*, const float*, float*, int) {} + +extern "C" void run_opus_parse_gfx1201( + const void* d_a, + const void* d_b, + void* d_result, + int n) +{ + const auto* a = static_cast(d_a); + const auto* b = static_cast(d_b); + auto* r = static_cast(d_result); + + constexpr int BLOCK_SIZE = 256; + constexpr int VECTOR_SIZE = 4; + int blocks = (n + (BLOCK_SIZE * VECTOR_SIZE) - 1) / (BLOCK_SIZE * VECTOR_SIZE); + + hipLaunchKernelGGL( + (opus_parse_gfx1201_kernel), + dim3(blocks), dim3(BLOCK_SIZE), 0, 0, + a, b, r, n); + HIP_CALL(hipGetLastError()); + HIP_CALL(hipDeviceSynchronize()); +} +#endif From 9878ec91c18c9979c1cb21ee35b6cd6ac060b6e8 Mon Sep 17 00:00:00 2001 From: carlushuang Date: Sat, 16 May 2026 21:43:11 +0800 Subject: [PATCH 02/18] opus: route gfx1200/gfx1201 to the RDNA buffer rsrc config (fix silent zero loads/stores) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit opus::make_gmem<>.load() / .store() silently returned 0 / dropped writes on gfx1201, because buffer_default_config() landed in the 0xffffffff fallback for that arch and the resulting __amdgpu_buffer_rsrc_t was invalid. The HIP buffer_load_b32 / buffer_store_b32 intrinsics themselves work on RDNA4 — the bug was purely in which config word the header picked. Root cause: the existing branch #elif defined(__gfx11__) || defined(__gfx12__) || defined(__gfx1250__) return 0x31004000; uses __gfx11__ / __gfx12__ (lowercase) which clang does NOT predefine (only the uppercase __GFX11__ / __GFX12__ exist). gfx1250 was already covered by its explicit per-arch token; everything else in the gfx11x / gfx12x families silently fell into the 0xffffffff sentinel branch. Minimal fix: add explicit __gfx1201__ / __gfx1200__ checks alongside the existing __gfx1250__ check, so Navi 44 / Navi 48 also get the correct RDNA buffer rsrc config (0x31004000). The lowercase __gfx11__/__gfx12__ typo is left alone here — fixing it would also flip behavior for gfx1100-1103 / gfx1150-1153, which is out of scope for the gfx1201 enablement work this branch covers. No change for gfx1250 (already used 0x31004000 via the explicit per-arch check); no change for gfx9x (different branch entirely). Unit test: test_opus_parse_gfx1201.cu now uses opus::make_gmem<>.load<4> and .store<4> (the API users actually want working) instead of plain pointer arithmetic. max_diff=0.0 against torch reference on n=1310720 fp32 elements, gfx1201. Integration check: with the buffer config fix in place, sample_kernels.cu JIT-rebuilds cleanly and aiter.mixed_sample_outer_exponential still produces bit-identical output to torch Gumbel-max — confirms the change does not regress kernels that already worked (sample_kernels passes an explicit size to make_gmem, so it tolerated the bad default config; now both the default-config path and the explicit-size path work). --- csrc/include/opus/opus.hpp | 9 ++- .../opus/device/test_opus_parse_gfx1201.cu | 80 ++++++++++--------- 2 files changed, 49 insertions(+), 40 deletions(-) diff --git a/csrc/include/opus/opus.hpp b/csrc/include/opus/opus.hpp index af4d43c086..4deb5b0ceb 100644 --- a/csrc/include/opus/opus.hpp +++ b/csrc/include/opus/opus.hpp @@ -1635,7 +1635,14 @@ OPUS_D constexpr auto buffer_default_config() { return 0x00020000; #elif defined(__gfx103__) return 0x31014000; -#elif defined(__gfx11__) || defined(__gfx12__) || defined(__gfx1250__) +// Navi 44 / Navi 48 (RDNA4, gfx1200 / gfx1201) use the same RDNA buffer rsrc +// config word as gfx1250. They are listed here explicitly because the +// __gfx11__ / __gfx12__ tokens on the previous line are typos — clang only +// predefines the uppercase __GFX11__ / __GFX12__ — so without these explicit +// per-arch checks gfx1200 / gfx1201 fall into the 0xffffffff fallback, +// producing an invalid buffer descriptor that silently drops all stores and +// returns zero for all loads from make_gmem<>. +#elif defined(__gfx11__) || defined(__gfx12__) || defined(__gfx1250__) || defined(__gfx1201__) || defined(__gfx1200__) return 0x31004000; #else return 0xffffffff; diff --git a/op_tests/opus/device/test_opus_parse_gfx1201.cu b/op_tests/opus/device/test_opus_parse_gfx1201.cu index 7a9700eadb..a1afe11b2f 100644 --- a/op_tests/opus/device/test_opus_parse_gfx1201.cu +++ b/op_tests/opus/device/test_opus_parse_gfx1201.cu @@ -3,33 +3,38 @@ /** * @file test_opus_parse_gfx1201.cu - * @brief Verify opus.hpp parses cleanly on gfx1201 (Navi 48 / RX 9070 XT, RDNA4). + * @brief Verify opus.hpp parses + opus::make_gmem load/store work on gfx1201 + * (Navi 48 / RX 9070 XT, RDNA4). * - * Without the gfx1201 forward-declaration fix in opus.hpp, kernels that - * include opus.hpp on gfx1201 fail to compile with: + * Two issues this test covers, both addressed in this commit: * - * csrc/include/opus/opus.hpp:3065:24: error: unknown type name 'mfma_adaptor' + * 1) Parse-time: without the forward declarations of mfma_adaptor / + * wmma_adaptor at the top of the opus namespace, opus.hpp fails to + * compile for gfx1201 device code because make_tiled_mma()'s default + * template argument names types that are gated behind __GFX9__ / + * __gfx1250__ blocks. * - * because make_tiled_mma()'s default template argument references - * mfma_adaptor (defined only under __GFX9__) and wmma_adaptor (defined only - * under __gfx1250__) — neither of which is active in gfx1201 device code. + * 2) Runtime: even after the header parses, opus::make_gmem<>.store() + * and .load() silently produced wrong results on gfx1201 because + * buffer_default_config() returned the 0xffffffff fallback (the + * __gfx11__ / __gfx12__ checks on the prior line are typos — clang + * only predefines the uppercase __GFX11__ / __GFX12__). The invalid + * buffer rsrc made all buffer_load_b32 lanes return 0 and all + * buffer_store_b32 lanes drop on the floor. Fix: add explicit + * __gfx1201__ / __gfx1200__ branches with the correct 0x31004000 + * config word that gfx1250 already uses. * - * This test exercises the opus utilities sample_kernels.cu actually uses - * (opus::vector_t for vectorized lane storage, opus::cast for type - * conversion). They are pure compile-time / C++ template machinery — no HIP - * intrinsics — so they work on any arch the opus.hpp header parses for. + * The kernel below exercises the exact opus API that sample_kernels.cu / + * topk_softmax_kernels_group.cu / etc. depend on: * - * What this test does NOT exercise: - * - opus::make_gmem .load/.store — these route through buffer-load / - * buffer-store intrinsics that are not available on gfx1201 today; the - * kernel uses plain pointer arithmetic for memory I/O instead. - * - mfma_adaptor / wmma_adaptor instantiation — neither is defined for - * gfx1201; that is the subject of opus.hpp Phase 2 (full WMMA support). + * auto g = opus::make_gmem(ptr); + * auto v = g.load(i); // buffer_load via cached rsrc + * ... opus::cast(v[j]) ... + * g.store(vr, i); // buffer_store via cached rsrc * - * If this test builds and produces correct results on gfx1201, the - * forward declarations in opus.hpp are sufficient to keep gfx1201 device - * code compiling. Behavior on gfx1250 / gfx9x is unchanged — the kernel - * body is gated by __gfx1201__ so other archs see an empty no-op pass. + * If this test produces correct results on gfx1201, both fixes hold. + * Kernel body is gated by __gfx1201__ so other archs see an empty no-op + * pass — gfx1250 / gfx9x behavior is unchanged. */ #ifdef __HIP_DEVICE_COMPILE__ @@ -37,11 +42,8 @@ #include "opus/opus.hpp" #if defined(__gfx1201__) -// Element-wise add via opus::vector_t lanes + opus::cast; the -// load/store goes through plain pointer arithmetic because opus's -// buffer-intrinsic backed make_gmem store path is not available on -// gfx1201 today. Sample_kernels.cu uses the same vector_t + cast pattern -// for its per-lane FP8/BF16 → FP32 conversion. +// Element-wise add via opus make_gmem load / store + per-lane opus::cast. +// Mirrors the load → cast → store pattern in sample_kernels.cu. template __global__ void opus_parse_gfx1201_kernel( const float* __restrict__ a, @@ -49,25 +51,25 @@ __global__ void opus_parse_gfx1201_kernel( float* __restrict__ result, int n) { - int idx = __builtin_amdgcn_workgroup_id_x() * BLOCK_SIZE - + __builtin_amdgcn_workitem_id_x(); + auto g_a = opus::make_gmem(a); + auto g_b = opus::make_gmem(b); + auto g_r = opus::make_gmem(result); + + int idx = __builtin_amdgcn_workgroup_id_x() * BLOCK_SIZE + __builtin_amdgcn_workitem_id_x(); int stride = __builtin_amdgcn_grid_size_x(); - for (int base = idx * VECTOR_SIZE; base < n; base += stride * VECTOR_SIZE) { - opus::vector_t va, vb, vr; - for (int j = 0; j < VECTOR_SIZE; ++j) { - va[j] = a[base + j]; - vb[j] = b[base + j]; - } + for (int i = idx * VECTOR_SIZE; i < n; i += stride * VECTOR_SIZE) { + auto va = g_a.load(i); + auto vb = g_b.load(i); + + decltype(va) vr; for (int j = 0; j < VECTOR_SIZE; ++j) { // opus::cast(float) is a compile-time pass-through. - // Including it in the test exercises the same template path - // sample_kernels.cu instantiates for its DTYPE_I → float lanes. + // Including it exercises the same template path sample_kernels.cu + // instantiates for its per-lane DTYPE_I → float conversion. vr[j] = opus::cast(va[j]) + opus::cast(vb[j]); } - for (int j = 0; j < VECTOR_SIZE; ++j) { - result[base + j] = vr[j]; - } + g_r.store(vr, i); } } From b8aa331822dc2b5a42df350535e32b18d0cba638 Mon Sep 17 00:00:00 2001 From: carlushuang Date: Sat, 16 May 2026 23:05:44 +0800 Subject: [PATCH 03/18] =?UTF-8?q?opus:=20add=20gfx1201=20(Navi=2048=20/=20?= =?UTF-8?q?RDNA4)=20WMMA=20support=20=E2=80=94=208=20wave32=2016x16x16=20v?= =?UTF-8?q?ariants?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Extends opus::wmma<> to dispatch through the gfx12-specific __builtin_amdgcn_wmma_*_w32_gfx12 family on gfx1201, in parallel with the existing gfx1250 path. gfx12 builtins have a leaner argument signature than gfx1250 (no matrix_fmts / neg_c slot, no opsel) so they need their own dispatch macros. Variants covered (all wave32 16x16x16; matches the breadth gfx12 exposes for fp / fp8 acc): - f32 <- f16 / f16 - f32 <- bf16 / bf16 - f16 <- f16 / f16 - bf16 <- bf16 / bf16 - f32 <- fp8 / fp8 - f32 <- fp8 / bf8 - f32 <- bf8 / fp8 - f32 <- bf8 / bf8 iu8 / iu4 variants are deliberately deferred — opus has no iu8_t / iu4_t dtype aliases yet, so wiring them needs a small separate change. What is NOT touched on the wmma_adaptor / make_tiled_mma path: the existing wmma_adaptor encoding (rows cross-thread along M) was designed for gfx1250's WMMA fragment layout, which is row-distributed for A and column-distributed for B / C. gfx12 has a different asymmetry — A is row-distributed (lane[i] holds A[i%16, (i/16)*8 + j] for j in [0,7]) while B and C are column-distributed (lane[i] holds B[(i/16)*8 + j, i%16] / C[(i/16)*8 + j, i%16]). That asymmetry is documented inline in test_wmma_gfx1201.cu but a dedicated opus::wmma_adaptor_gfx12 specialization is needed before the high-level tiled API can route gfx1201 — TODO comment added near the wmma_adaptor block. The opus::wmma<> struct itself is fully usable by callers who construct their own fragments (which is what the unit test does). Test (op_tests/opus/device/test_wmma_gfx1201.cu): one kernel per dtype combo loads its lane-local A / B fragment per the gfx12 layout, calls opus::wmma<>::operator(), and stores the C fragment back. Verified on RX 9070 XT (gfx1201) against torch matmul ref: f32 <- f16 / f16 : max_diff = 0.0000 (bit-exact) f32 <- bf16 / bf16 : max_diff = 0.0000 f16 <- f16 / f16 : max_diff = 0.0312 (1 ULP fp16) bf16 <- bf16 / bf16 : max_diff = 0.5000 (1 ULP bf16) f32 <- fp8 / fp8 : max_diff = 0.0000 f32 <- fp8 / bf8 : max_diff = 0.0000 f32 <- bf8 / fp8 : max_diff = 0.0000 f32 <- bf8 / bf8 : max_diff = 0.0000 8/8 variants passed Also renames the earlier test_opus_parse_gfx1201.cu to test_opus_gmem_gfx1201.cu — the test exercises opus::make_gmem load / store on gfx1201 (enabled by the prior buffer_default_config fix in this PR), not anything specifically about parsing. The new name reflects the actual scope. --- csrc/include/opus/opus.hpp | 75 ++++++++- op_tests/opus/device/setup.py | 3 +- op_tests/opus/device/test_opus_device.py | 127 +++++++++++++-- ...e_gfx1201.cu => test_opus_gmem_gfx1201.cu} | 49 +++--- op_tests/opus/device/test_wmma_gfx1201.cu | 150 ++++++++++++++++++ 5 files changed, 363 insertions(+), 41 deletions(-) rename op_tests/opus/device/{test_opus_parse_gfx1201.cu => test_opus_gmem_gfx1201.cu} (63%) create mode 100644 op_tests/opus/device/test_wmma_gfx1201.cu diff --git a/csrc/include/opus/opus.hpp b/csrc/include/opus/opus.hpp index 4deb5b0ceb..3f79a3a398 100644 --- a/csrc/include/opus/opus.hpp +++ b/csrc/include/opus/opus.hpp @@ -2331,8 +2331,15 @@ using mfma_scale_f32_16x16x128_fp4_fp4 = mfma_f32_16x16x128_fp4_fp4; #endif // __GFX9__ (mfma) ///////////////////////////////////////////////////////////////////////////////////////////////////////// -// wmma (gfx1250 / RDNA4, wave32) -#if defined(__gfx1250__) || !defined(__HIP_DEVICE_COMPILE__) +// wmma (RDNA4 / wave32) — supports gfx1250 and gfx1201 (Navi 48). +// The two archs share the same opus::wmma<> template + dispatch shape, but use +// different LLVM builtins: +// - gfx1250: __builtin_amdgcn_wmma__16x16x{32,64,128,4}_ (wmma-256b-insts) +// - gfx1201: __builtin_amdgcn_wmma__16x16x{16,32}__w32_gfx12 (wmma-128b-insts) +// gfx1201 only supports the 16x16x16 shape (plus 16x16x32 for iu4) and the +// dispatch macros below have different argument lists than the gfx1250 ones +// (see DISPATCH_WMMA_GFX12_* further down). +#if defined(__gfx1250__) || defined(__gfx1201__) || !defined(__HIP_DEVICE_COMPILE__) // f16/bf16/f32 builtins: (neg_a, A, neg_b, B, matrix_fmts, C, clamp, neg_c) #define DISPATCH_WMMA_(ta_, tb_, tc_, wm_, wn_, wk_, inst_) \ (std::is_same_v && std::is_same_v && std::is_same_v && \ @@ -2355,6 +2362,27 @@ using mfma_scale_f32_16x16x128_fp4_fp4 = mfma_f32_16x16x128_fp4_fp4; __builtin_bit_cast(vector_t, b), \ static_cast(0), c, false, false); } +// gfx12 (gfx1200 / gfx1201, Navi 44/48) WMMA dispatch macros. +// +// The gfx12 builtins (suffixed _w32_gfx12 in BuiltinsAMDGPU.td) have a leaner +// signature than the gfx1250 ones — there is no matrix_fmts / neg_c slot — +// so they need their own macros even though shape/dtype matching is identical. +// +// FP / FP8-acc variants (f16/bf16/f16→f16/bf16→bf16 → f32 or same-type acc): (A, B, C) +// FP8/BF8 (A/B reinterpreted as packed i32 vector): (A, B, C) +#define DISPATCH_WMMA_GFX12_F32_(ta_, tb_, tc_, wm_, wn_, wk_, inst_) \ + (std::is_same_v && std::is_same_v && std::is_same_v && \ + wave_m == wm_ && wave_n == wn_ && wave_k == wk_) { \ + return inst_(a, b, c); } + +#define DISPATCH_WMMA_GFX12_8BIT_(ta_, tb_, tc_, wm_, wn_, wk_, inst_) \ + (std::is_same_v && std::is_same_v && std::is_same_v && \ + wave_m == wm_ && wave_n == wn_ && wave_k == wk_) { \ + constexpr index_t i32_a = elem_a * static_cast(sizeof(dtype_a)) / static_cast(sizeof(i32_t)); \ + constexpr index_t i32_b = elem_b * static_cast(sizeof(dtype_b)) / static_cast(sizeof(i32_t)); \ + return inst_(__builtin_bit_cast(vector_t, a), \ + __builtin_bit_cast(vector_t, b), c); } + template struct wmma { using dtype_a = remove_cvref_t; @@ -2418,6 +2446,22 @@ struct wmma { else if constexpr DISPATCH_WMMA_8BIT_(bf8_t, fp8_t, fp16_t, 16, 16, 128, __builtin_amdgcn_wmma_f16_16x16x128_bf8_fp8) else if constexpr DISPATCH_WMMA_8BIT_(bf8_t, bf8_t, fp16_t, 16, 16, 128, __builtin_amdgcn_wmma_f16_16x16x128_bf8_bf8) #endif +#if defined(__gfx1201__) + // gfx12 wave32 16x16x16 — f16/bf16 → f32 + else if constexpr DISPATCH_WMMA_GFX12_F32_(fp16_t, fp16_t, fp32_t, 16, 16, 16, __builtin_amdgcn_wmma_f32_16x16x16_f16_w32_gfx12) + else if constexpr DISPATCH_WMMA_GFX12_F32_(bf16_t, bf16_t, fp32_t, 16, 16, 16, __builtin_amdgcn_wmma_f32_16x16x16_bf16_w32_gfx12) + // gfx12 wave32 16x16x16 — same-type accumulator (same 3-arg signature as f32 acc) + else if constexpr DISPATCH_WMMA_GFX12_F32_(fp16_t, fp16_t, fp16_t, 16, 16, 16, __builtin_amdgcn_wmma_f16_16x16x16_f16_w32_gfx12) + else if constexpr DISPATCH_WMMA_GFX12_F32_(bf16_t, bf16_t, bf16_t, 16, 16, 16, __builtin_amdgcn_wmma_bf16_16x16x16_bf16_w32_gfx12) + // gfx12 wave32 16x16x16 — fp8/bf8 × {fp8, bf8} → f32 + else if constexpr DISPATCH_WMMA_GFX12_8BIT_(fp8_t, fp8_t, fp32_t, 16, 16, 16, __builtin_amdgcn_wmma_f32_16x16x16_fp8_fp8_w32_gfx12) + else if constexpr DISPATCH_WMMA_GFX12_8BIT_(fp8_t, bf8_t, fp32_t, 16, 16, 16, __builtin_amdgcn_wmma_f32_16x16x16_fp8_bf8_w32_gfx12) + else if constexpr DISPATCH_WMMA_GFX12_8BIT_(bf8_t, fp8_t, fp32_t, 16, 16, 16, __builtin_amdgcn_wmma_f32_16x16x16_bf8_fp8_w32_gfx12) + else if constexpr DISPATCH_WMMA_GFX12_8BIT_(bf8_t, bf8_t, fp32_t, 16, 16, 16, __builtin_amdgcn_wmma_f32_16x16x16_bf8_bf8_w32_gfx12) + // Note: gfx12 also supports __builtin_amdgcn_wmma_i32_16x16x{16,32}_iu{8,4}_w32_gfx12 + // (signed/unsigned 8-bit / 4-bit integer dot). Not wired here because opus + // doesn't have iu8_t / iu4_t dtype aliases yet — see follow-up. +#endif // __gfx1201__ __builtin_unreachable(); } @@ -2512,6 +2556,19 @@ struct wmma { #undef DISPATCH_WMMA_ #undef DISPATCH_WMMA_BF16F32_ #undef DISPATCH_WMMA_8BIT_ +#undef DISPATCH_WMMA_GFX12_F32_ +#undef DISPATCH_WMMA_GFX12_8BIT_ + +// gfx12 (gfx1200 / gfx1201, Navi 44/48) — wave32 WMMA 16x16x16 type aliases. +// Only the 16x16x{16,32} shapes are valid on gfx12; 16x16x32 is iu4-only. +using wmma_f32_16x16x16_f16 = wmma; +using wmma_f16_16x16x16_f16 = wmma; +using wmma_f32_16x16x16_bf16 = wmma; +using wmma_bf16_16x16x16_bf16 = wmma; +using wmma_f32_16x16x16_fp8_fp8 = wmma; +using wmma_f32_16x16x16_fp8_bf8 = wmma; +using wmma_f32_16x16x16_bf8_fp8 = wmma; +using wmma_f32_16x16x16_bf8_bf8 = wmma; // f16/bf16 16x16x32 using wmma_f32_16x16x32_f16 = wmma; @@ -2543,7 +2600,7 @@ using wmma_scale_f32_16x16x128_fp8_fp8 = wmma using wmma_scale_f32_16x16x128_fp4_fp4 = wmma; // Scaled WMMA (dedicated fp4 32x16x128 instruction) using wmma_scale_f32_32x16x128_fp4_fp4 = wmma; -#endif // __gfx1250__ (wmma) +#endif // __gfx1250__ / __gfx1201__ (wmma) ///////////////////////////////////////////////////////////////////////////////////////////////////////// // adaptor @@ -2743,7 +2800,17 @@ template = {}) { return A{}(mfma(WaveMNK{}), get<1>(WaveMNK{}), get<2>(WaveMNK{}), warp_size_>{}); } #endif // __GFX9__ -// wmma_adaptor: same layout encoding as mfma_adaptor but for wave32 WMMA (gfx1250) +// wmma_adaptor: same layout encoding as mfma_adaptor but for wave32 WMMA (gfx1250). +// +// NOTE: gfx12 (gfx1200 / gfx1201, RDNA4) WMMA uses a column-distributed +// fragment layout for A / B / C (lane selects column, lane group selects an +// 8-row block, vector register index selects row within block — see AMD +// RDNA4 ISA §7.12.2 and CK's wmma_gemm.hpp). This is incompatible with the +// row-distributed encoding below, which was designed for gfx1250's WMMA. +// gfx1201 callers can still use the opus::wmma<> struct directly to invoke +// the gfx12 builtins, but the make_tiled_mma / partition_layout_* path is +// gfx1250-only until a dedicated wmma_adaptor_gfx12 is added. +// // A:[(grpm_a

), (rept_a, grpk_a

, pack_a)], MxK // B:[(grpn_b

), (rept_b, grpk_b

, pack_b)], NxK // C:[(grpm_c

, rept_c, pack_c), (grpn_c

)], MxN diff --git a/op_tests/opus/device/setup.py b/op_tests/opus/device/setup.py index 8273262f4c..c8018e6750 100644 --- a/op_tests/opus/device/setup.py +++ b/op_tests/opus/device/setup.py @@ -40,7 +40,8 @@ "test_numeric_limits.cu", "test_workgroup_barrier.cu", "test_finfo.cu", - "test_opus_parse_gfx1201.cu", + "test_opus_gmem_gfx1201.cu", + "test_wmma_gfx1201.cu", ] diff --git a/op_tests/opus/device/test_opus_device.py b/op_tests/opus/device/test_opus_device.py index 9d9a7d660d..d7e1229321 100644 --- a/op_tests/opus/device/test_opus_device.py +++ b/op_tests/opus/device/test_opus_device.py @@ -176,13 +176,30 @@ def run_vector_add(self, A, B, Result): fn.argtypes = [_VP, _VP, _VP, _I] fn(self._ptr(A), self._ptr(B), self._ptr(Result), int(A.numel())) - # -- opus_parse_gfx1201 (verifies opus.hpp parses + opus utils work on gfx1201) -- - def run_opus_parse_gfx1201(self, A, B, Result): - fn = self._lib.run_opus_parse_gfx1201 + # -- opus_gmem_gfx1201 (verifies opus.hpp parses + opus utils work on gfx1201) -- + def run_opus_gmem_gfx1201(self, A, B, Result): + fn = self._lib.run_opus_gmem_gfx1201 fn.restype = None fn.argtypes = [_VP, _VP, _VP, _I] fn(self._ptr(A), self._ptr(B), self._ptr(Result), int(A.numel())) + # -- wmma_gfx1201 (8 wave32 16x16x16 variants via __builtin_amdgcn_wmma_*_w32_gfx12) -- + def _run_wmma_gfx1201(self, suffix, A, B, C): + fn = getattr(self._lib, f"run_wmma_gfx1201_{suffix}") + fn.restype = None + fn.argtypes = [_VP, _VP, _VP, _I, _I, _I] + fn(self._ptr(A), self._ptr(B), self._ptr(C), + int(A.stride(0)), int(B.stride(0)), int(C.stride(0))) + + def run_wmma_gfx1201_f32_f16(self, A, B, C): self._run_wmma_gfx1201("f32_f16", A, B, C) + def run_wmma_gfx1201_f32_bf16(self, A, B, C): self._run_wmma_gfx1201("f32_bf16", A, B, C) + def run_wmma_gfx1201_f16_f16(self, A, B, C): self._run_wmma_gfx1201("f16_f16", A, B, C) + def run_wmma_gfx1201_bf16_bf16(self, A, B, C): self._run_wmma_gfx1201("bf16_bf16", A, B, C) + def run_wmma_gfx1201_f32_fp8_fp8(self, A, B, C): self._run_wmma_gfx1201("f32_fp8_fp8", A, B, C) + def run_wmma_gfx1201_f32_fp8_bf8(self, A, B, C): self._run_wmma_gfx1201("f32_fp8_bf8", A, B, C) + def run_wmma_gfx1201_f32_bf8_fp8(self, A, B, C): self._run_wmma_gfx1201("f32_bf8_fp8", A, B, C) + def run_wmma_gfx1201_f32_bf8_bf8(self, A, B, C): self._run_wmma_gfx1201("f32_bf8_bf8", A, B, C) + # -- async_load -- def run_async_load(self, Src, Dst): fn = self._lib.run_async_load @@ -1174,19 +1191,19 @@ def test_vector_add(mod): # Archs where the opus.hpp parse-time fix matters. The kernel body in -# test_opus_parse_gfx1201.cu is gated by __gfx1201__ — on other archs the +# test_opus_gmem_gfx1201.cu is gated by __gfx1201__ — on other archs the # launcher runs an empty kernel, so we skip the correctness check to avoid # a misleading failure. _OPUS_PARSE_GFX1201_ARCHS = {"gfx1201"} -def test_opus_parse_gfx1201(mod): +def test_opus_gmem_gfx1201(mod): """Verify opus.hpp parses + opus utilities (make_gmem / .load/.store / cast) work on gfx1201. Mirrors the load → cast → store pattern in sample_kernels.cu. Skips on other archs (kernel body is gfx1201-only).""" arch = _get_gpu_arch() if arch not in _OPUS_PARSE_GFX1201_ARCHS: - print(f" SKIP: opus_parse_gfx1201 (arch={arch}, gfx1201-only test)") + print(f" SKIP: opus_gmem_gfx1201 (arch={arch}, gfx1201-only test)") return 0 n = 1310720 @@ -1198,7 +1215,7 @@ def test_opus_parse_gfx1201(mod): B = torch.randn(n, device=device, dtype=dtype) Result = torch.empty(n, device=device, dtype=dtype) - mod.run_opus_parse_gfx1201(A, B, Result) + mod.run_opus_gmem_gfx1201(A, B, Result) Ref = A + B @@ -1208,14 +1225,96 @@ def test_opus_parse_gfx1201(mod): if not ok: diff_count = (Result - Ref).abs().gt(atol + rtol * Ref.abs()).sum().item() print( - f" FAIL: opus_parse_gfx1201 max_diff={max_diff:.6e}, " + f" FAIL: opus_gmem_gfx1201 max_diff={max_diff:.6e}, " f"{diff_count} elements outside tol" ) return 1 - print(f" PASS: opus_parse_gfx1201 (arch={arch}, n={n}), max_diff={max_diff:.6e}") + print(f" PASS: opus_gmem_gfx1201 (arch={arch}, n={n}), max_diff={max_diff:.6e}") + return 0 + + +# WMMA tests for gfx1201 (Navi 48). Kernel bodies in test_wmma_gfx1201.cu are +# gated by __gfx1201__ — on other archs the launcher runs an empty kernel +# so we skip the correctness check. +_WMMA_GFX1201_ARCHS = {"gfx1201"} + + +def _wmma_gfx1201_tolerances(out_dtype): + # f32 acc is bit-exact against the FP32 reference matmul; f16/bf16 acc + # picks up one ULP of rounding error. + if out_dtype == torch.float32: return 5e-2, 1e-2 + if out_dtype == torch.float16: return 1e-1, 1e-2 + if out_dtype == torch.bfloat16: return 5e-1, 5e-2 + return 1e-2, 1e-2 + + +def _test_wmma_gfx1201_variant(mod, name, runner, in_dtype_a, in_dtype_b, out_dtype): + """Drive one wmma_gfx1201 variant: build random 16x16 A/B, call the + WMMA kernel, compare against fp32 torch matmul cast to out_dtype.""" + arch = _get_gpu_arch() + if arch not in _WMMA_GFX1201_ARCHS: + print(f" SKIP: wmma_gfx1201_{name} (arch={arch}, gfx1201-only)") + return 0 + + M = N = K = 16 + device = torch.device("cuda") + + torch.manual_seed(42) + a_ref = torch.randn(M, K, dtype=torch.float32, device=device) * 2.0 + b_ref = torch.randn(K, N, dtype=torch.float32, device=device) * 2.0 + A = a_ref.to(in_dtype_a) + B = b_ref.to(in_dtype_b) + C = torch.zeros(M, N, dtype=out_dtype, device=device) + + Ref = (A.to(torch.float32) @ B.to(torch.float32)).to(out_dtype) + + runner(A, B, C) + + atol, rtol = _wmma_gfx1201_tolerances(out_dtype) + Cf = C.to(torch.float32) + Rf = Ref.to(torch.float32) + ok = torch.allclose(Cf, Rf, atol=atol, rtol=rtol) + max_diff = (Cf - Rf).abs().max().item() + if not ok: + print(f" FAIL: wmma_gfx1201_{name} max_diff={max_diff:.4e} (atol={atol})") + return 1 + print(f" PASS: wmma_gfx1201_{name} (in=({in_dtype_a}, {in_dtype_b}), out={out_dtype}, max_diff={max_diff:.4e})") return 0 +def test_wmma_gfx1201_f32_f16(mod): + return _test_wmma_gfx1201_variant(mod, "f32_f16", mod.run_wmma_gfx1201_f32_f16, + torch.float16, torch.float16, torch.float32) + +def test_wmma_gfx1201_f32_bf16(mod): + return _test_wmma_gfx1201_variant(mod, "f32_bf16", mod.run_wmma_gfx1201_f32_bf16, + torch.bfloat16, torch.bfloat16, torch.float32) + +def test_wmma_gfx1201_f16_f16(mod): + return _test_wmma_gfx1201_variant(mod, "f16_f16", mod.run_wmma_gfx1201_f16_f16, + torch.float16, torch.float16, torch.float16) + +def test_wmma_gfx1201_bf16_bf16(mod): + return _test_wmma_gfx1201_variant(mod, "bf16_bf16", mod.run_wmma_gfx1201_bf16_bf16, + torch.bfloat16, torch.bfloat16, torch.bfloat16) + +def test_wmma_gfx1201_f32_fp8_fp8(mod): + return _test_wmma_gfx1201_variant(mod, "f32_fp8_fp8", mod.run_wmma_gfx1201_f32_fp8_fp8, + torch.float8_e4m3fn, torch.float8_e4m3fn, torch.float32) + +def test_wmma_gfx1201_f32_fp8_bf8(mod): + return _test_wmma_gfx1201_variant(mod, "f32_fp8_bf8", mod.run_wmma_gfx1201_f32_fp8_bf8, + torch.float8_e4m3fn, torch.float8_e5m2, torch.float32) + +def test_wmma_gfx1201_f32_bf8_fp8(mod): + return _test_wmma_gfx1201_variant(mod, "f32_bf8_fp8", mod.run_wmma_gfx1201_f32_bf8_fp8, + torch.float8_e5m2, torch.float8_e4m3fn, torch.float32) + +def test_wmma_gfx1201_f32_bf8_bf8(mod): + return _test_wmma_gfx1201_variant(mod, "f32_bf8_bf8", mod.run_wmma_gfx1201_f32_bf8_bf8, + torch.float8_e5m2, torch.float8_e5m2, torch.float32) + + def test_async_load(mod): """Test async_load: copy data through LDS and verify integrity.""" # n should be a multiple of BLOCK_SIZE (256) @@ -2221,7 +2320,15 @@ def main(): failures += test_wmma_scale_16x16x128_fp8_bx32_scaled(mod) failures += test_mma_step_k_bf16(mod) failures += test_vector_add(mod) - failures += test_opus_parse_gfx1201(mod) + failures += test_opus_gmem_gfx1201(mod) + failures += test_wmma_gfx1201_f32_f16(mod) + failures += test_wmma_gfx1201_f32_bf16(mod) + failures += test_wmma_gfx1201_f16_f16(mod) + failures += test_wmma_gfx1201_bf16_bf16(mod) + failures += test_wmma_gfx1201_f32_fp8_fp8(mod) + failures += test_wmma_gfx1201_f32_fp8_bf8(mod) + failures += test_wmma_gfx1201_f32_bf8_fp8(mod) + failures += test_wmma_gfx1201_f32_bf8_bf8(mod) failures += test_async_load(mod) failures += test_tr_load_f16(mod) failures += test_dtype_convert_fp32_bf16(mod) diff --git a/op_tests/opus/device/test_opus_parse_gfx1201.cu b/op_tests/opus/device/test_opus_gmem_gfx1201.cu similarity index 63% rename from op_tests/opus/device/test_opus_parse_gfx1201.cu rename to op_tests/opus/device/test_opus_gmem_gfx1201.cu index a1afe11b2f..4f3919d467 100644 --- a/op_tests/opus/device/test_opus_parse_gfx1201.cu +++ b/op_tests/opus/device/test_opus_gmem_gfx1201.cu @@ -2,37 +2,34 @@ // Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. /** - * @file test_opus_parse_gfx1201.cu - * @brief Verify opus.hpp parses + opus::make_gmem load/store work on gfx1201 - * (Navi 48 / RX 9070 XT, RDNA4). + * @file test_opus_gmem_gfx1201.cu + * @brief Exercise opus::make_gmem<>.load<>/.store<> on gfx1201 (Navi 48 / + * RX 9070 XT, RDNA4). * - * Two issues this test covers, both addressed in this commit: + * Two opus.hpp changes need to be in place for this test to pass: * - * 1) Parse-time: without the forward declarations of mfma_adaptor / - * wmma_adaptor at the top of the opus namespace, opus.hpp fails to - * compile for gfx1201 device code because make_tiled_mma()'s default - * template argument names types that are gated behind __GFX9__ / - * __gfx1250__ blocks. + * 1) Forward declarations of mfma_adaptor / wmma_adaptor in the opus + * namespace, so the header parses for gfx1201 device code. + * make_tiled_mma()'s default template argument names these types, + * and without forward decls name lookup fails on archs where the + * full definitions (gated by __GFX9__ / __gfx1250__) are inactive. * - * 2) Runtime: even after the header parses, opus::make_gmem<>.store() - * and .load() silently produced wrong results on gfx1201 because - * buffer_default_config() returned the 0xffffffff fallback (the - * __gfx11__ / __gfx12__ checks on the prior line are typos — clang - * only predefines the uppercase __GFX11__ / __GFX12__). The invalid - * buffer rsrc made all buffer_load_b32 lanes return 0 and all - * buffer_store_b32 lanes drop on the floor. Fix: add explicit - * __gfx1201__ / __gfx1200__ branches with the correct 0x31004000 - * config word that gfx1250 already uses. + * 2) buffer_default_config() returning the correct RDNA buffer rsrc + * config (0x31004000) for gfx1201 instead of the 0xffffffff + * fallback. The existing __gfx11__ / __gfx12__ tokens are typos + * (clang only predefines the uppercase __GFX11__ / __GFX12__), so + * without an explicit __gfx1201__ / __gfx1200__ branch the + * make_gmem<> resource descriptor is invalid and all buffer_load_b32 + * lanes return 0 / all buffer_store_b32 lanes drop on the floor. * - * The kernel below exercises the exact opus API that sample_kernels.cu / - * topk_softmax_kernels_group.cu / etc. depend on: + * The kernel exercises the exact opus API that sample_kernels.cu / + * topk_softmax_kernels_group.cu / mhc_kernels.cu depend on: * * auto g = opus::make_gmem(ptr); * auto v = g.load(i); // buffer_load via cached rsrc * ... opus::cast(v[j]) ... * g.store(vr, i); // buffer_store via cached rsrc * - * If this test produces correct results on gfx1201, both fixes hold. * Kernel body is gated by __gfx1201__ so other archs see an empty no-op * pass — gfx1250 / gfx9x behavior is unchanged. */ @@ -45,7 +42,7 @@ // Element-wise add via opus make_gmem load / store + per-lane opus::cast. // Mirrors the load → cast → store pattern in sample_kernels.cu. template -__global__ void opus_parse_gfx1201_kernel( +__global__ void opus_gmem_gfx1201_kernel( const float* __restrict__ a, const float* __restrict__ b, float* __restrict__ result, @@ -73,7 +70,7 @@ __global__ void opus_parse_gfx1201_kernel( } } -template __global__ void opus_parse_gfx1201_kernel<256, 4>(const float*, const float*, float*, int); +template __global__ void opus_gmem_gfx1201_kernel<256, 4>(const float*, const float*, float*, int); #endif // __gfx1201__ #else @@ -90,9 +87,9 @@ template __global__ void opus_parse_gfx1201_kernel<256, 4>(const float*, const f } while(0) template -__global__ void opus_parse_gfx1201_kernel(const float*, const float*, float*, int) {} +__global__ void opus_gmem_gfx1201_kernel(const float*, const float*, float*, int) {} -extern "C" void run_opus_parse_gfx1201( +extern "C" void run_opus_gmem_gfx1201( const void* d_a, const void* d_b, void* d_result, @@ -107,7 +104,7 @@ extern "C" void run_opus_parse_gfx1201( int blocks = (n + (BLOCK_SIZE * VECTOR_SIZE) - 1) / (BLOCK_SIZE * VECTOR_SIZE); hipLaunchKernelGGL( - (opus_parse_gfx1201_kernel), + (opus_gmem_gfx1201_kernel), dim3(blocks), dim3(BLOCK_SIZE), 0, 0, a, b, r, n); HIP_CALL(hipGetLastError()); diff --git a/op_tests/opus/device/test_wmma_gfx1201.cu b/op_tests/opus/device/test_wmma_gfx1201.cu new file mode 100644 index 0000000000..0487410e4d --- /dev/null +++ b/op_tests/opus/device/test_wmma_gfx1201.cu @@ -0,0 +1,150 @@ +// SPDX-License-Identifier: MIT +// Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. + +/** + * @file test_wmma_gfx1201.cu + * @brief Direct WMMA tests on gfx1201 (Navi 48, RDNA4) via the opus::wmma<> + * struct dispatch (which targets __builtin_amdgcn_wmma_*_w32_gfx12). + * + * Covers the wave32 16x16x16 WMMA variants gfx1201 supports: + * + * - f32 <- f16 / f16 + * - f32 <- bf16 / bf16 + * - f16 <- f16 / f16 + * - bf16 <- bf16 / bf16 + * - f32 <- fp8 / fp8 + * - f32 <- fp8 / bf8 + * - f32 <- bf8 / fp8 + * - f32 <- bf8 / bf8 + * + * Lane / register layout used here (per the gfx12 WMMA fragment diagrams in + * AMD's internal Navi 4 layout reference and the community-verified guide + * https://github.com/JohnTDI-cpu/rdna4-wmma-guide): + * + * lane i in [0, 31], j in [0, 7]: + * A_frag[i][j] = A[i % 16, (i/16)*8 + j] // ROW-distributed + * B_frag[i][j] = B[(i/16)*8 + j, i % 16] // COLUMN-distributed + * C_frag[i][j] = C[(i/16)*8 + j, i % 16] // COLUMN-distributed + * + * That is: + * - A: lanes 0..15 cover rows 0..15 with K=0..7 within each lane; lanes + * 16..31 cover the same rows with K=8..15. + * - B and C: lanes 0..15 cover columns 0..15 with K (for B) or M-rows + * 0..7 (for C) within each lane; lanes 16..31 cover the same columns + * with K=8..15 (B) or M-rows 8..15 (C). + * + * This asymmetry (A row-distributed, B and C column-distributed) is the + * native gfx12 wmma_128b fragment encoding and matches CK's wmma_gemm.hpp. + * + * The tests deliberately go through opus::wmma<>::operator() to exercise the + * dispatch table (DISPATCH_WMMA_GFX12_F32_ / DISPATCH_WMMA_GFX12_8BIT_), not + * the high-level make_tiled_mma / partition_layout_* path — the latter + * relies on opus::wmma_adaptor whose lane encoding is row-distributed and + * matches gfx1250 but NOT gfx12. A dedicated gfx12 wmma_adaptor would be + * needed for the tiled API to work on gfx1201 (see TODO in opus.hpp). + */ + +#ifdef __HIP_DEVICE_COMPILE__ +// ── Device pass ───────────────────────────────────────────────────────────── +#include "opus/opus.hpp" +#if defined(__gfx1201__) + +// Generic 1-wave, 1-tile WMMA driver. Each lane loads its column-distributed +// A and B fragment from global memory, calls opus::wmma<>::operator(), and +// stores the C fragment back. +template +__global__ void wmma_gfx12_kernel( + const DIN_A* __restrict__ ptr_a, + const DIN_B* __restrict__ ptr_b, + DOUT* __restrict__ ptr_c, + int stride_a, + int stride_b, + int stride_c) +{ + constexpr int WM = 16, WN = 16, WK = 16; + constexpr int ELEM_A = WM * WK / 32; // 8 + constexpr int ELEM_B = WN * WK / 32; // 8 + constexpr int ELEM_C = WM * WN / 32; // 8 + + using vtype_a = opus::vector_t; + using vtype_b = opus::vector_t; + using vtype_c = opus::vector_t; + + int lane = static_cast(__builtin_amdgcn_workitem_id_x() % 32); + int col = lane % 16; // for B / C (column-distributed) + int row_base = (lane / 16) * 8; // for B (K block) / C (M row block) + int a_row = lane % 16; // A is row-distributed: lane selects row + int a_k_base = (lane / 16) * 8; // and lane group selects an 8-wide K block + + vtype_a v_a{}; + vtype_b v_b{}; + vtype_c v_c{}; + + #pragma unroll + for (int j = 0; j < ELEM_A; ++j) v_a[j] = ptr_a[a_row * stride_a + (a_k_base + j)]; + #pragma unroll + for (int j = 0; j < ELEM_B; ++j) v_b[j] = ptr_b[(row_base + j) * stride_b + col]; + + // Call through the opus::wmma<> dispatch — this is the path users of the + // library hit when they call opus::wmma<...>{}(a, b, c). + opus::wmma mma; + v_c = mma(v_a, v_b, v_c); + + #pragma unroll + for (int j = 0; j < ELEM_C; ++j) ptr_c[(row_base + j) * stride_c + col] = v_c[j]; +} + +template __global__ void wmma_gfx12_kernel(const opus::fp16_t*, const opus::fp16_t*, opus::fp32_t*, int, int, int); +template __global__ void wmma_gfx12_kernel(const opus::bf16_t*, const opus::bf16_t*, opus::fp32_t*, int, int, int); +template __global__ void wmma_gfx12_kernel(const opus::fp16_t*, const opus::fp16_t*, opus::fp16_t*, int, int, int); +template __global__ void wmma_gfx12_kernel(const opus::bf16_t*, const opus::bf16_t*, opus::bf16_t*, int, int, int); +template __global__ void wmma_gfx12_kernel(const opus::fp8_t* , const opus::fp8_t* , opus::fp32_t*, int, int, int); +template __global__ void wmma_gfx12_kernel(const opus::fp8_t* , const opus::bf8_t* , opus::fp32_t*, int, int, int); +template __global__ void wmma_gfx12_kernel(const opus::bf8_t* , const opus::fp8_t* , opus::fp32_t*, int, int, int); +template __global__ void wmma_gfx12_kernel(const opus::bf8_t* , const opus::bf8_t* , opus::fp32_t*, int, int, int); + +#endif // __gfx1201__ + +#else +// ── Host pass: empty kernel stubs + extern "C" launchers ──────────────────── +#include "opus/opus.hpp" +#include "opus/hip_minimal.hpp" +#include + +#define HIP_CALL(call) do { \ + hipError_t err = (call); \ + if (err != hipSuccess) { \ + fprintf(stderr, "HIP error %d at %s:%d\n", (int)err, __FILE__, __LINE__); \ + return; \ + } \ +} while(0) + +template +__global__ void wmma_gfx12_kernel(const DIN_A*, const DIN_B*, DOUT*, int, int, int) {} + +#define LAUNCHER_(NAME, DA, DB, DC) \ +extern "C" void run_wmma_gfx1201_ ## NAME ( \ + const void* d_a, const void* d_b, void* d_c, \ + int stride_a, int stride_b, int stride_c) \ +{ \ + hipLaunchKernelGGL((wmma_gfx12_kernel), \ + dim3(1, 1), 32, 0, 0, \ + static_cast(d_a), \ + static_cast(d_b), \ + static_cast(d_c), \ + stride_a, stride_b, stride_c); \ + HIP_CALL(hipGetLastError()); \ + HIP_CALL(hipDeviceSynchronize()); \ +} + +LAUNCHER_(f32_f16, fp16_t, fp16_t, fp32_t) +LAUNCHER_(f32_bf16, bf16_t, bf16_t, fp32_t) +LAUNCHER_(f16_f16, fp16_t, fp16_t, fp16_t) +LAUNCHER_(bf16_bf16, bf16_t, bf16_t, bf16_t) +LAUNCHER_(f32_fp8_fp8, fp8_t, fp8_t, fp32_t) +LAUNCHER_(f32_fp8_bf8, fp8_t, bf8_t, fp32_t) +LAUNCHER_(f32_bf8_fp8, bf8_t, fp8_t, fp32_t) +LAUNCHER_(f32_bf8_bf8, bf8_t, bf8_t, fp32_t) + +#undef LAUNCHER_ +#endif // __HIP_DEVICE_COMPILE__ From fdf9d8772619f3830ffd9e827535881578b99114 Mon Sep 17 00:00:00 2001 From: carlushuang Date: Sat, 16 May 2026 23:13:24 +0800 Subject: [PATCH 04/18] opus: black + ruff format on test_opus_device.py wmma additions CI black + ruff complained about the compact single-line wrapper and helper defs added in the previous commit. Just running black on the file (which also resolves the ruff E701 multiple-statements-on-one-line warnings). No behavior change. --- op_tests/opus/device/test_opus_device.py | 143 ++++++++++++++++++----- 1 file changed, 112 insertions(+), 31 deletions(-) diff --git a/op_tests/opus/device/test_opus_device.py b/op_tests/opus/device/test_opus_device.py index d7e1229321..471cf03ef7 100644 --- a/op_tests/opus/device/test_opus_device.py +++ b/op_tests/opus/device/test_opus_device.py @@ -188,17 +188,38 @@ def _run_wmma_gfx1201(self, suffix, A, B, C): fn = getattr(self._lib, f"run_wmma_gfx1201_{suffix}") fn.restype = None fn.argtypes = [_VP, _VP, _VP, _I, _I, _I] - fn(self._ptr(A), self._ptr(B), self._ptr(C), - int(A.stride(0)), int(B.stride(0)), int(C.stride(0))) - - def run_wmma_gfx1201_f32_f16(self, A, B, C): self._run_wmma_gfx1201("f32_f16", A, B, C) - def run_wmma_gfx1201_f32_bf16(self, A, B, C): self._run_wmma_gfx1201("f32_bf16", A, B, C) - def run_wmma_gfx1201_f16_f16(self, A, B, C): self._run_wmma_gfx1201("f16_f16", A, B, C) - def run_wmma_gfx1201_bf16_bf16(self, A, B, C): self._run_wmma_gfx1201("bf16_bf16", A, B, C) - def run_wmma_gfx1201_f32_fp8_fp8(self, A, B, C): self._run_wmma_gfx1201("f32_fp8_fp8", A, B, C) - def run_wmma_gfx1201_f32_fp8_bf8(self, A, B, C): self._run_wmma_gfx1201("f32_fp8_bf8", A, B, C) - def run_wmma_gfx1201_f32_bf8_fp8(self, A, B, C): self._run_wmma_gfx1201("f32_bf8_fp8", A, B, C) - def run_wmma_gfx1201_f32_bf8_bf8(self, A, B, C): self._run_wmma_gfx1201("f32_bf8_bf8", A, B, C) + fn( + self._ptr(A), + self._ptr(B), + self._ptr(C), + int(A.stride(0)), + int(B.stride(0)), + int(C.stride(0)), + ) + + def run_wmma_gfx1201_f32_f16(self, A, B, C): + self._run_wmma_gfx1201("f32_f16", A, B, C) + + def run_wmma_gfx1201_f32_bf16(self, A, B, C): + self._run_wmma_gfx1201("f32_bf16", A, B, C) + + def run_wmma_gfx1201_f16_f16(self, A, B, C): + self._run_wmma_gfx1201("f16_f16", A, B, C) + + def run_wmma_gfx1201_bf16_bf16(self, A, B, C): + self._run_wmma_gfx1201("bf16_bf16", A, B, C) + + def run_wmma_gfx1201_f32_fp8_fp8(self, A, B, C): + self._run_wmma_gfx1201("f32_fp8_fp8", A, B, C) + + def run_wmma_gfx1201_f32_fp8_bf8(self, A, B, C): + self._run_wmma_gfx1201("f32_fp8_bf8", A, B, C) + + def run_wmma_gfx1201_f32_bf8_fp8(self, A, B, C): + self._run_wmma_gfx1201("f32_bf8_fp8", A, B, C) + + def run_wmma_gfx1201_f32_bf8_bf8(self, A, B, C): + self._run_wmma_gfx1201("f32_bf8_bf8", A, B, C) # -- async_load -- def run_async_load(self, Src, Dst): @@ -1242,9 +1263,12 @@ def test_opus_gmem_gfx1201(mod): def _wmma_gfx1201_tolerances(out_dtype): # f32 acc is bit-exact against the FP32 reference matmul; f16/bf16 acc # picks up one ULP of rounding error. - if out_dtype == torch.float32: return 5e-2, 1e-2 - if out_dtype == torch.float16: return 1e-1, 1e-2 - if out_dtype == torch.bfloat16: return 5e-1, 5e-2 + if out_dtype == torch.float32: + return 5e-2, 1e-2 + if out_dtype == torch.float16: + return 1e-1, 1e-2 + if out_dtype == torch.bfloat16: + return 5e-1, 5e-2 return 1e-2, 1e-2 @@ -1278,41 +1302,98 @@ def _test_wmma_gfx1201_variant(mod, name, runner, in_dtype_a, in_dtype_b, out_dt if not ok: print(f" FAIL: wmma_gfx1201_{name} max_diff={max_diff:.4e} (atol={atol})") return 1 - print(f" PASS: wmma_gfx1201_{name} (in=({in_dtype_a}, {in_dtype_b}), out={out_dtype}, max_diff={max_diff:.4e})") + print( + f" PASS: wmma_gfx1201_{name} (in=({in_dtype_a}, {in_dtype_b}), out={out_dtype}, max_diff={max_diff:.4e})" + ) return 0 def test_wmma_gfx1201_f32_f16(mod): - return _test_wmma_gfx1201_variant(mod, "f32_f16", mod.run_wmma_gfx1201_f32_f16, - torch.float16, torch.float16, torch.float32) + return _test_wmma_gfx1201_variant( + mod, + "f32_f16", + mod.run_wmma_gfx1201_f32_f16, + torch.float16, + torch.float16, + torch.float32, + ) + def test_wmma_gfx1201_f32_bf16(mod): - return _test_wmma_gfx1201_variant(mod, "f32_bf16", mod.run_wmma_gfx1201_f32_bf16, - torch.bfloat16, torch.bfloat16, torch.float32) + return _test_wmma_gfx1201_variant( + mod, + "f32_bf16", + mod.run_wmma_gfx1201_f32_bf16, + torch.bfloat16, + torch.bfloat16, + torch.float32, + ) + def test_wmma_gfx1201_f16_f16(mod): - return _test_wmma_gfx1201_variant(mod, "f16_f16", mod.run_wmma_gfx1201_f16_f16, - torch.float16, torch.float16, torch.float16) + return _test_wmma_gfx1201_variant( + mod, + "f16_f16", + mod.run_wmma_gfx1201_f16_f16, + torch.float16, + torch.float16, + torch.float16, + ) + def test_wmma_gfx1201_bf16_bf16(mod): - return _test_wmma_gfx1201_variant(mod, "bf16_bf16", mod.run_wmma_gfx1201_bf16_bf16, - torch.bfloat16, torch.bfloat16, torch.bfloat16) + return _test_wmma_gfx1201_variant( + mod, + "bf16_bf16", + mod.run_wmma_gfx1201_bf16_bf16, + torch.bfloat16, + torch.bfloat16, + torch.bfloat16, + ) + def test_wmma_gfx1201_f32_fp8_fp8(mod): - return _test_wmma_gfx1201_variant(mod, "f32_fp8_fp8", mod.run_wmma_gfx1201_f32_fp8_fp8, - torch.float8_e4m3fn, torch.float8_e4m3fn, torch.float32) + return _test_wmma_gfx1201_variant( + mod, + "f32_fp8_fp8", + mod.run_wmma_gfx1201_f32_fp8_fp8, + torch.float8_e4m3fn, + torch.float8_e4m3fn, + torch.float32, + ) + def test_wmma_gfx1201_f32_fp8_bf8(mod): - return _test_wmma_gfx1201_variant(mod, "f32_fp8_bf8", mod.run_wmma_gfx1201_f32_fp8_bf8, - torch.float8_e4m3fn, torch.float8_e5m2, torch.float32) + return _test_wmma_gfx1201_variant( + mod, + "f32_fp8_bf8", + mod.run_wmma_gfx1201_f32_fp8_bf8, + torch.float8_e4m3fn, + torch.float8_e5m2, + torch.float32, + ) + def test_wmma_gfx1201_f32_bf8_fp8(mod): - return _test_wmma_gfx1201_variant(mod, "f32_bf8_fp8", mod.run_wmma_gfx1201_f32_bf8_fp8, - torch.float8_e5m2, torch.float8_e4m3fn, torch.float32) + return _test_wmma_gfx1201_variant( + mod, + "f32_bf8_fp8", + mod.run_wmma_gfx1201_f32_bf8_fp8, + torch.float8_e5m2, + torch.float8_e4m3fn, + torch.float32, + ) + def test_wmma_gfx1201_f32_bf8_bf8(mod): - return _test_wmma_gfx1201_variant(mod, "f32_bf8_bf8", mod.run_wmma_gfx1201_f32_bf8_bf8, - torch.float8_e5m2, torch.float8_e5m2, torch.float32) + return _test_wmma_gfx1201_variant( + mod, + "f32_bf8_bf8", + mod.run_wmma_gfx1201_f32_bf8_bf8, + torch.float8_e5m2, + torch.float8_e5m2, + torch.float32, + ) def test_async_load(mod): From cde14bb1b4a0eec923c09fc1d5110d0f589c9c4c Mon Sep 17 00:00:00 2001 From: carlushuang Date: Sat, 16 May 2026 23:19:59 +0800 Subject: [PATCH 05/18] opus: condense gfx1201 commentary to match the rest of the file MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Same code paths, same dispatch — just tightens the multi-paragraph comment blocks I added in 70f6e4a / 9878ec9 / b8aa331 down to the 1-line-per-anchor density used everywhere else in opus.hpp. Diff vs origin/main shrinks ~166 → ~109 lines; 8/8 wmma variants and the gmem test still pass on RX 9070 XT. --- csrc/include/opus/opus.hpp | 101 ++++++++----------------------------- 1 file changed, 22 insertions(+), 79 deletions(-) diff --git a/csrc/include/opus/opus.hpp b/csrc/include/opus/opus.hpp index 3f79a3a398..111fbd5a58 100644 --- a/csrc/include/opus/opus.hpp +++ b/csrc/include/opus/opus.hpp @@ -499,27 +499,10 @@ template struct tuple_element. +// gfx1200/gfx1201 (Navi 44/48) listed explicitly — __gfx11__/__gfx12__ are typos (clang predefines only uppercase __GFX11__/__GFX12__), so without this gfx1200/gfx1201 fall into the 0xffffffff sentinel and make_gmem<> stores silently drop. #elif defined(__gfx11__) || defined(__gfx12__) || defined(__gfx1250__) || defined(__gfx1201__) || defined(__gfx1200__) return 0x31004000; #else @@ -2331,14 +2308,7 @@ using mfma_scale_f32_16x16x128_fp4_fp4 = mfma_f32_16x16x128_fp4_fp4; #endif // __GFX9__ (mfma) ///////////////////////////////////////////////////////////////////////////////////////////////////////// -// wmma (RDNA4 / wave32) — supports gfx1250 and gfx1201 (Navi 48). -// The two archs share the same opus::wmma<> template + dispatch shape, but use -// different LLVM builtins: -// - gfx1250: __builtin_amdgcn_wmma__16x16x{32,64,128,4}_ (wmma-256b-insts) -// - gfx1201: __builtin_amdgcn_wmma__16x16x{16,32}__w32_gfx12 (wmma-128b-insts) -// gfx1201 only supports the 16x16x16 shape (plus 16x16x32 for iu4) and the -// dispatch macros below have different argument lists than the gfx1250 ones -// (see DISPATCH_WMMA_GFX12_* further down). +// wmma (RDNA4 / wave32) — gfx1250 uses wmma-256b builtins (16x16x{4,32,64,128}); gfx1201 (Navi 48) uses wmma-128b _w32_gfx12 builtins (16x16x16). Dispatch macros for the two arg-list shapes differ — gfx12 set is DISPATCH_WMMA_GFX12_*. #if defined(__gfx1250__) || defined(__gfx1201__) || !defined(__HIP_DEVICE_COMPILE__) // f16/bf16/f32 builtins: (neg_a, A, neg_b, B, matrix_fmts, C, clamp, neg_c) #define DISPATCH_WMMA_(ta_, tb_, tc_, wm_, wn_, wk_, inst_) \ @@ -2362,26 +2332,15 @@ using mfma_scale_f32_16x16x128_fp4_fp4 = mfma_f32_16x16x128_fp4_fp4; __builtin_bit_cast(vector_t, b), \ static_cast(0), c, false, false); } -// gfx12 (gfx1200 / gfx1201, Navi 44/48) WMMA dispatch macros. -// -// The gfx12 builtins (suffixed _w32_gfx12 in BuiltinsAMDGPU.td) have a leaner -// signature than the gfx1250 ones — there is no matrix_fmts / neg_c slot — -// so they need their own macros even though shape/dtype matching is identical. -// -// FP / FP8-acc variants (f16/bf16/f16→f16/bf16→bf16 → f32 or same-type acc): (A, B, C) -// FP8/BF8 (A/B reinterpreted as packed i32 vector): (A, B, C) +// gfx12 (_w32_gfx12 suffix) builtins: 3-arg (A, B, C) — no matrix_fmts / neg_c slot. fp/same-type acc: #define DISPATCH_WMMA_GFX12_F32_(ta_, tb_, tc_, wm_, wn_, wk_, inst_) \ - (std::is_same_v && std::is_same_v && std::is_same_v && \ - wave_m == wm_ && wave_n == wn_ && wave_k == wk_) { \ - return inst_(a, b, c); } - + (std::is_same_v && std::is_same_v && std::is_same_v && wave_m == wm_ && wave_n == wn_ && wave_k == wk_) { return inst_(a, b, c); } +// fp8/bf8 variant: A/B reinterpreted as packed i32 vector then (A, B, C): #define DISPATCH_WMMA_GFX12_8BIT_(ta_, tb_, tc_, wm_, wn_, wk_, inst_) \ - (std::is_same_v && std::is_same_v && std::is_same_v && \ - wave_m == wm_ && wave_n == wn_ && wave_k == wk_) { \ + (std::is_same_v && std::is_same_v && std::is_same_v && wave_m == wm_ && wave_n == wn_ && wave_k == wk_) { \ constexpr index_t i32_a = elem_a * static_cast(sizeof(dtype_a)) / static_cast(sizeof(i32_t)); \ constexpr index_t i32_b = elem_b * static_cast(sizeof(dtype_b)) / static_cast(sizeof(i32_t)); \ - return inst_(__builtin_bit_cast(vector_t, a), \ - __builtin_bit_cast(vector_t, b), c); } + return inst_(__builtin_bit_cast(vector_t, a), __builtin_bit_cast(vector_t, b), c); } template struct wmma { @@ -2447,20 +2406,15 @@ struct wmma { else if constexpr DISPATCH_WMMA_8BIT_(bf8_t, bf8_t, fp16_t, 16, 16, 128, __builtin_amdgcn_wmma_f16_16x16x128_bf8_bf8) #endif #if defined(__gfx1201__) - // gfx12 wave32 16x16x16 — f16/bf16 → f32 - else if constexpr DISPATCH_WMMA_GFX12_F32_(fp16_t, fp16_t, fp32_t, 16, 16, 16, __builtin_amdgcn_wmma_f32_16x16x16_f16_w32_gfx12) - else if constexpr DISPATCH_WMMA_GFX12_F32_(bf16_t, bf16_t, fp32_t, 16, 16, 16, __builtin_amdgcn_wmma_f32_16x16x16_bf16_w32_gfx12) - // gfx12 wave32 16x16x16 — same-type accumulator (same 3-arg signature as f32 acc) - else if constexpr DISPATCH_WMMA_GFX12_F32_(fp16_t, fp16_t, fp16_t, 16, 16, 16, __builtin_amdgcn_wmma_f16_16x16x16_f16_w32_gfx12) - else if constexpr DISPATCH_WMMA_GFX12_F32_(bf16_t, bf16_t, bf16_t, 16, 16, 16, __builtin_amdgcn_wmma_bf16_16x16x16_bf16_w32_gfx12) - // gfx12 wave32 16x16x16 — fp8/bf8 × {fp8, bf8} → f32 - else if constexpr DISPATCH_WMMA_GFX12_8BIT_(fp8_t, fp8_t, fp32_t, 16, 16, 16, __builtin_amdgcn_wmma_f32_16x16x16_fp8_fp8_w32_gfx12) - else if constexpr DISPATCH_WMMA_GFX12_8BIT_(fp8_t, bf8_t, fp32_t, 16, 16, 16, __builtin_amdgcn_wmma_f32_16x16x16_fp8_bf8_w32_gfx12) - else if constexpr DISPATCH_WMMA_GFX12_8BIT_(bf8_t, fp8_t, fp32_t, 16, 16, 16, __builtin_amdgcn_wmma_f32_16x16x16_bf8_fp8_w32_gfx12) - else if constexpr DISPATCH_WMMA_GFX12_8BIT_(bf8_t, bf8_t, fp32_t, 16, 16, 16, __builtin_amdgcn_wmma_f32_16x16x16_bf8_bf8_w32_gfx12) - // Note: gfx12 also supports __builtin_amdgcn_wmma_i32_16x16x{16,32}_iu{8,4}_w32_gfx12 - // (signed/unsigned 8-bit / 4-bit integer dot). Not wired here because opus - // doesn't have iu8_t / iu4_t dtype aliases yet — see follow-up. + // gfx12 wave32 16x16x16: f16/bf16/fp8/bf8 → f32 + same-type (f16/bf16) acc. iu8/iu4 deferred (no iu*_t aliases yet). + else if constexpr DISPATCH_WMMA_GFX12_F32_ (fp16_t, fp16_t, fp32_t , 16, 16, 16, __builtin_amdgcn_wmma_f32_16x16x16_f16_w32_gfx12) + else if constexpr DISPATCH_WMMA_GFX12_F32_ (bf16_t, bf16_t, fp32_t , 16, 16, 16, __builtin_amdgcn_wmma_f32_16x16x16_bf16_w32_gfx12) + else if constexpr DISPATCH_WMMA_GFX12_F32_ (fp16_t, fp16_t, fp16_t , 16, 16, 16, __builtin_amdgcn_wmma_f16_16x16x16_f16_w32_gfx12) + else if constexpr DISPATCH_WMMA_GFX12_F32_ (bf16_t, bf16_t, bf16_t , 16, 16, 16, __builtin_amdgcn_wmma_bf16_16x16x16_bf16_w32_gfx12) + else if constexpr DISPATCH_WMMA_GFX12_8BIT_(fp8_t , fp8_t , fp32_t , 16, 16, 16, __builtin_amdgcn_wmma_f32_16x16x16_fp8_fp8_w32_gfx12) + else if constexpr DISPATCH_WMMA_GFX12_8BIT_(fp8_t , bf8_t , fp32_t , 16, 16, 16, __builtin_amdgcn_wmma_f32_16x16x16_fp8_bf8_w32_gfx12) + else if constexpr DISPATCH_WMMA_GFX12_8BIT_(bf8_t , fp8_t , fp32_t , 16, 16, 16, __builtin_amdgcn_wmma_f32_16x16x16_bf8_fp8_w32_gfx12) + else if constexpr DISPATCH_WMMA_GFX12_8BIT_(bf8_t , bf8_t , fp32_t , 16, 16, 16, __builtin_amdgcn_wmma_f32_16x16x16_bf8_bf8_w32_gfx12) #endif // __gfx1201__ __builtin_unreachable(); } @@ -2559,8 +2513,7 @@ struct wmma { #undef DISPATCH_WMMA_GFX12_F32_ #undef DISPATCH_WMMA_GFX12_8BIT_ -// gfx12 (gfx1200 / gfx1201, Navi 44/48) — wave32 WMMA 16x16x16 type aliases. -// Only the 16x16x{16,32} shapes are valid on gfx12; 16x16x32 is iu4-only. +// gfx12 (gfx1200/gfx1201, Navi 44/48) wave32 16x16x16 aliases using wmma_f32_16x16x16_f16 = wmma; using wmma_f16_16x16x16_f16 = wmma; using wmma_f32_16x16x16_bf16 = wmma; @@ -2800,17 +2753,7 @@ template = {}) { return A{}(mfma(WaveMNK{}), get<1>(WaveMNK{}), get<2>(WaveMNK{}), warp_size_>{}); } #endif // __GFX9__ -// wmma_adaptor: same layout encoding as mfma_adaptor but for wave32 WMMA (gfx1250). -// -// NOTE: gfx12 (gfx1200 / gfx1201, RDNA4) WMMA uses a column-distributed -// fragment layout for A / B / C (lane selects column, lane group selects an -// 8-row block, vector register index selects row within block — see AMD -// RDNA4 ISA §7.12.2 and CK's wmma_gemm.hpp). This is incompatible with the -// row-distributed encoding below, which was designed for gfx1250's WMMA. -// gfx1201 callers can still use the opus::wmma<> struct directly to invoke -// the gfx12 builtins, but the make_tiled_mma / partition_layout_* path is -// gfx1250-only until a dedicated wmma_adaptor_gfx12 is added. -// +// wmma_adaptor: layout encoding for wave32 WMMA (gfx1250). TODO: gfx12 (gfx1200/gfx1201) needs a dedicated adaptor — its fragment layout is asymmetric (A row-distributed, B/C column-distributed) per AMD RDNA4 ISA §7.12.2 / CK wmma_gemm.hpp. Until then gfx1201 callers use opus::wmma<> directly. // A:[(grpm_a

), (rept_a, grpk_a

, pack_a)], MxK // B:[(grpn_b

), (rept_b, grpk_b

, pack_b)], NxK // C:[(grpm_c

, rept_c, pack_c), (grpn_c

)], MxN From d9f992e8c89062baf34d84c07b67443761bbd253 Mon Sep 17 00:00:00 2001 From: carlushuang Date: Sat, 16 May 2026 23:31:54 +0800 Subject: [PATCH 06/18] opus device tests: skip gfx1201-incompatible kernels at build time MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Two existing test .cu files use opus _async_load, which calls __builtin_amdgcn_raw_ptr_buffer_load_lds — that builtin needs the vmem-to-lds-load-insts target feature (only present on gfx9x / gfx950 / gfx1250). On gfx1201 the per-file hipcc compile errors out and that fails the whole opus_device_test.so build, so before this fix **0 of 72 opus device tests actually ran on gfx1201**. setup.py: add _ARCH_SKIP_SOURCES that drops the two incompatible files (test_async_load.cu, test_load_store_if.cu) from _CU_SOURCES when arch is gfx1200 / gfx1201. The remaining 19 .cu files compile fine for gfx1201 and the .so links cleanly. test_opus_device.py: add a small _skip_if_missing_symbol() helper + early-skip guard at the top of the 5 test functions whose extern "C" launcher symbols come from those skipped files (test_async_load, test_predicated_copy, test_predicated_copy_2d, test_free_func_vector_add, test_predicated_async_load). They now print SKIP cleanly instead of AttributeError-ing at runtime. Result on gfx1201 (after this commit): - 41 PASS (includes all 9 new gfx1201 tests: 8 wmma + 1 gmem) - 53 SKIP (arch-gated mfma/wmma_1250/wmma_scale/mxfp/etc tests) - 4 FAIL (pre-existing fp8/bf8/bf16 ABI mismatches in dtype_convert_fp32_bf16, dtype_convert_fp32_bf16_vec4, numeric_limits, finfo — unrelated to this PR; tests assume fnuz fp8 semantics while gfx12 uses OCP) Behavior on every other arch (gfx9x, gfx1250) is unchanged — those archs are not in _ARCH_SKIP_SOURCES so all sources still compile and all launchers exist, so the symbol-existence guard is a no-op. --- op_tests/opus/device/setup.py | 20 +++++++++++++++++++- op_tests/opus/device/test_opus_device.py | 21 +++++++++++++++++++++ 2 files changed, 40 insertions(+), 1 deletion(-) diff --git a/op_tests/opus/device/setup.py b/op_tests/opus/device/setup.py index c8018e6750..5665bbb5fd 100644 --- a/op_tests/opus/device/setup.py +++ b/op_tests/opus/device/setup.py @@ -111,11 +111,29 @@ def build(verbose=False, jobs=None): if verbose: print(f"[setup] arch={arch}, jobs={jobs}") + # Per-arch skip list: kernels that use builtins not available on the + # target arch. Skipped at .so build time so the rest of the suite + # still links; the Python harness sees the missing extern "C" launcher + # and reports SKIP for those tests. + # + # gfx1201 / gfx1200 (Navi 44/48, RDNA4): opus _async_load uses + # __builtin_amdgcn_raw_ptr_buffer_load_lds which needs the + # `vmem-to-lds-load-insts` target feature (gfx9x / gfx950 / gfx1250 only). + _ARCH_SKIP_SOURCES = { + "gfx1200": {"test_async_load.cu", "test_load_store_if.cu"}, + "gfx1201": {"test_async_load.cu", "test_load_store_if.cu"}, + } + skip = _ARCH_SKIP_SOURCES.get(arch, set()) + sources = [s for s in _CU_SOURCES if s not in skip] + if verbose and skip: + for s in sorted(skip): + print(f"[setup] skip {s} (incompatible with arch={arch})") + t0 = time.monotonic() # Parallel compile: each .cu -> .o tasks = [] - for s in _CU_SOURCES: + for s in sources: src = os.path.join(_THIS_DIR, s) obj = os.path.join(_THIS_DIR, s.replace(".cu", ".o")) tasks.append((src, obj, hipcc, arch, verbose)) diff --git a/op_tests/opus/device/test_opus_device.py b/op_tests/opus/device/test_opus_device.py index 471cf03ef7..1af367de08 100644 --- a/op_tests/opus/device/test_opus_device.py +++ b/op_tests/opus/device/test_opus_device.py @@ -319,6 +319,15 @@ def _get_gpu_arch(): return getattr(props, "gcnArchName", "").split(":")[0] +def _skip_if_missing_symbol(mod, sym, label): + """Print SKIP + return True if the .so wasn't built with `sym` (setup.py + skips arch-incompatible sources per _ARCH_SKIP_SOURCES).""" + if not hasattr(mod._lib, sym): + print(f" SKIP: {label} ({sym} not built for arch={_get_gpu_arch()})") + return True + return False + + # --------------------------------------------------------------------------- # Individual test functions # --------------------------------------------------------------------------- @@ -1398,6 +1407,8 @@ def test_wmma_gfx1201_f32_bf8_bf8(mod): def test_async_load(mod): """Test async_load: copy data through LDS and verify integrity.""" + if _skip_if_missing_symbol(mod, "run_async_load", "async_load"): + return 0 # n should be a multiple of BLOCK_SIZE (256) n = 1048576 # 1M elements device = torch.device("cuda") @@ -1900,6 +1911,8 @@ def test_dtype_convert_fp32_fp4_x4(mod): def test_predicated_copy(mod): """Test gmem load_if/store_if via free function wrappers (boundary predicate).""" + if _skip_if_missing_symbol(mod, "run_predicated_copy", "predicated_copy"): + return 0 # Use n not aligned to block*4 to create a partial boundary condition n = 1001 BLOCK_SIZE = 256 @@ -1946,6 +1959,8 @@ def test_predicated_copy_2d(mod): Uses a 2D layout with row/col boundary checking — the predicate receives (i_row, i_col) and uses them to check bounds, which would fail if given a single flat index. """ + if _skip_if_missing_symbol(mod, "run_predicated_copy_2d", "predicated_copy_2d"): + return 0 ROWS = 4 # issue space rows per workgroup COLS = 4 # issue space cols per thread BLOCK_SIZE = 256 # threads per block @@ -1986,6 +2001,8 @@ def test_predicated_copy_2d(mod): def test_free_func_vector_add(mod): """Test opus::load / opus::store free function wrappers (vector add).""" + if _skip_if_missing_symbol(mod, "run_free_func_add", "free_func_vector_add"): + return 0 n = 1310720 # same as regular vector_add test device = torch.device("cuda") dtype = torch.float32 @@ -2015,6 +2032,10 @@ def test_free_func_vector_add(mod): def test_predicated_async_load(mod): """Test gmem async_load_if via free function wrapper (boundary predicate).""" + if _skip_if_missing_symbol( + mod, "run_predicated_async_load", "predicated_async_load" + ): + return 0 n = 1001 BLOCK_SIZE = 256 n_padded = ((n + BLOCK_SIZE - 1) // BLOCK_SIZE) * BLOCK_SIZE # 1024 From 4ab7e0d66b3754e702b37975cc5589c0ae42c001 Mon Sep 17 00:00:00 2001 From: carlushuang Date: Sat, 16 May 2026 23:40:47 +0800 Subject: [PATCH 07/18] opus: extend gfx1201 gates to also cover gfx1200 (Navi 44, same RDNA4 ISA) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit gfx1200 (Navi 44) and gfx1201 (Navi 48) are siblings in the same RDNA4 family — they share the wmma-128b ISA and the same buffer rsrc format. The buffer_default_config branch added in 9878ec9 already listed both; this commit brings the wmma struct dispatch, the wmma class outer guard, and the two unit-test guards in line. Verification (no gfx1200 hardware available, so compile-only): 1. clang predefines __GFX12__ for both archs; only the per-arch macro differs (__gfx1200__ vs __gfx1201__). 2. LLVM gates all 8 __builtin_amdgcn_wmma_*_w32_gfx12 builtins on "wmma-128b-insts,wavefrontsize32" — a feature both gfx1200 and gfx1201 enable per AMDGPUSubtarget (only gfx1250 adds the bigger wmma-256b-insts shapes). 3. Direct probe: a .cu calling all 8 builtins compiles for both --offload-arch=gfx1200 and --offload-arch=gfx1201 with no errors. 4. test_wmma_gfx1201.cu now builds for both archs producing identical 49432-byte .so files with all 8 run_wmma_gfx1201_* launcher symbols. Risks of broadening: low. If real gfx1200 hardware turns out to differ semantically we would see it as wrong WMMA outputs (not a build break) and can narrow back to per-arch in one line. The alternative — leaving Navi 44 unsupported in opus while it shares the exact same gfx12 wmma ISA — would be worse for downstream consumers. Verified on gfx1201 (RX 9070 XT): all 9 gfx1201 tests still pass. PASS: opus_gmem_gfx1201 max_diff=0.00e+00 PASS: wmma_gfx1201_f32_f16 max_diff=3.81e-06 PASS: wmma_gfx1201_f32_bf16 max_diff=1.91e-06 PASS: wmma_gfx1201_f16_f16 max_diff=3.13e-02 (1 ULP fp16) PASS: wmma_gfx1201_bf16_bf16 max_diff=5.00e-01 (1 ULP bf16) PASS: wmma_gfx1201_f32_fp8_fp8 max_diff=0.00e+00 PASS: wmma_gfx1201_f32_fp8_bf8 max_diff=0.00e+00 PASS: wmma_gfx1201_f32_bf8_fp8 max_diff=0.00e+00 PASS: wmma_gfx1201_f32_bf8_bf8 max_diff=0.00e+00 --- csrc/include/opus/opus.hpp | 10 +++++----- op_tests/opus/device/test_opus_device.py | 19 ++++++++++--------- .../opus/device/test_opus_gmem_gfx1201.cu | 9 +++++---- op_tests/opus/device/test_wmma_gfx1201.cu | 4 ++-- 4 files changed, 22 insertions(+), 20 deletions(-) diff --git a/csrc/include/opus/opus.hpp b/csrc/include/opus/opus.hpp index 111fbd5a58..a56331e695 100644 --- a/csrc/include/opus/opus.hpp +++ b/csrc/include/opus/opus.hpp @@ -2308,8 +2308,8 @@ using mfma_scale_f32_16x16x128_fp4_fp4 = mfma_f32_16x16x128_fp4_fp4; #endif // __GFX9__ (mfma) ///////////////////////////////////////////////////////////////////////////////////////////////////////// -// wmma (RDNA4 / wave32) — gfx1250 uses wmma-256b builtins (16x16x{4,32,64,128}); gfx1201 (Navi 48) uses wmma-128b _w32_gfx12 builtins (16x16x16). Dispatch macros for the two arg-list shapes differ — gfx12 set is DISPATCH_WMMA_GFX12_*. -#if defined(__gfx1250__) || defined(__gfx1201__) || !defined(__HIP_DEVICE_COMPILE__) +// wmma (RDNA4 / wave32) — gfx1250 uses wmma-256b builtins (16x16x{4,32,64,128}); gfx1200/gfx1201 (Navi 44/48) use wmma-128b _w32_gfx12 builtins (16x16x16). Dispatch macros for the two arg-list shapes differ — gfx12 set is DISPATCH_WMMA_GFX12_*. +#if defined(__gfx1250__) || defined(__gfx1201__) || defined(__gfx1200__) || !defined(__HIP_DEVICE_COMPILE__) // f16/bf16/f32 builtins: (neg_a, A, neg_b, B, matrix_fmts, C, clamp, neg_c) #define DISPATCH_WMMA_(ta_, tb_, tc_, wm_, wn_, wk_, inst_) \ (std::is_same_v && std::is_same_v && std::is_same_v && \ @@ -2405,7 +2405,7 @@ struct wmma { else if constexpr DISPATCH_WMMA_8BIT_(bf8_t, fp8_t, fp16_t, 16, 16, 128, __builtin_amdgcn_wmma_f16_16x16x128_bf8_fp8) else if constexpr DISPATCH_WMMA_8BIT_(bf8_t, bf8_t, fp16_t, 16, 16, 128, __builtin_amdgcn_wmma_f16_16x16x128_bf8_bf8) #endif -#if defined(__gfx1201__) +#if defined(__gfx1201__) || defined(__gfx1200__) // gfx12 wave32 16x16x16: f16/bf16/fp8/bf8 → f32 + same-type (f16/bf16) acc. iu8/iu4 deferred (no iu*_t aliases yet). else if constexpr DISPATCH_WMMA_GFX12_F32_ (fp16_t, fp16_t, fp32_t , 16, 16, 16, __builtin_amdgcn_wmma_f32_16x16x16_f16_w32_gfx12) else if constexpr DISPATCH_WMMA_GFX12_F32_ (bf16_t, bf16_t, fp32_t , 16, 16, 16, __builtin_amdgcn_wmma_f32_16x16x16_bf16_w32_gfx12) @@ -2415,7 +2415,7 @@ struct wmma { else if constexpr DISPATCH_WMMA_GFX12_8BIT_(fp8_t , bf8_t , fp32_t , 16, 16, 16, __builtin_amdgcn_wmma_f32_16x16x16_fp8_bf8_w32_gfx12) else if constexpr DISPATCH_WMMA_GFX12_8BIT_(bf8_t , fp8_t , fp32_t , 16, 16, 16, __builtin_amdgcn_wmma_f32_16x16x16_bf8_fp8_w32_gfx12) else if constexpr DISPATCH_WMMA_GFX12_8BIT_(bf8_t , bf8_t , fp32_t , 16, 16, 16, __builtin_amdgcn_wmma_f32_16x16x16_bf8_bf8_w32_gfx12) -#endif // __gfx1201__ +#endif // __gfx1201__ / __gfx1200__ __builtin_unreachable(); } @@ -2553,7 +2553,7 @@ using wmma_scale_f32_16x16x128_fp8_fp8 = wmma using wmma_scale_f32_16x16x128_fp4_fp4 = wmma; // Scaled WMMA (dedicated fp4 32x16x128 instruction) using wmma_scale_f32_32x16x128_fp4_fp4 = wmma; -#endif // __gfx1250__ / __gfx1201__ (wmma) +#endif // __gfx1250__ / __gfx1201__ / __gfx1200__ (wmma) ///////////////////////////////////////////////////////////////////////////////////////////////////////// // adaptor diff --git a/op_tests/opus/device/test_opus_device.py b/op_tests/opus/device/test_opus_device.py index 1af367de08..c0659177e1 100644 --- a/op_tests/opus/device/test_opus_device.py +++ b/op_tests/opus/device/test_opus_device.py @@ -1220,11 +1220,11 @@ def test_vector_add(mod): return 0 -# Archs where the opus.hpp parse-time fix matters. The kernel body in -# test_opus_gmem_gfx1201.cu is gated by __gfx1201__ — on other archs the -# launcher runs an empty kernel, so we skip the correctness check to avoid -# a misleading failure. -_OPUS_PARSE_GFX1201_ARCHS = {"gfx1201"} +# Archs where the opus.hpp gfx12 (Navi 44/48 RDNA4) path is active. The kernel +# body in test_opus_gmem_gfx1201.cu is gated by __gfx1201__ / __gfx1200__ — +# on other archs the launcher runs an empty kernel, so we skip the correctness +# check to avoid a misleading failure. +_OPUS_PARSE_GFX1201_ARCHS = {"gfx1201", "gfx1200"} def test_opus_gmem_gfx1201(mod): @@ -1263,10 +1263,11 @@ def test_opus_gmem_gfx1201(mod): return 0 -# WMMA tests for gfx1201 (Navi 48). Kernel bodies in test_wmma_gfx1201.cu are -# gated by __gfx1201__ — on other archs the launcher runs an empty kernel -# so we skip the correctness check. -_WMMA_GFX1201_ARCHS = {"gfx1201"} +# WMMA tests for gfx1200/gfx1201 (Navi 44/48, RDNA4). Both archs share the +# same gfx12 wmma-128b ISA so the kernel bodies in test_wmma_gfx1201.cu are +# gated by __gfx1201__ / __gfx1200__ — on other archs the launcher runs an +# empty kernel so we skip the correctness check. +_WMMA_GFX1201_ARCHS = {"gfx1201", "gfx1200"} def _wmma_gfx1201_tolerances(out_dtype): diff --git a/op_tests/opus/device/test_opus_gmem_gfx1201.cu b/op_tests/opus/device/test_opus_gmem_gfx1201.cu index 4f3919d467..a9773e530e 100644 --- a/op_tests/opus/device/test_opus_gmem_gfx1201.cu +++ b/op_tests/opus/device/test_opus_gmem_gfx1201.cu @@ -30,15 +30,16 @@ * ... opus::cast(v[j]) ... * g.store(vr, i); // buffer_store via cached rsrc * - * Kernel body is gated by __gfx1201__ so other archs see an empty no-op - * pass — gfx1250 / gfx9x behavior is unchanged. + * Kernel body is gated by __gfx1201__ / __gfx1200__ (same RDNA4 family, + * same ISA for buffer_load/store) so other archs see an empty no-op pass + * — gfx1250 / gfx9x behavior is unchanged. */ #ifdef __HIP_DEVICE_COMPILE__ // ── Device pass: opus.hpp + kernel body, no hip_runtime.h ────────────────── #include "opus/opus.hpp" -#if defined(__gfx1201__) +#if defined(__gfx1201__) || defined(__gfx1200__) // Element-wise add via opus make_gmem load / store + per-lane opus::cast. // Mirrors the load → cast → store pattern in sample_kernels.cu. template @@ -71,7 +72,7 @@ __global__ void opus_gmem_gfx1201_kernel( } template __global__ void opus_gmem_gfx1201_kernel<256, 4>(const float*, const float*, float*, int); -#endif // __gfx1201__ +#endif // __gfx1201__ / __gfx1200__ #else // ── Host pass: launcher + empty kernel stub ──────────────────────────────── diff --git a/op_tests/opus/device/test_wmma_gfx1201.cu b/op_tests/opus/device/test_wmma_gfx1201.cu index 0487410e4d..0fcf05fe5c 100644 --- a/op_tests/opus/device/test_wmma_gfx1201.cu +++ b/op_tests/opus/device/test_wmma_gfx1201.cu @@ -47,7 +47,7 @@ #ifdef __HIP_DEVICE_COMPILE__ // ── Device pass ───────────────────────────────────────────────────────────── #include "opus/opus.hpp" -#if defined(__gfx1201__) +#if defined(__gfx1201__) || defined(__gfx1200__) // Generic 1-wave, 1-tile WMMA driver. Each lane loads its column-distributed // A and B fragment from global memory, calls opus::wmma<>::operator(), and @@ -103,7 +103,7 @@ template __global__ void wmma_gfx12_kernel(const opus::bf8_t* , const opus::fp8_t* , opus::fp32_t*, int, int, int); template __global__ void wmma_gfx12_kernel(const opus::bf8_t* , const opus::bf8_t* , opus::fp32_t*, int, int, int); -#endif // __gfx1201__ +#endif // __gfx1201__ / __gfx1200__ #else // ── Host pass: empty kernel stubs + extern "C" launchers ──────────────────── From 3007acd14c82377d75945d981974f78a2163eb21 Mon Sep 17 00:00:00 2001 From: carlushuang Date: Sun, 17 May 2026 19:15:05 +0800 Subject: [PATCH 08/18] opus: add gfx1201 wave64 WMMA test infrastructure and D-matrix layout analysis MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add test_wmma_gfx1201_w64.cu with 4 wave64 16x16x16 variants (f16/bf16 → f32/same-type acc), compiled with -mwavefrontsize64 via a new _W64_SOURCES set in setup.py. The wave64 fragment layout was reverse-engineered using AMD matrix instruction calculator and hardware probing: - A,B: sequential group order {0,1,2,3} — same as wave32 but 4 elem/lane - D/C: interleaved group order {0,2,1,3} — lanes 0-15 → rows 0-3, lanes 32-47 → rows 4-7, lanes 16-31 → rows 8-11, lanes 48-63 → rows 12-15 Tests currently fail on ROCm 7.2.3 (only lanes 0-31 produce valid output), suggesting the _w64_gfx12 builtins may not be fully functional in this toolchain version. Infrastructure is ready for when support matures. --- csrc/include/opus/opus.hpp | 90 +++++++++++--- op_tests/opus/device/setup.py | 11 +- op_tests/opus/device/test_opus_device.py | 114 ++++++++++++++++++ op_tests/opus/device/test_wmma_gfx1201_w64.cu | 106 ++++++++++++++++ 4 files changed, 305 insertions(+), 16 deletions(-) create mode 100644 op_tests/opus/device/test_wmma_gfx1201_w64.cu diff --git a/csrc/include/opus/opus.hpp b/csrc/include/opus/opus.hpp index a56331e695..4e5dab8d23 100644 --- a/csrc/include/opus/opus.hpp +++ b/csrc/include/opus/opus.hpp @@ -2333,14 +2333,30 @@ using mfma_scale_f32_16x16x128_fp4_fp4 = mfma_f32_16x16x128_fp4_fp4; static_cast(0), c, false, false); } // gfx12 (_w32_gfx12 suffix) builtins: 3-arg (A, B, C) — no matrix_fmts / neg_c slot. fp/same-type acc: -#define DISPATCH_WMMA_GFX12_F32_(ta_, tb_, tc_, wm_, wn_, wk_, inst_) \ - (std::is_same_v && std::is_same_v && std::is_same_v && wave_m == wm_ && wave_n == wn_ && wave_k == wk_) { return inst_(a, b, c); } +// The warp_size param (ws_) selects between _w32_gfx12 and _w64_gfx12 variants and is +// required because both shapes share the same {wm, wn, wk} triple. +#define DISPATCH_WMMA_GFX12_F32_(ta_, tb_, tc_, wm_, wn_, wk_, ws_, inst_) \ + (std::is_same_v && std::is_same_v && std::is_same_v && \ + wave_m == wm_ && wave_n == wn_ && wave_k == wk_ && warp_size == ws_) { return inst_(a, b, c); } // fp8/bf8 variant: A/B reinterpreted as packed i32 vector then (A, B, C): -#define DISPATCH_WMMA_GFX12_8BIT_(ta_, tb_, tc_, wm_, wn_, wk_, inst_) \ - (std::is_same_v && std::is_same_v && std::is_same_v && wave_m == wm_ && wave_n == wn_ && wave_k == wk_) { \ +#define DISPATCH_WMMA_GFX12_8BIT_(ta_, tb_, tc_, wm_, wn_, wk_, ws_, inst_) \ + (std::is_same_v && std::is_same_v && std::is_same_v && \ + wave_m == wm_ && wave_n == wn_ && wave_k == wk_ && warp_size == ws_) { \ constexpr index_t i32_a = elem_a * static_cast(sizeof(dtype_a)) / static_cast(sizeof(i32_t)); \ constexpr index_t i32_b = elem_b * static_cast(sizeof(dtype_b)) / static_cast(sizeof(i32_t)); \ return inst_(__builtin_bit_cast(vector_t, a), __builtin_bit_cast(vector_t, b), c); } +// STEP_K: synthesize a larger K by chaining N copies of an inst with K=inst_k_. Matches +// the gfx942 mfma_f32_16x16x32_*_1k pattern (2× 16x16x16 inner). Lets a kernel use a +// wider A/B vector type per K-step so memory loads vectorize to b128 per lane (e.g. +// wave64 16x16x32 bf16 = 8 elem/lane = 16 bytes = 1 ds_load_b128 per lane). +#define DISPATCH_WMMA_GFX12_F32_STEP_K_(ta_, tb_, tc_, wm_, wn_, wk_, ws_, inst_k_, inst_) \ + (std::is_same_v && std::is_same_v && std::is_same_v && \ + wave_m == wm_ && wave_n == wn_ && wave_k == wk_ && warp_size == ws_) { \ + constexpr index_t steps = wk_ / inst_k_; \ + constexpr index_t e_a = elem_a / steps; constexpr index_t e_b = elem_b / steps; \ + auto tmp = inst_(slice(a, number<0>{}, number{}), slice(b, number<0>{}, number{}), c); \ + static_for([&](auto i){ tmp = inst_(slice(a, number{}, number{}), slice(b, number{}, number{}), tmp); }); \ + return tmp; } template struct wmma { @@ -2406,15 +2422,44 @@ struct wmma { else if constexpr DISPATCH_WMMA_8BIT_(bf8_t, bf8_t, fp16_t, 16, 16, 128, __builtin_amdgcn_wmma_f16_16x16x128_bf8_bf8) #endif #if defined(__gfx1201__) || defined(__gfx1200__) - // gfx12 wave32 16x16x16: f16/bf16/fp8/bf8 → f32 + same-type (f16/bf16) acc. iu8/iu4 deferred (no iu*_t aliases yet). - else if constexpr DISPATCH_WMMA_GFX12_F32_ (fp16_t, fp16_t, fp32_t , 16, 16, 16, __builtin_amdgcn_wmma_f32_16x16x16_f16_w32_gfx12) - else if constexpr DISPATCH_WMMA_GFX12_F32_ (bf16_t, bf16_t, fp32_t , 16, 16, 16, __builtin_amdgcn_wmma_f32_16x16x16_bf16_w32_gfx12) - else if constexpr DISPATCH_WMMA_GFX12_F32_ (fp16_t, fp16_t, fp16_t , 16, 16, 16, __builtin_amdgcn_wmma_f16_16x16x16_f16_w32_gfx12) - else if constexpr DISPATCH_WMMA_GFX12_F32_ (bf16_t, bf16_t, bf16_t , 16, 16, 16, __builtin_amdgcn_wmma_bf16_16x16x16_bf16_w32_gfx12) - else if constexpr DISPATCH_WMMA_GFX12_8BIT_(fp8_t , fp8_t , fp32_t , 16, 16, 16, __builtin_amdgcn_wmma_f32_16x16x16_fp8_fp8_w32_gfx12) - else if constexpr DISPATCH_WMMA_GFX12_8BIT_(fp8_t , bf8_t , fp32_t , 16, 16, 16, __builtin_amdgcn_wmma_f32_16x16x16_fp8_bf8_w32_gfx12) - else if constexpr DISPATCH_WMMA_GFX12_8BIT_(bf8_t , fp8_t , fp32_t , 16, 16, 16, __builtin_amdgcn_wmma_f32_16x16x16_bf8_fp8_w32_gfx12) - else if constexpr DISPATCH_WMMA_GFX12_8BIT_(bf8_t , bf8_t , fp32_t , 16, 16, 16, __builtin_amdgcn_wmma_f32_16x16x16_bf8_bf8_w32_gfx12) + // The _w32/_w64 builtins are gated by wavefrontsize32/wavefrontsize64 target + // features. Clang checks target features eagerly (even for unused if-constexpr + // branches), so we gate the dispatch lines with __has_builtin to only include + // the lines that match the active wave mode. +# if __has_builtin(__builtin_amdgcn_wmma_f32_16x16x16_f16_w32_gfx12) + // gfx12 wave32 16x16x16: f16/bf16/fp8/bf8 → f32 + same-type (f16/bf16) acc. iu8/iu4 deferred. + else if constexpr DISPATCH_WMMA_GFX12_F32_ (fp16_t, fp16_t, fp32_t , 16, 16, 16, 32, __builtin_amdgcn_wmma_f32_16x16x16_f16_w32_gfx12) + else if constexpr DISPATCH_WMMA_GFX12_F32_ (bf16_t, bf16_t, fp32_t , 16, 16, 16, 32, __builtin_amdgcn_wmma_f32_16x16x16_bf16_w32_gfx12) + else if constexpr DISPATCH_WMMA_GFX12_F32_ (fp16_t, fp16_t, fp16_t , 16, 16, 16, 32, __builtin_amdgcn_wmma_f16_16x16x16_f16_w32_gfx12) + else if constexpr DISPATCH_WMMA_GFX12_F32_ (bf16_t, bf16_t, bf16_t , 16, 16, 16, 32, __builtin_amdgcn_wmma_bf16_16x16x16_bf16_w32_gfx12) + else if constexpr DISPATCH_WMMA_GFX12_8BIT_(fp8_t , fp8_t , fp32_t , 16, 16, 16, 32, __builtin_amdgcn_wmma_f32_16x16x16_fp8_fp8_w32_gfx12) + else if constexpr DISPATCH_WMMA_GFX12_8BIT_(fp8_t , bf8_t , fp32_t , 16, 16, 16, 32, __builtin_amdgcn_wmma_f32_16x16x16_fp8_bf8_w32_gfx12) + else if constexpr DISPATCH_WMMA_GFX12_8BIT_(bf8_t , fp8_t , fp32_t , 16, 16, 16, 32, __builtin_amdgcn_wmma_f32_16x16x16_bf8_fp8_w32_gfx12) + else if constexpr DISPATCH_WMMA_GFX12_8BIT_(bf8_t , bf8_t , fp32_t , 16, 16, 16, 32, __builtin_amdgcn_wmma_f32_16x16x16_bf8_bf8_w32_gfx12) + // gfx12 wave32 16x16x32: synthetic — 2× stacked 16x16x16 along K. Per-lane A/B = + // 16 fp16/bf16 elem (= 32 B). Useful for wider per-K-step register footprint and + // amortizing inner-loop overhead across 2 wmma issues. + else if constexpr DISPATCH_WMMA_GFX12_F32_STEP_K_(fp16_t, fp16_t, fp32_t , 16, 16, 32, 32, 16, __builtin_amdgcn_wmma_f32_16x16x16_f16_w32_gfx12) + else if constexpr DISPATCH_WMMA_GFX12_F32_STEP_K_(bf16_t, bf16_t, fp32_t , 16, 16, 32, 32, 16, __builtin_amdgcn_wmma_f32_16x16x16_bf16_w32_gfx12) + else if constexpr DISPATCH_WMMA_GFX12_F32_STEP_K_(fp16_t, fp16_t, fp16_t , 16, 16, 32, 32, 16, __builtin_amdgcn_wmma_f16_16x16x16_f16_w32_gfx12) + else if constexpr DISPATCH_WMMA_GFX12_F32_STEP_K_(bf16_t, bf16_t, bf16_t , 16, 16, 32, 32, 16, __builtin_amdgcn_wmma_bf16_16x16x16_bf16_w32_gfx12) +# endif // wave32 builtin available +# if __has_builtin(__builtin_amdgcn_wmma_f32_16x16x16_f16_w64_gfx12) + // gfx12 wave64 16x16x16: native _w64_gfx12 builtins. Per-lane A/B = 4 elem + // (= 8 B = 1 b64 load per lane). Kernel must be compiled with wavefrontsize64. + else if constexpr DISPATCH_WMMA_GFX12_F32_ (fp16_t, fp16_t, fp32_t , 16, 16, 16, 64, __builtin_amdgcn_wmma_f32_16x16x16_f16_w64_gfx12) + else if constexpr DISPATCH_WMMA_GFX12_F32_ (bf16_t, bf16_t, fp32_t , 16, 16, 16, 64, __builtin_amdgcn_wmma_f32_16x16x16_bf16_w64_gfx12) + else if constexpr DISPATCH_WMMA_GFX12_F32_ (fp16_t, fp16_t, fp16_t , 16, 16, 16, 64, __builtin_amdgcn_wmma_f16_16x16x16_f16_w64_gfx12) + else if constexpr DISPATCH_WMMA_GFX12_F32_ (bf16_t, bf16_t, bf16_t , 16, 16, 16, 64, __builtin_amdgcn_wmma_bf16_16x16x16_bf16_w64_gfx12) + // gfx12 wave64 16x16x32: synthetic — 2× stacked _w64 16x16x16. Per-lane A/B = 8 + // elem (= 16 B = 1 ds_load_b128 / global_load_b128 per lane). The primary + // motivation for the wave64 path on gfx12: wider per-lane vector-load slot than + // wave64 16x16x16's natural 8 B per lane. + else if constexpr DISPATCH_WMMA_GFX12_F32_STEP_K_(fp16_t, fp16_t, fp32_t , 16, 16, 32, 64, 16, __builtin_amdgcn_wmma_f32_16x16x16_f16_w64_gfx12) + else if constexpr DISPATCH_WMMA_GFX12_F32_STEP_K_(bf16_t, bf16_t, fp32_t , 16, 16, 32, 64, 16, __builtin_amdgcn_wmma_f32_16x16x16_bf16_w64_gfx12) + else if constexpr DISPATCH_WMMA_GFX12_F32_STEP_K_(fp16_t, fp16_t, fp16_t , 16, 16, 32, 64, 16, __builtin_amdgcn_wmma_f16_16x16x16_f16_w64_gfx12) + else if constexpr DISPATCH_WMMA_GFX12_F32_STEP_K_(bf16_t, bf16_t, bf16_t , 16, 16, 32, 64, 16, __builtin_amdgcn_wmma_bf16_16x16x16_bf16_w64_gfx12) +# endif // wave64 builtin available #endif // __gfx1201__ / __gfx1200__ __builtin_unreachable(); } @@ -2512,8 +2557,10 @@ struct wmma { #undef DISPATCH_WMMA_8BIT_ #undef DISPATCH_WMMA_GFX12_F32_ #undef DISPATCH_WMMA_GFX12_8BIT_ +#undef DISPATCH_WMMA_GFX12_F32_STEP_K_ -// gfx12 (gfx1200/gfx1201, Navi 44/48) wave32 16x16x16 aliases +// gfx12 (gfx1200/gfx1201, Navi 44/48) wave32 16x16x16 aliases (default warp_size). +// The wave64 and synthetic 16x16x32 aliases follow below. using wmma_f32_16x16x16_f16 = wmma; using wmma_f16_16x16x16_f16 = wmma; using wmma_f32_16x16x16_bf16 = wmma; @@ -2522,6 +2569,21 @@ using wmma_f32_16x16x16_fp8_fp8 = wmma; using wmma_f32_16x16x16_fp8_bf8 = wmma; using wmma_f32_16x16x16_bf8_fp8 = wmma; using wmma_f32_16x16x16_bf8_bf8 = wmma; +// gfx12 wave32 16x16x32 synthetic (2× 16x16x16 stacked along K, see STEP_K dispatch). +using wmma_f32_16x16x32_f16_w32 = wmma; +using wmma_f32_16x16x32_bf16_w32 = wmma; +using wmma_f16_16x16x32_f16_w32 = wmma; +using wmma_bf16_16x16x32_bf16_w32 = wmma; +// gfx12 wave64 16x16x16 native — kernel must be compiled with wavefrontsize64. +using wmma_f32_16x16x16_f16_w64 = wmma; +using wmma_f32_16x16x16_bf16_w64 = wmma; +using wmma_f16_16x16x16_f16_w64 = wmma; +using wmma_bf16_16x16x16_bf16_w64 = wmma; +// gfx12 wave64 16x16x32 synthetic (2× 16x16x16_w64) — per-lane A/B = 16 B = 1 b128 load. +using wmma_f32_16x16x32_f16_w64 = wmma; +using wmma_f32_16x16x32_bf16_w64 = wmma; +using wmma_f16_16x16x32_f16_w64 = wmma; +using wmma_bf16_16x16x32_bf16_w64 = wmma; // f16/bf16 16x16x32 using wmma_f32_16x16x32_f16 = wmma; diff --git a/op_tests/opus/device/setup.py b/op_tests/opus/device/setup.py index 5665bbb5fd..ddc89ed0d4 100644 --- a/op_tests/opus/device/setup.py +++ b/op_tests/opus/device/setup.py @@ -42,8 +42,12 @@ "test_finfo.cu", "test_opus_gmem_gfx1201.cu", "test_wmma_gfx1201.cu", + "test_wmma_gfx1201_w64.cu", ] +# Sources requiring -mwavefrontsize64 (wave64 builtins). +_W64_SOURCES = {"test_wmma_gfx1201_w64.cu"} + def _detect_arch(): try: @@ -76,7 +80,8 @@ def _find_hipcc(): def _compile_one(args): """Compile a single .cu -> .o. Used as a worker function for parallel builds.""" - src, obj, hipcc, arch, verbose = args + src, obj, hipcc, arch, verbose, *rest = args + extra_flags = rest[0] if rest else [] cmd = [ hipcc, f"--offload-arch={arch}", @@ -85,6 +90,7 @@ def _compile_one(args): "-D__HIPCC_RTC__", f"-I{_REPO_CSRC}", f"-I{_THIS_DIR}", + *extra_flags, "-c", src, "-o", @@ -136,7 +142,8 @@ def build(verbose=False, jobs=None): for s in sources: src = os.path.join(_THIS_DIR, s) obj = os.path.join(_THIS_DIR, s.replace(".cu", ".o")) - tasks.append((src, obj, hipcc, arch, verbose)) + extra = ["-mwavefrontsize64"] if s in _W64_SOURCES else [] + tasks.append((src, obj, hipcc, arch, verbose, extra)) objs = [] with ProcessPoolExecutor(max_workers=jobs) as pool: diff --git a/op_tests/opus/device/test_opus_device.py b/op_tests/opus/device/test_opus_device.py index c0659177e1..794ba71112 100644 --- a/op_tests/opus/device/test_opus_device.py +++ b/op_tests/opus/device/test_opus_device.py @@ -221,6 +221,32 @@ def run_wmma_gfx1201_f32_bf8_fp8(self, A, B, C): def run_wmma_gfx1201_f32_bf8_bf8(self, A, B, C): self._run_wmma_gfx1201("f32_bf8_bf8", A, B, C) + # -- wmma_gfx1201_w64 (4 wave64 16x16x16 variants via __builtin_amdgcn_wmma_*_w64_gfx12) -- + def _run_wmma_gfx1201_w64(self, suffix, A, B, C): + fn = getattr(self._lib, f"run_wmma_gfx1201_w64_{suffix}") + fn.restype = None + fn.argtypes = [_VP, _VP, _VP, _I, _I, _I] + fn( + self._ptr(A), + self._ptr(B), + self._ptr(C), + int(A.stride(0)), + int(B.stride(0)), + int(C.stride(0)), + ) + + def run_wmma_gfx1201_w64_f32_f16(self, A, B, C): + self._run_wmma_gfx1201_w64("f32_f16", A, B, C) + + def run_wmma_gfx1201_w64_f32_bf16(self, A, B, C): + self._run_wmma_gfx1201_w64("f32_bf16", A, B, C) + + def run_wmma_gfx1201_w64_f16_f16(self, A, B, C): + self._run_wmma_gfx1201_w64("f16_f16", A, B, C) + + def run_wmma_gfx1201_w64_bf16_bf16(self, A, B, C): + self._run_wmma_gfx1201_w64("bf16_bf16", A, B, C) + # -- async_load -- def run_async_load(self, Src, Dst): fn = self._lib.run_async_load @@ -1406,6 +1432,90 @@ def test_wmma_gfx1201_f32_bf8_bf8(mod): ) + +# WMMA wave64 tests for gfx1200/gfx1201 (RDNA4). The _w64_gfx12 builtins +# use 64 lanes with 4 elem/lane instead of wave32's 8 elem/lane. + +def _test_wmma_gfx1201_w64_variant(mod, name, runner, in_dtype_a, in_dtype_b, out_dtype): + arch = _get_gpu_arch() + if arch not in _WMMA_GFX1201_ARCHS: + print(f" SKIP: wmma_gfx1201_w64_{name} (arch={arch}, gfx1201-only)") + return 0 + if _skip_if_missing_symbol(mod, f"run_wmma_gfx1201_w64_{name}", f"wmma_gfx1201_w64_{name}"): + return 0 + + M = N = K = 16 + device = torch.device("cuda") + + torch.manual_seed(42) + a_ref = torch.randn(M, K, dtype=torch.float32, device=device) * 2.0 + b_ref = torch.randn(K, N, dtype=torch.float32, device=device) * 2.0 + A = a_ref.to(in_dtype_a) + B = b_ref.to(in_dtype_b) + C = torch.zeros(M, N, dtype=out_dtype, device=device) + + Ref = (A.to(torch.float32) @ B.to(torch.float32)).to(out_dtype) + + runner(A, B, C) + + atol, rtol = _wmma_gfx1201_tolerances(out_dtype) + Cf = C.to(torch.float32) + Rf = Ref.to(torch.float32) + ok = torch.allclose(Cf, Rf, atol=atol, rtol=rtol) + max_diff = (Cf - Rf).abs().max().item() + if not ok: + print(f" FAIL: wmma_gfx1201_w64_{name} max_diff={max_diff:.4e} (atol={atol})") + return 1 + print( + f" PASS: wmma_gfx1201_w64_{name} (in=({in_dtype_a}, {in_dtype_b}), out={out_dtype}, max_diff={max_diff:.4e})" + ) + return 0 + + +def test_wmma_gfx1201_w64_f32_f16(mod): + return _test_wmma_gfx1201_w64_variant( + mod, + "f32_f16", + mod.run_wmma_gfx1201_w64_f32_f16, + torch.float16, + torch.float16, + torch.float32, + ) + + +def test_wmma_gfx1201_w64_f32_bf16(mod): + return _test_wmma_gfx1201_w64_variant( + mod, + "f32_bf16", + mod.run_wmma_gfx1201_w64_f32_bf16, + torch.bfloat16, + torch.bfloat16, + torch.float32, + ) + + +def test_wmma_gfx1201_w64_f16_f16(mod): + return _test_wmma_gfx1201_w64_variant( + mod, + "f16_f16", + mod.run_wmma_gfx1201_w64_f16_f16, + torch.float16, + torch.float16, + torch.float16, + ) + + +def test_wmma_gfx1201_w64_bf16_bf16(mod): + return _test_wmma_gfx1201_w64_variant( + mod, + "bf16_bf16", + mod.run_wmma_gfx1201_w64_bf16_bf16, + torch.bfloat16, + torch.bfloat16, + torch.bfloat16, + ) + + def test_async_load(mod): """Test async_load: copy data through LDS and verify integrity.""" if _skip_if_missing_symbol(mod, "run_async_load", "async_load"): @@ -2432,6 +2542,10 @@ def main(): failures += test_wmma_gfx1201_f32_fp8_bf8(mod) failures += test_wmma_gfx1201_f32_bf8_fp8(mod) failures += test_wmma_gfx1201_f32_bf8_bf8(mod) + failures += test_wmma_gfx1201_w64_f32_f16(mod) + failures += test_wmma_gfx1201_w64_f32_bf16(mod) + failures += test_wmma_gfx1201_w64_f16_f16(mod) + failures += test_wmma_gfx1201_w64_bf16_bf16(mod) failures += test_async_load(mod) failures += test_tr_load_f16(mod) failures += test_dtype_convert_fp32_bf16(mod) diff --git a/op_tests/opus/device/test_wmma_gfx1201_w64.cu b/op_tests/opus/device/test_wmma_gfx1201_w64.cu new file mode 100644 index 0000000000..5d8b461e7d --- /dev/null +++ b/op_tests/opus/device/test_wmma_gfx1201_w64.cu @@ -0,0 +1,106 @@ +// SPDX-License-Identifier: MIT +// Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. + +#ifdef __HIP_DEVICE_COMPILE__ +#include "opus/opus.hpp" +#if defined(__gfx1201__) || defined(__gfx1200__) + +template +__global__ void wmma_gfx12_w64_kernel( + const DIN_A* __restrict__ ptr_a, + const DIN_B* __restrict__ ptr_b, + DOUT* __restrict__ ptr_c, + int stride_a, + int stride_b, + int stride_c) +{ + constexpr int WM = 16, WN = 16, WK = 16; + constexpr int WARP = 64; + constexpr int ELEM = WM * WK / WARP; // 4 + + using vtype_a = opus::vector_t; + using vtype_b = opus::vector_t; + using vtype_c = opus::vector_t; + + int lane = static_cast(opus::lane_id()); + int group = lane / 16; // 0,1,2,3 + int sublane = lane % 16; // column for B/C, row for A + + // A: row-distributed. A[row=sublane][K=(group*4+j)] + // Same group order as my original assumption — confirmed by AMD matrix calculator. + int a_row = sublane; + int a_k_base = group * 4; + + // B: column-distributed. B[K=(group*4+j)][col=sublane] + int b_col = sublane; + int b_k_base = group * 4; + + // D/C: column-distributed with INTERLEAVED group order {0,2,1,3} + // Group 0 (lanes 0-15) → rows 0-3 + // Group 1 (lanes 16-31) → rows 8-11 (NOT 4-7!) + // Group 2 (lanes 32-47) → rows 4-7 + // Group 3 (lanes 48-63) → rows 12-15 + constexpr int c_row_base_lut[4] = {0, 8, 4, 12}; + int c_col = sublane; + int c_row_base = c_row_base_lut[group]; + + vtype_a v_a{}; + vtype_b v_b{}; + vtype_c v_c{}; + + #pragma unroll + for (int j = 0; j < ELEM; ++j) v_a[j] = ptr_a[a_row * stride_a + (a_k_base + j)]; + #pragma unroll + for (int j = 0; j < ELEM; ++j) v_b[j] = ptr_b[(b_k_base + j) * stride_b + b_col]; + + opus::wmma mma; + v_c = mma(v_a, v_b, v_c); + + #pragma unroll + for (int j = 0; j < ELEM; ++j) ptr_c[(c_row_base + j) * stride_c + c_col] = v_c[j]; +} + +template __global__ void wmma_gfx12_w64_kernel(const opus::fp16_t*, const opus::fp16_t*, opus::fp32_t*, int, int, int); +template __global__ void wmma_gfx12_w64_kernel(const opus::bf16_t*, const opus::bf16_t*, opus::fp32_t*, int, int, int); +template __global__ void wmma_gfx12_w64_kernel(const opus::fp16_t*, const opus::fp16_t*, opus::fp16_t*, int, int, int); +template __global__ void wmma_gfx12_w64_kernel(const opus::bf16_t*, const opus::bf16_t*, opus::bf16_t*, int, int, int); + +#endif +#else +#include "opus/opus.hpp" +#include "opus/hip_minimal.hpp" +#include + +#define HIP_CALL(call) do { \ + hipError_t err = (call); \ + if (err != hipSuccess) { \ + fprintf(stderr, "HIP error %d at %s:%d\n", (int)err, __FILE__, __LINE__); \ + return; \ + } \ +} while(0) + +template +__global__ void wmma_gfx12_w64_kernel(const DIN_A*, const DIN_B*, DOUT*, int, int, int) {} + +#define LAUNCHER_(NAME, DA, DB, DC) \ +extern "C" void run_wmma_gfx1201_w64_ ## NAME ( \ + const void* d_a, const void* d_b, void* d_c, \ + int stride_a, int stride_b, int stride_c) \ +{ \ + hipLaunchKernelGGL((wmma_gfx12_w64_kernel), \ + dim3(1, 1), 64, 0, 0, \ + static_cast(d_a), \ + static_cast(d_b), \ + static_cast(d_c), \ + stride_a, stride_b, stride_c); \ + HIP_CALL(hipGetLastError()); \ + HIP_CALL(hipDeviceSynchronize()); \ +} + +LAUNCHER_(f32_f16, fp16_t, fp16_t, fp32_t) +LAUNCHER_(f32_bf16, bf16_t, bf16_t, fp32_t) +LAUNCHER_(f16_f16, fp16_t, fp16_t, fp16_t) +LAUNCHER_(bf16_bf16, bf16_t, bf16_t, bf16_t) + +#undef LAUNCHER_ +#endif From 1387801eaeaaeef968d007afb9c0d3607a8b7097 Mon Sep 17 00:00:00 2001 From: carlushuang Date: Sun, 17 May 2026 20:24:09 +0800 Subject: [PATCH 09/18] =?UTF-8?q?opus:=20fix=20wave64=20WMMA=20test=20?= =?UTF-8?q?=E2=80=94=20use=20direct=20mbcnt=20for=20lane=5Fid=20on=20gfx12?= =?UTF-8?q?01?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit opus::lane_id() delegates to get_warp_size() which hardcodes 32 for gfx1200/gfx1201. In wave64 mode this caused mbcnt_lo-only lane IDs (0-31), making lanes 32-63 duplicate lanes 0-31 and corrupting all loads/stores. Fix: use __builtin_amdgcn_mbcnt_hi(-1, __builtin_amdgcn_mbcnt_lo(-1, 0)) directly to get the full 0-63 lane ID. All 4 wave64 WMMA variants now pass on gfx1201 (RX 9070 XT): PASS: f32<-f16 max_diff=3.81e-06 PASS: f32<-bf16 max_diff=3.81e-06 PASS: f16<-f16 max_diff=3.13e-02 PASS: bf16<-bf16 max_diff=5.00e-01 --- op_tests/opus/device/test_wmma_gfx1201_w64.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/op_tests/opus/device/test_wmma_gfx1201_w64.cu b/op_tests/opus/device/test_wmma_gfx1201_w64.cu index 5d8b461e7d..c34ecf0228 100644 --- a/op_tests/opus/device/test_wmma_gfx1201_w64.cu +++ b/op_tests/opus/device/test_wmma_gfx1201_w64.cu @@ -22,7 +22,7 @@ __global__ void wmma_gfx12_w64_kernel( using vtype_b = opus::vector_t; using vtype_c = opus::vector_t; - int lane = static_cast(opus::lane_id()); + int lane = __builtin_amdgcn_mbcnt_hi(-1, __builtin_amdgcn_mbcnt_lo(-1, 0)); int group = lane / 16; // 0,1,2,3 int sublane = lane % 16; // column for B/C, row for A From f0a7c216b1c367310a895af2b407b6c48e9ef5d2 Mon Sep 17 00:00:00 2001 From: carlushuang Date: Sun, 17 May 2026 20:57:05 +0800 Subject: [PATCH 10/18] opus: fix get_warp_size() for gfx12 wave64 via __has_builtin proxy get_warp_size() hardcoded 32 for gfx1200/gfx1201, breaking lane_id() and all wave64 code paths. The proper fix is to detect -mwavefrontsize64 at compile time, but __AMDGCN_WAVEFRONT_SIZE__ was removed in ROCm 7.2 and __builtin_amdgcn_wavefrontsize() is not constexpr. Use __has_builtin on the _w64_gfx12 WMMA builtins as a constexpr proxy: clang gates these builtins on the wavefrontsize64 target feature, which is exactly what -mwavefrontsize64 sets. Reverts the direct mbcnt workaround in test_wmma_gfx1201_w64.cu back to opus::lane_id(), which now works correctly. --- csrc/include/opus/opus.hpp | 6 ++++++ op_tests/opus/device/test_wmma_gfx1201_w64.cu | 2 +- 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/csrc/include/opus/opus.hpp b/csrc/include/opus/opus.hpp index 4e5dab8d23..a1c71c3e88 100644 --- a/csrc/include/opus/opus.hpp +++ b/csrc/include/opus/opus.hpp @@ -1508,6 +1508,12 @@ OPUS_H_D constexpr index_t get_warp_size() return 32; #elif defined(__GFX9__) || !defined(__HIP_DEVICE_COMPILE__) return 64; +#elif (defined(__gfx1201__) || defined(__gfx1200__)) && __has_builtin(__builtin_amdgcn_wmma_f32_16x16x16_f16_w64_gfx12) + // Workaround: __AMDGCN_WAVEFRONT_SIZE__ macro was removed in ROCm 7.2 (clang 22), + // and __builtin_amdgcn_wavefrontsize() is not constexpr. The _w64_gfx12 builtins are + // gated by the wavefrontsize64 target feature (set by -mwavefrontsize64), so + // __has_builtin serves as a constexpr-compatible proxy for wave64 detection on gfx12. + return 64; #else return 32; #endif diff --git a/op_tests/opus/device/test_wmma_gfx1201_w64.cu b/op_tests/opus/device/test_wmma_gfx1201_w64.cu index c34ecf0228..5d8b461e7d 100644 --- a/op_tests/opus/device/test_wmma_gfx1201_w64.cu +++ b/op_tests/opus/device/test_wmma_gfx1201_w64.cu @@ -22,7 +22,7 @@ __global__ void wmma_gfx12_w64_kernel( using vtype_b = opus::vector_t; using vtype_c = opus::vector_t; - int lane = __builtin_amdgcn_mbcnt_hi(-1, __builtin_amdgcn_mbcnt_lo(-1, 0)); + int lane = static_cast(opus::lane_id()); int group = lane / 16; // 0,1,2,3 int sublane = lane % 16; // column for B/C, row for A From a4ce01b03914eaee6acb2004abb2a439db39137e Mon Sep 17 00:00:00 2001 From: carlushuang Date: Sun, 17 May 2026 21:01:58 +0800 Subject: [PATCH 11/18] opus: condense gfx12 wave64 comments to match file style --- csrc/include/opus/opus.hpp | 37 +++++++++++-------------------------- 1 file changed, 11 insertions(+), 26 deletions(-) diff --git a/csrc/include/opus/opus.hpp b/csrc/include/opus/opus.hpp index a1c71c3e88..5d2a431d12 100644 --- a/csrc/include/opus/opus.hpp +++ b/csrc/include/opus/opus.hpp @@ -1509,11 +1509,7 @@ OPUS_H_D constexpr index_t get_warp_size() #elif defined(__GFX9__) || !defined(__HIP_DEVICE_COMPILE__) return 64; #elif (defined(__gfx1201__) || defined(__gfx1200__)) && __has_builtin(__builtin_amdgcn_wmma_f32_16x16x16_f16_w64_gfx12) - // Workaround: __AMDGCN_WAVEFRONT_SIZE__ macro was removed in ROCm 7.2 (clang 22), - // and __builtin_amdgcn_wavefrontsize() is not constexpr. The _w64_gfx12 builtins are - // gated by the wavefrontsize64 target feature (set by -mwavefrontsize64), so - // __has_builtin serves as a constexpr-compatible proxy for wave64 detection on gfx12. - return 64; + return 64; // workaround: __AMDGCN_WAVEFRONT_SIZE__ removed in ROCm 7.2; _w64 builtins are gated by -mwavefrontsize64 target feature, so __has_builtin is a constexpr proxy #else return 32; #endif @@ -2339,8 +2335,7 @@ using mfma_scale_f32_16x16x128_fp4_fp4 = mfma_f32_16x16x128_fp4_fp4; static_cast(0), c, false, false); } // gfx12 (_w32_gfx12 suffix) builtins: 3-arg (A, B, C) — no matrix_fmts / neg_c slot. fp/same-type acc: -// The warp_size param (ws_) selects between _w32_gfx12 and _w64_gfx12 variants and is -// required because both shapes share the same {wm, wn, wk} triple. +// ws_ param selects _w32/_w64 variants (both share the same {wm,wn,wk} triple). #define DISPATCH_WMMA_GFX12_F32_(ta_, tb_, tc_, wm_, wn_, wk_, ws_, inst_) \ (std::is_same_v && std::is_same_v && std::is_same_v && \ wave_m == wm_ && wave_n == wn_ && wave_k == wk_ && warp_size == ws_) { return inst_(a, b, c); } @@ -2351,10 +2346,7 @@ using mfma_scale_f32_16x16x128_fp4_fp4 = mfma_f32_16x16x128_fp4_fp4; constexpr index_t i32_a = elem_a * static_cast(sizeof(dtype_a)) / static_cast(sizeof(i32_t)); \ constexpr index_t i32_b = elem_b * static_cast(sizeof(dtype_b)) / static_cast(sizeof(i32_t)); \ return inst_(__builtin_bit_cast(vector_t, a), __builtin_bit_cast(vector_t, b), c); } -// STEP_K: synthesize a larger K by chaining N copies of an inst with K=inst_k_. Matches -// the gfx942 mfma_f32_16x16x32_*_1k pattern (2× 16x16x16 inner). Lets a kernel use a -// wider A/B vector type per K-step so memory loads vectorize to b128 per lane (e.g. -// wave64 16x16x32 bf16 = 8 elem/lane = 16 bytes = 1 ds_load_b128 per lane). +// STEP_K: chain N copies of inst with K=inst_k_ to synthesize wider K (like mfma_*_1k). Wider per-lane A/B for b128 loads. #define DISPATCH_WMMA_GFX12_F32_STEP_K_(ta_, tb_, tc_, wm_, wn_, wk_, ws_, inst_k_, inst_) \ (std::is_same_v && std::is_same_v && std::is_same_v && \ wave_m == wm_ && wave_n == wn_ && wave_k == wk_ && warp_size == ws_) { \ @@ -2428,12 +2420,9 @@ struct wmma { else if constexpr DISPATCH_WMMA_8BIT_(bf8_t, bf8_t, fp16_t, 16, 16, 128, __builtin_amdgcn_wmma_f16_16x16x128_bf8_bf8) #endif #if defined(__gfx1201__) || defined(__gfx1200__) - // The _w32/_w64 builtins are gated by wavefrontsize32/wavefrontsize64 target - // features. Clang checks target features eagerly (even for unused if-constexpr - // branches), so we gate the dispatch lines with __has_builtin to only include - // the lines that match the active wave mode. + // _w32/_w64 builtins gated by wavefrontsize target feature; __has_builtin selects the active wave mode. # if __has_builtin(__builtin_amdgcn_wmma_f32_16x16x16_f16_w32_gfx12) - // gfx12 wave32 16x16x16: f16/bf16/fp8/bf8 → f32 + same-type (f16/bf16) acc. iu8/iu4 deferred. + // gfx12 wave32 16x16x16: f16/bf16/fp8/bf8 → f32 + same-type acc. iu8/iu4 deferred. else if constexpr DISPATCH_WMMA_GFX12_F32_ (fp16_t, fp16_t, fp32_t , 16, 16, 16, 32, __builtin_amdgcn_wmma_f32_16x16x16_f16_w32_gfx12) else if constexpr DISPATCH_WMMA_GFX12_F32_ (bf16_t, bf16_t, fp32_t , 16, 16, 16, 32, __builtin_amdgcn_wmma_f32_16x16x16_bf16_w32_gfx12) else if constexpr DISPATCH_WMMA_GFX12_F32_ (fp16_t, fp16_t, fp16_t , 16, 16, 16, 32, __builtin_amdgcn_wmma_f16_16x16x16_f16_w32_gfx12) @@ -2442,17 +2431,14 @@ struct wmma { else if constexpr DISPATCH_WMMA_GFX12_8BIT_(fp8_t , bf8_t , fp32_t , 16, 16, 16, 32, __builtin_amdgcn_wmma_f32_16x16x16_fp8_bf8_w32_gfx12) else if constexpr DISPATCH_WMMA_GFX12_8BIT_(bf8_t , fp8_t , fp32_t , 16, 16, 16, 32, __builtin_amdgcn_wmma_f32_16x16x16_bf8_fp8_w32_gfx12) else if constexpr DISPATCH_WMMA_GFX12_8BIT_(bf8_t , bf8_t , fp32_t , 16, 16, 16, 32, __builtin_amdgcn_wmma_f32_16x16x16_bf8_bf8_w32_gfx12) - // gfx12 wave32 16x16x32: synthetic — 2× stacked 16x16x16 along K. Per-lane A/B = - // 16 fp16/bf16 elem (= 32 B). Useful for wider per-K-step register footprint and - // amortizing inner-loop overhead across 2 wmma issues. + // gfx12 wave32 16x16x32 synthetic: 2× stacked 16x16x16 along K. else if constexpr DISPATCH_WMMA_GFX12_F32_STEP_K_(fp16_t, fp16_t, fp32_t , 16, 16, 32, 32, 16, __builtin_amdgcn_wmma_f32_16x16x16_f16_w32_gfx12) else if constexpr DISPATCH_WMMA_GFX12_F32_STEP_K_(bf16_t, bf16_t, fp32_t , 16, 16, 32, 32, 16, __builtin_amdgcn_wmma_f32_16x16x16_bf16_w32_gfx12) else if constexpr DISPATCH_WMMA_GFX12_F32_STEP_K_(fp16_t, fp16_t, fp16_t , 16, 16, 32, 32, 16, __builtin_amdgcn_wmma_f16_16x16x16_f16_w32_gfx12) else if constexpr DISPATCH_WMMA_GFX12_F32_STEP_K_(bf16_t, bf16_t, bf16_t , 16, 16, 32, 32, 16, __builtin_amdgcn_wmma_bf16_16x16x16_bf16_w32_gfx12) # endif // wave32 builtin available # if __has_builtin(__builtin_amdgcn_wmma_f32_16x16x16_f16_w64_gfx12) - // gfx12 wave64 16x16x16: native _w64_gfx12 builtins. Per-lane A/B = 4 elem - // (= 8 B = 1 b64 load per lane). Kernel must be compiled with wavefrontsize64. + // gfx12 wave64 16x16x16: native _w64_gfx12 builtins (4 elem/lane). Requires -mwavefrontsize64. else if constexpr DISPATCH_WMMA_GFX12_F32_ (fp16_t, fp16_t, fp32_t , 16, 16, 16, 64, __builtin_amdgcn_wmma_f32_16x16x16_f16_w64_gfx12) else if constexpr DISPATCH_WMMA_GFX12_F32_ (bf16_t, bf16_t, fp32_t , 16, 16, 16, 64, __builtin_amdgcn_wmma_f32_16x16x16_bf16_w64_gfx12) else if constexpr DISPATCH_WMMA_GFX12_F32_ (fp16_t, fp16_t, fp16_t , 16, 16, 16, 64, __builtin_amdgcn_wmma_f16_16x16x16_f16_w64_gfx12) @@ -2565,8 +2551,7 @@ struct wmma { #undef DISPATCH_WMMA_GFX12_8BIT_ #undef DISPATCH_WMMA_GFX12_F32_STEP_K_ -// gfx12 (gfx1200/gfx1201, Navi 44/48) wave32 16x16x16 aliases (default warp_size). -// The wave64 and synthetic 16x16x32 aliases follow below. +// gfx12 wave32 16x16x16 aliases (default warp_size). wave64 + synthetic 16x16x32 aliases below. using wmma_f32_16x16x16_f16 = wmma; using wmma_f16_16x16x16_f16 = wmma; using wmma_f32_16x16x16_bf16 = wmma; @@ -2575,17 +2560,17 @@ using wmma_f32_16x16x16_fp8_fp8 = wmma; using wmma_f32_16x16x16_fp8_bf8 = wmma; using wmma_f32_16x16x16_bf8_fp8 = wmma; using wmma_f32_16x16x16_bf8_bf8 = wmma; -// gfx12 wave32 16x16x32 synthetic (2× 16x16x16 stacked along K, see STEP_K dispatch). +// gfx12 wave32 16x16x32 synthetic (2× 16x16x16). using wmma_f32_16x16x32_f16_w32 = wmma; using wmma_f32_16x16x32_bf16_w32 = wmma; using wmma_f16_16x16x32_f16_w32 = wmma; using wmma_bf16_16x16x32_bf16_w32 = wmma; -// gfx12 wave64 16x16x16 native — kernel must be compiled with wavefrontsize64. +// gfx12 wave64 16x16x16 (-mwavefrontsize64). using wmma_f32_16x16x16_f16_w64 = wmma; using wmma_f32_16x16x16_bf16_w64 = wmma; using wmma_f16_16x16x16_f16_w64 = wmma; using wmma_bf16_16x16x16_bf16_w64 = wmma; -// gfx12 wave64 16x16x32 synthetic (2× 16x16x16_w64) — per-lane A/B = 16 B = 1 b128 load. +// gfx12 wave64 16x16x32 synthetic (2× 16x16x16_w64). using wmma_f32_16x16x32_f16_w64 = wmma; using wmma_f32_16x16x32_bf16_w64 = wmma; using wmma_f16_16x16x32_f16_w64 = wmma; From 46747bcd24ebc1a23efc5e00cd69b3033a55763c Mon Sep 17 00:00:00 2001 From: carlushuang Date: Sun, 17 May 2026 13:32:30 +0000 Subject: [PATCH 12/18] opus: black format on test_opus_device.py w64 additions --- op_tests/opus/device/test_opus_device.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/op_tests/opus/device/test_opus_device.py b/op_tests/opus/device/test_opus_device.py index 794ba71112..e8ac3c3ee2 100644 --- a/op_tests/opus/device/test_opus_device.py +++ b/op_tests/opus/device/test_opus_device.py @@ -1432,16 +1432,20 @@ def test_wmma_gfx1201_f32_bf8_bf8(mod): ) - # WMMA wave64 tests for gfx1200/gfx1201 (RDNA4). The _w64_gfx12 builtins # use 64 lanes with 4 elem/lane instead of wave32's 8 elem/lane. -def _test_wmma_gfx1201_w64_variant(mod, name, runner, in_dtype_a, in_dtype_b, out_dtype): + +def _test_wmma_gfx1201_w64_variant( + mod, name, runner, in_dtype_a, in_dtype_b, out_dtype +): arch = _get_gpu_arch() if arch not in _WMMA_GFX1201_ARCHS: print(f" SKIP: wmma_gfx1201_w64_{name} (arch={arch}, gfx1201-only)") return 0 - if _skip_if_missing_symbol(mod, f"run_wmma_gfx1201_w64_{name}", f"wmma_gfx1201_w64_{name}"): + if _skip_if_missing_symbol( + mod, f"run_wmma_gfx1201_w64_{name}", f"wmma_gfx1201_w64_{name}" + ): return 0 M = N = K = 16 From 1eb7ccac980ea87653d144e102fd2cd26e81c0b1 Mon Sep 17 00:00:00 2001 From: carlushuang Date: Sun, 17 May 2026 13:37:59 +0000 Subject: [PATCH 13/18] opus: factor out DISPATCH_WMMA_GFX12_MATCH_ to condense gfx12 dispatch macros --- csrc/include/opus/opus.hpp | 26 +++++++++++--------------- 1 file changed, 11 insertions(+), 15 deletions(-) diff --git a/csrc/include/opus/opus.hpp b/csrc/include/opus/opus.hpp index 5d2a431d12..3a9cb0ffe5 100644 --- a/csrc/include/opus/opus.hpp +++ b/csrc/include/opus/opus.hpp @@ -2334,26 +2334,21 @@ using mfma_scale_f32_16x16x128_fp4_fp4 = mfma_f32_16x16x128_fp4_fp4; __builtin_bit_cast(vector_t, b), \ static_cast(0), c, false, false); } -// gfx12 (_w32_gfx12 suffix) builtins: 3-arg (A, B, C) — no matrix_fmts / neg_c slot. fp/same-type acc: -// ws_ param selects _w32/_w64 variants (both share the same {wm,wn,wk} triple). +// gfx12 builtins: 3-arg (A, B, C) — no matrix_fmts/neg_c. ws_ selects _w32/_w64 (same {wm,wn,wk} triple). +#define DISPATCH_WMMA_GFX12_MATCH_(ta_, tb_, tc_, wm_, wn_, wk_, ws_) \ + (std::is_same_v && std::is_same_v && std::is_same_v && wave_m == wm_ && wave_n == wn_ && wave_k == wk_ && warp_size == ws_) #define DISPATCH_WMMA_GFX12_F32_(ta_, tb_, tc_, wm_, wn_, wk_, ws_, inst_) \ - (std::is_same_v && std::is_same_v && std::is_same_v && \ - wave_m == wm_ && wave_n == wn_ && wave_k == wk_ && warp_size == ws_) { return inst_(a, b, c); } -// fp8/bf8 variant: A/B reinterpreted as packed i32 vector then (A, B, C): -#define DISPATCH_WMMA_GFX12_8BIT_(ta_, tb_, tc_, wm_, wn_, wk_, ws_, inst_) \ - (std::is_same_v && std::is_same_v && std::is_same_v && \ - wave_m == wm_ && wave_n == wn_ && wave_k == wk_ && warp_size == ws_) { \ + DISPATCH_WMMA_GFX12_MATCH_(ta_, tb_, tc_, wm_, wn_, wk_, ws_) { return inst_(a, b, c); } +#define DISPATCH_WMMA_GFX12_8BIT_(ta_, tb_, tc_, wm_, wn_, wk_, ws_, inst_) /* fp8/bf8: A/B bitcast to i32 */ \ + DISPATCH_WMMA_GFX12_MATCH_(ta_, tb_, tc_, wm_, wn_, wk_, ws_) { \ constexpr index_t i32_a = elem_a * static_cast(sizeof(dtype_a)) / static_cast(sizeof(i32_t)); \ constexpr index_t i32_b = elem_b * static_cast(sizeof(dtype_b)) / static_cast(sizeof(i32_t)); \ return inst_(__builtin_bit_cast(vector_t, a), __builtin_bit_cast(vector_t, b), c); } -// STEP_K: chain N copies of inst with K=inst_k_ to synthesize wider K (like mfma_*_1k). Wider per-lane A/B for b128 loads. -#define DISPATCH_WMMA_GFX12_F32_STEP_K_(ta_, tb_, tc_, wm_, wn_, wk_, ws_, inst_k_, inst_) \ - (std::is_same_v && std::is_same_v && std::is_same_v && \ - wave_m == wm_ && wave_n == wn_ && wave_k == wk_ && warp_size == ws_) { \ - constexpr index_t steps = wk_ / inst_k_; \ - constexpr index_t e_a = elem_a / steps; constexpr index_t e_b = elem_b / steps; \ +#define DISPATCH_WMMA_GFX12_F32_STEP_K_(ta_, tb_, tc_, wm_, wn_, wk_, ws_, inst_k_, inst_) /* chain N×inst for wider K */ \ + DISPATCH_WMMA_GFX12_MATCH_(ta_, tb_, tc_, wm_, wn_, wk_, ws_) { \ + constexpr index_t steps = wk_ / inst_k_, e_a = elem_a / steps, e_b = elem_b / steps; \ auto tmp = inst_(slice(a, number<0>{}, number{}), slice(b, number<0>{}, number{}), c); \ - static_for([&](auto i){ tmp = inst_(slice(a, number{}, number{}), slice(b, number{}, number{}), tmp); }); \ + static_for([&](auto i){ tmp = inst_(slice(a, number{}, number{}), slice(b, number{}, number{}), tmp); }); \ return tmp; } template @@ -2547,6 +2542,7 @@ struct wmma { #undef DISPATCH_WMMA_ #undef DISPATCH_WMMA_BF16F32_ #undef DISPATCH_WMMA_8BIT_ +#undef DISPATCH_WMMA_GFX12_MATCH_ #undef DISPATCH_WMMA_GFX12_F32_ #undef DISPATCH_WMMA_GFX12_8BIT_ #undef DISPATCH_WMMA_GFX12_F32_STEP_K_ From 71e005e88d9a248214b90ad17a09895094e26c7b Mon Sep 17 00:00:00 2001 From: carlushuang Date: Sun, 17 May 2026 13:57:24 +0000 Subject: [PATCH 14/18] opus: fix gfx1201 test failures + define OPUS_GFX120X_IS_WAVE32 + code cleanup - Define OPUS_GFX120X_IS_WAVE32 macro (constexpr proxy via __has_builtin) and use it in get_warp_size() and wmma dispatch guards, replacing scattered __has_builtin checks. - Add gfx1201/gfx1200 to fp32_to_bf16 hardware conversion guard so gfx12 uses static_cast (RNE) instead of software truncate. - Add gfx1201/gfx1200 to _get_fp8_dtype/_get_bf8_dtype OCP arch set and bf8 has_infinity check in test_opus_device.py. - Condense fwd-decl adaptors to 2 lines, fix macro indent to 4-space, condense w64 STEP_K comment. --- csrc/include/opus/opus.hpp | 44 ++++++++++++++---------- op_tests/opus/device/test_opus_device.py | 13 +++++-- 2 files changed, 35 insertions(+), 22 deletions(-) diff --git a/csrc/include/opus/opus.hpp b/csrc/include/opus/opus.hpp index 3a9cb0ffe5..4db1dfb489 100644 --- a/csrc/include/opus/opus.hpp +++ b/csrc/include/opus/opus.hpp @@ -499,10 +499,8 @@ template struct tuple_element(bits >> 16); } -#if (defined(__gfx950__) || defined(__gfx1250__)) && __clang_major__ >= 20 -template // gfx950/gfx1250 has instruction conversion, leave 'rm' here for compatiblity +#if (defined(__gfx950__) || defined(__gfx1250__) || defined(__gfx1201__) || defined(__gfx1200__)) && __clang_major__ >= 20 +template // gfx950/gfx1250/gfx12 has instruction conversion, leave 'rm' here for compatiblity OPUS_D constexpr auto fp32_to_bf16(const fp32_t& x, number = {}) { return static_cast(x); } #else template // 0:standard, 1:truncate_with_nan, 2:truncate, 3:standard asm 4:rta_asm(round to nearest away) @@ -1502,14 +1500,25 @@ OPUS_D constexpr decltype(auto) cast(const S& s, Aux&&... aux) { // Guarded by OPUS_ENABLE_RUNTIME_QUERY (default 0). Define OPUS_ENABLE_RUNTIME_QUERY=1 before // including opus.hpp (or via compiler flag) to enable these functions and the hip_runtime_api.h include. // +// gfx1200/gfx1201 wave32/64 detection: __AMDGCN_WAVEFRONT_SIZE__ was removed in ROCm 7.2 and +// __builtin_amdgcn_wavefrontsize() is not constexpr. The _w32_gfx12 builtins are gated by the +// wavefrontsize32 target feature (set by -mwavefrontsize32, the default), so __has_builtin is a constexpr proxy. +#if (defined(__gfx1201__) || defined(__gfx1200__)) && defined(__HIP_DEVICE_COMPILE__) +# if __has_builtin(__builtin_amdgcn_wmma_f32_16x16x16_f16_w32_gfx12) +# define OPUS_GFX120X_IS_WAVE32 1 +# else +# define OPUS_GFX120X_IS_WAVE32 0 +# endif +#endif + OPUS_H_D constexpr index_t get_warp_size() { #if defined(__gfx1250__) return 32; #elif defined(__GFX9__) || !defined(__HIP_DEVICE_COMPILE__) return 64; -#elif (defined(__gfx1201__) || defined(__gfx1200__)) && __has_builtin(__builtin_amdgcn_wmma_f32_16x16x16_f16_w64_gfx12) - return 64; // workaround: __AMDGCN_WAVEFRONT_SIZE__ removed in ROCm 7.2; _w64 builtins are gated by -mwavefrontsize64 target feature, so __has_builtin is a constexpr proxy +#elif defined(OPUS_GFX120X_IS_WAVE32) && !OPUS_GFX120X_IS_WAVE32 + return 64; #else return 32; #endif @@ -2336,16 +2345,16 @@ using mfma_scale_f32_16x16x128_fp4_fp4 = mfma_f32_16x16x128_fp4_fp4; // gfx12 builtins: 3-arg (A, B, C) — no matrix_fmts/neg_c. ws_ selects _w32/_w64 (same {wm,wn,wk} triple). #define DISPATCH_WMMA_GFX12_MATCH_(ta_, tb_, tc_, wm_, wn_, wk_, ws_) \ - (std::is_same_v && std::is_same_v && std::is_same_v && wave_m == wm_ && wave_n == wn_ && wave_k == wk_ && warp_size == ws_) + (std::is_same_v && std::is_same_v && std::is_same_v && wave_m == wm_ && wave_n == wn_ && wave_k == wk_ && warp_size == ws_) #define DISPATCH_WMMA_GFX12_F32_(ta_, tb_, tc_, wm_, wn_, wk_, ws_, inst_) \ - DISPATCH_WMMA_GFX12_MATCH_(ta_, tb_, tc_, wm_, wn_, wk_, ws_) { return inst_(a, b, c); } + DISPATCH_WMMA_GFX12_MATCH_(ta_, tb_, tc_, wm_, wn_, wk_, ws_) { return inst_(a, b, c); } #define DISPATCH_WMMA_GFX12_8BIT_(ta_, tb_, tc_, wm_, wn_, wk_, ws_, inst_) /* fp8/bf8: A/B bitcast to i32 */ \ - DISPATCH_WMMA_GFX12_MATCH_(ta_, tb_, tc_, wm_, wn_, wk_, ws_) { \ + DISPATCH_WMMA_GFX12_MATCH_(ta_, tb_, tc_, wm_, wn_, wk_, ws_) { \ constexpr index_t i32_a = elem_a * static_cast(sizeof(dtype_a)) / static_cast(sizeof(i32_t)); \ constexpr index_t i32_b = elem_b * static_cast(sizeof(dtype_b)) / static_cast(sizeof(i32_t)); \ return inst_(__builtin_bit_cast(vector_t, a), __builtin_bit_cast(vector_t, b), c); } #define DISPATCH_WMMA_GFX12_F32_STEP_K_(ta_, tb_, tc_, wm_, wn_, wk_, ws_, inst_k_, inst_) /* chain N×inst for wider K */ \ - DISPATCH_WMMA_GFX12_MATCH_(ta_, tb_, tc_, wm_, wn_, wk_, ws_) { \ + DISPATCH_WMMA_GFX12_MATCH_(ta_, tb_, tc_, wm_, wn_, wk_, ws_) { \ constexpr index_t steps = wk_ / inst_k_, e_a = elem_a / steps, e_b = elem_b / steps; \ auto tmp = inst_(slice(a, number<0>{}, number{}), slice(b, number<0>{}, number{}), c); \ static_for([&](auto i){ tmp = inst_(slice(a, number{}, number{}), slice(b, number{}, number{}), tmp); }); \ @@ -2415,8 +2424,8 @@ struct wmma { else if constexpr DISPATCH_WMMA_8BIT_(bf8_t, bf8_t, fp16_t, 16, 16, 128, __builtin_amdgcn_wmma_f16_16x16x128_bf8_bf8) #endif #if defined(__gfx1201__) || defined(__gfx1200__) - // _w32/_w64 builtins gated by wavefrontsize target feature; __has_builtin selects the active wave mode. -# if __has_builtin(__builtin_amdgcn_wmma_f32_16x16x16_f16_w32_gfx12) + // _w32/_w64 builtins gated by wavefrontsize target feature; OPUS_GFX120X_IS_WAVE32 selects the active wave mode. +# if OPUS_GFX120X_IS_WAVE32 // gfx12 wave32 16x16x16: f16/bf16/fp8/bf8 → f32 + same-type acc. iu8/iu4 deferred. else if constexpr DISPATCH_WMMA_GFX12_F32_ (fp16_t, fp16_t, fp32_t , 16, 16, 16, 32, __builtin_amdgcn_wmma_f32_16x16x16_f16_w32_gfx12) else if constexpr DISPATCH_WMMA_GFX12_F32_ (bf16_t, bf16_t, fp32_t , 16, 16, 16, 32, __builtin_amdgcn_wmma_f32_16x16x16_bf16_w32_gfx12) @@ -2432,16 +2441,13 @@ struct wmma { else if constexpr DISPATCH_WMMA_GFX12_F32_STEP_K_(fp16_t, fp16_t, fp16_t , 16, 16, 32, 32, 16, __builtin_amdgcn_wmma_f16_16x16x16_f16_w32_gfx12) else if constexpr DISPATCH_WMMA_GFX12_F32_STEP_K_(bf16_t, bf16_t, bf16_t , 16, 16, 32, 32, 16, __builtin_amdgcn_wmma_bf16_16x16x16_bf16_w32_gfx12) # endif // wave32 builtin available -# if __has_builtin(__builtin_amdgcn_wmma_f32_16x16x16_f16_w64_gfx12) +# if !OPUS_GFX120X_IS_WAVE32 // gfx12 wave64 16x16x16: native _w64_gfx12 builtins (4 elem/lane). Requires -mwavefrontsize64. else if constexpr DISPATCH_WMMA_GFX12_F32_ (fp16_t, fp16_t, fp32_t , 16, 16, 16, 64, __builtin_amdgcn_wmma_f32_16x16x16_f16_w64_gfx12) else if constexpr DISPATCH_WMMA_GFX12_F32_ (bf16_t, bf16_t, fp32_t , 16, 16, 16, 64, __builtin_amdgcn_wmma_f32_16x16x16_bf16_w64_gfx12) else if constexpr DISPATCH_WMMA_GFX12_F32_ (fp16_t, fp16_t, fp16_t , 16, 16, 16, 64, __builtin_amdgcn_wmma_f16_16x16x16_f16_w64_gfx12) else if constexpr DISPATCH_WMMA_GFX12_F32_ (bf16_t, bf16_t, bf16_t , 16, 16, 16, 64, __builtin_amdgcn_wmma_bf16_16x16x16_bf16_w64_gfx12) - // gfx12 wave64 16x16x32: synthetic — 2× stacked _w64 16x16x16. Per-lane A/B = 8 - // elem (= 16 B = 1 ds_load_b128 / global_load_b128 per lane). The primary - // motivation for the wave64 path on gfx12: wider per-lane vector-load slot than - // wave64 16x16x16's natural 8 B per lane. + // gfx12 wave64 16x16x32 synthetic: 2× stacked _w64 16x16x16 (8 elem/lane = 1 b128 load). else if constexpr DISPATCH_WMMA_GFX12_F32_STEP_K_(fp16_t, fp16_t, fp32_t , 16, 16, 32, 64, 16, __builtin_amdgcn_wmma_f32_16x16x16_f16_w64_gfx12) else if constexpr DISPATCH_WMMA_GFX12_F32_STEP_K_(bf16_t, bf16_t, fp32_t , 16, 16, 32, 64, 16, __builtin_amdgcn_wmma_f32_16x16x16_bf16_w64_gfx12) else if constexpr DISPATCH_WMMA_GFX12_F32_STEP_K_(fp16_t, fp16_t, fp16_t , 16, 16, 32, 64, 16, __builtin_amdgcn_wmma_f16_16x16x16_f16_w64_gfx12) diff --git a/op_tests/opus/device/test_opus_device.py b/op_tests/opus/device/test_opus_device.py index e8ac3c3ee2..3130d48fe9 100644 --- a/op_tests/opus/device/test_opus_device.py +++ b/op_tests/opus/device/test_opus_device.py @@ -380,7 +380,7 @@ def _skip_if_missing_symbol(mod, sym, label): def _get_fp8_dtype(): """Return the correct FP8 (e4m3) torch dtype for the current GPU arch.""" arch = _get_gpu_arch() - if arch in ("gfx950", "gfx1250"): + if arch in ("gfx950", "gfx1250", "gfx1201", "gfx1200"): return torch.float8_e4m3fn return torch.float8_e4m3fnuz # gfx942 default @@ -388,7 +388,7 @@ def _get_fp8_dtype(): def _get_bf8_dtype(): """Return the correct BF8 (e5m2) torch dtype for the current GPU arch.""" arch = _get_gpu_arch() - if arch in ("gfx950", "gfx1250"): + if arch in ("gfx950", "gfx1250", "gfx1201", "gfx1200"): return torch.float8_e5m2 return torch.float8_e5m2fnuz # gfx942 default @@ -2300,7 +2300,14 @@ def ref_int(dtype, size): ("fp16", 5, 2, True, torch.float16, True), ("bf16", 10, 2, True, torch.bfloat16, True), ("fp8", 15, 1, True, fp8_dtype, False), - ("bf8", 20, 1, True, bf8_dtype, arch in ("gfx950", "gfx1250")), + ( + "bf8", + 20, + 1, + True, + bf8_dtype, + arch in ("gfx950", "gfx1250", "gfx1201", "gfx1200"), + ), ("i32", 25, 4, False, torch.int32, False), ("i16", 35, 2, False, torch.int16, False), ("i8", 45, 1, False, torch.int8, False), From 7c98ab43eaa2f2c3576ccfb70d12493c4a0a74fe Mon Sep 17 00:00:00 2001 From: carlushuang Date: Sun, 17 May 2026 14:05:21 +0000 Subject: [PATCH 15/18] opus: condense comments to single-line style --- csrc/include/opus/opus.hpp | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/csrc/include/opus/opus.hpp b/csrc/include/opus/opus.hpp index 4db1dfb489..744080ff2f 100644 --- a/csrc/include/opus/opus.hpp +++ b/csrc/include/opus/opus.hpp @@ -1500,9 +1500,7 @@ OPUS_D constexpr decltype(auto) cast(const S& s, Aux&&... aux) { // Guarded by OPUS_ENABLE_RUNTIME_QUERY (default 0). Define OPUS_ENABLE_RUNTIME_QUERY=1 before // including opus.hpp (or via compiler flag) to enable these functions and the hip_runtime_api.h include. // -// gfx1200/gfx1201 wave32/64 detection: __AMDGCN_WAVEFRONT_SIZE__ was removed in ROCm 7.2 and -// __builtin_amdgcn_wavefrontsize() is not constexpr. The _w32_gfx12 builtins are gated by the -// wavefrontsize32 target feature (set by -mwavefrontsize32, the default), so __has_builtin is a constexpr proxy. +// gfx12 wave32/64 detection: __AMDGCN_WAVEFRONT_SIZE__ removed in ROCm 7.2; _w32 builtins are gated by wavefrontsize32 target feature, so __has_builtin is a constexpr proxy. #if (defined(__gfx1201__) || defined(__gfx1200__)) && defined(__HIP_DEVICE_COMPILE__) # if __has_builtin(__builtin_amdgcn_wmma_f32_16x16x16_f16_w32_gfx12) # define OPUS_GFX120X_IS_WAVE32 1 @@ -2808,7 +2806,7 @@ template = {}) { return A{}(mfma(WaveMNK{}), get<1>(WaveMNK{}), get<2>(WaveMNK{}), warp_size_>{}); } #endif // __GFX9__ -// wmma_adaptor: layout encoding for wave32 WMMA (gfx1250). TODO: gfx12 (gfx1200/gfx1201) needs a dedicated adaptor — its fragment layout is asymmetric (A row-distributed, B/C column-distributed) per AMD RDNA4 ISA §7.12.2 / CK wmma_gemm.hpp. Until then gfx1201 callers use opus::wmma<> directly. +// wmma_adaptor: layout encoding for wave32 WMMA (gfx1250). // A:[(grpm_a

), (rept_a, grpk_a

, pack_a)], MxK // B:[(grpn_b

), (rept_b, grpk_b

, pack_b)], NxK // C:[(grpm_c

, rept_c, pack_c), (grpn_c

)], MxN From b4b2fc591d1dcaa08cb6134f0db28dad94f45c54 Mon Sep 17 00:00:00 2001 From: carlushuang Date: Sun, 17 May 2026 14:16:04 +0000 Subject: [PATCH 16/18] opus: add make_tiled_mma support for gfx1201/gfx1200 + tiled WMMA device tests --- csrc/include/opus/opus.hpp | 6 +- op_tests/opus/device/setup.py | 1 + op_tests/opus/device/test_opus_device.py | 81 ++++++++++++++++++ .../opus/device/test_wmma_gfx1201_tiled.cu | 84 +++++++++++++++++++ 4 files changed, 169 insertions(+), 3 deletions(-) create mode 100644 op_tests/opus/device/test_wmma_gfx1201_tiled.cu diff --git a/csrc/include/opus/opus.hpp b/csrc/include/opus/opus.hpp index 744080ff2f..3b51f4a877 100644 --- a/csrc/include/opus/opus.hpp +++ b/csrc/include/opus/opus.hpp @@ -2806,11 +2806,11 @@ template = {}) { return A{}(mfma(WaveMNK{}), get<1>(WaveMNK{}), get<2>(WaveMNK{}), warp_size_>{}); } #endif // __GFX9__ -// wmma_adaptor: layout encoding for wave32 WMMA (gfx1250). +// wmma_adaptor: layout encoding for wave32 WMMA (gfx1250, gfx1200/gfx1201). // A:[(grpm_a

), (rept_a, grpk_a

, pack_a)], MxK // B:[(grpn_b

), (rept_b, grpk_b

, pack_b)], NxK // C:[(grpm_c

, rept_c, pack_c), (grpn_c

)], MxN -#if defined(__gfx1250__) || !defined(__HIP_DEVICE_COMPILE__) +#if defined(__gfx1250__) || defined(__gfx1201__) || defined(__gfx1200__) || !defined(__HIP_DEVICE_COMPILE__) namespace impl { template struct wmma_adaptor : public remove_cvref_t { @@ -2889,7 +2889,7 @@ OPUS_D decltype(auto) make_wmma(number, number, number, A&& = {}, template OPUS_D decltype(auto) make_wmma(WaveMNK&&, A&& = {}, number = {}) { return A{}(wmma(WaveMNK{}), get<1>(WaveMNK{}), get<2>(WaveMNK{}), warp_size_>{}); } -#endif // __gfx1250__ +#endif // __gfx1250__ / __gfx1201__ / __gfx1200__ (wmma_adaptor) ///////////////////////////////////////////////////////////////////////////////////////////////////////// namespace impl { diff --git a/op_tests/opus/device/setup.py b/op_tests/opus/device/setup.py index ddc89ed0d4..85d9f5445c 100644 --- a/op_tests/opus/device/setup.py +++ b/op_tests/opus/device/setup.py @@ -43,6 +43,7 @@ "test_opus_gmem_gfx1201.cu", "test_wmma_gfx1201.cu", "test_wmma_gfx1201_w64.cu", + "test_wmma_gfx1201_tiled.cu", ] # Sources requiring -mwavefrontsize64 (wave64 builtins). diff --git a/op_tests/opus/device/test_opus_device.py b/op_tests/opus/device/test_opus_device.py index 3130d48fe9..e7016f7c7c 100644 --- a/op_tests/opus/device/test_opus_device.py +++ b/op_tests/opus/device/test_opus_device.py @@ -247,6 +247,26 @@ def run_wmma_gfx1201_w64_f16_f16(self, A, B, C): def run_wmma_gfx1201_w64_bf16_bf16(self, A, B, C): self._run_wmma_gfx1201_w64("bf16_bf16", A, B, C) + # -- wmma_gfx1201_tiled (make_tiled_mma + partition_layout, C = A @ B^T via swap_ab) -- + def _run_wmma_gfx1201_tiled(self, suffix, A, B, C): + fn = getattr(self._lib, f"run_wmma_gfx1201_tiled_{suffix}") + fn.restype = None + fn.argtypes = [_VP, _VP, _VP, _I, _I, _I] + fn( + self._ptr(A), + self._ptr(B), + self._ptr(C), + int(A.stride(0)), + int(B.stride(0)), + int(C.stride(0)), + ) + + def run_wmma_gfx1201_tiled_f32_f16(self, A, B, C): + self._run_wmma_gfx1201_tiled("f32_f16", A, B, C) + + def run_wmma_gfx1201_tiled_f32_bf16(self, A, B, C): + self._run_wmma_gfx1201_tiled("f32_bf16", A, B, C) + # -- async_load -- def run_async_load(self, Src, Dst): fn = self._lib.run_async_load @@ -1520,6 +1540,65 @@ def test_wmma_gfx1201_w64_bf16_bf16(mod): ) +# Tiled WMMA tests for gfx1200/gfx1201: make_tiled_mma + partition_layout + gmem (C = A @ B^T). + + +def _test_wmma_gfx1201_tiled_variant(mod, name, runner, in_dtype, out_dtype): + arch = _get_gpu_arch() + if arch not in _WMMA_GFX1201_ARCHS: + print(f" SKIP: wmma_gfx1201_tiled_{name} (arch={arch}, gfx1201-only)") + return 0 + if _skip_if_missing_symbol( + mod, f"run_wmma_gfx1201_tiled_{name}", f"wmma_gfx1201_tiled_{name}" + ): + return 0 + + M = N = K = 16 + device = torch.device("cuda") + torch.manual_seed(42) + A = torch.randn(M, K, dtype=torch.float32, device=device) * 2.0 + B = torch.randn(N, K, dtype=torch.float32, device=device) * 2.0 + a = A.to(in_dtype) + b = B.to(in_dtype) + C = torch.zeros(M, N, dtype=out_dtype, device=device) + Ref = (a.to(torch.float32) @ b.to(torch.float32).t()).to(out_dtype) + + runner(a, b, C) + + atol, rtol = _wmma_gfx1201_tolerances(out_dtype) + ok = torch.allclose(C.float(), Ref.float(), atol=atol, rtol=rtol) + max_diff = (C.float() - Ref.float()).abs().max().item() + if not ok: + print( + f" FAIL: wmma_gfx1201_tiled_{name} max_diff={max_diff:.4e} (atol={atol})" + ) + return 1 + print( + f" PASS: wmma_gfx1201_tiled_{name} (in={in_dtype}, out={out_dtype}, max_diff={max_diff:.4e})" + ) + return 0 + + +def test_wmma_gfx1201_tiled_f32_f16(mod): + return _test_wmma_gfx1201_tiled_variant( + mod, + "f32_f16", + mod.run_wmma_gfx1201_tiled_f32_f16, + torch.float16, + torch.float32, + ) + + +def test_wmma_gfx1201_tiled_f32_bf16(mod): + return _test_wmma_gfx1201_tiled_variant( + mod, + "f32_bf16", + mod.run_wmma_gfx1201_tiled_f32_bf16, + torch.bfloat16, + torch.float32, + ) + + def test_async_load(mod): """Test async_load: copy data through LDS and verify integrity.""" if _skip_if_missing_symbol(mod, "run_async_load", "async_load"): @@ -2557,6 +2636,8 @@ def main(): failures += test_wmma_gfx1201_w64_f32_bf16(mod) failures += test_wmma_gfx1201_w64_f16_f16(mod) failures += test_wmma_gfx1201_w64_bf16_bf16(mod) + failures += test_wmma_gfx1201_tiled_f32_f16(mod) + failures += test_wmma_gfx1201_tiled_f32_bf16(mod) failures += test_async_load(mod) failures += test_tr_load_f16(mod) failures += test_dtype_convert_fp32_bf16(mod) diff --git a/op_tests/opus/device/test_wmma_gfx1201_tiled.cu b/op_tests/opus/device/test_wmma_gfx1201_tiled.cu new file mode 100644 index 0000000000..3c79e0da82 --- /dev/null +++ b/op_tests/opus/device/test_wmma_gfx1201_tiled.cu @@ -0,0 +1,84 @@ +// SPDX-License-Identifier: MIT +// Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. +// Tiled WMMA test for gfx1201/gfx1200: make_tiled_mma + partition_layout + gmem load/store. +// Computes C = A @ B^T (swap_ab) with 16x16x16 wave32 WMMA, same pattern as test_wmma_f32.cu (gfx1250). + +#ifdef __HIP_DEVICE_COMPILE__ +#include "opus/opus.hpp" +#if defined(__gfx1201__) || defined(__gfx1200__) + +template +__global__ void wmma_gfx12_tiled_kernel( + const DIN* __restrict__ ptr_a, + const DIN* __restrict__ ptr_b, + DOUT* __restrict__ ptr_c, + int k, int stride_a, int stride_b, int stride_c) +{ + using opus::operator""_I; + constexpr int T_M = 1, T_N = 1, T_K = 1; + constexpr int E_M = 1, E_N = 1, E_K = 1; + constexpr int ELEM_A = WM * WK / 32; + constexpr int PACK_A = (16 / static_cast(sizeof(DIN)) < ELEM_A) ? 16 / static_cast(sizeof(DIN)) : ELEM_A; + constexpr int PACK_B = PACK_A; + constexpr int ELEM_C = WM * WN / 32; + constexpr int PACK_C = (16 / static_cast(sizeof(DOUT)) < ELEM_C) ? 16 / static_cast(sizeof(DOUT)) : ELEM_C; + using d_a = DIN; using d_b = DIN; using d_c = DOUT; + + int lane_id = static_cast(opus::lane_id()); + int g_im = __builtin_amdgcn_workgroup_id_x() * WM; + int g_in = __builtin_amdgcn_workgroup_id_y() * WN; + + auto mma = opus::make_tiled_mma( + opus::make_wmma(opus::seq{}, opus::wmma_adaptor_swap_ab{}), + opus::seq{}, opus::seq{}); + + auto u_a = opus::partition_layout_a(mma, opus::make_tuple(stride_a, 1_I), + opus::make_tuple(0_I, lane_id % mma.grpm_a, 0_I, lane_id / mma.grpm_a)); + auto u_b = opus::partition_layout_b(mma, opus::make_tuple(stride_b, 1_I), + opus::make_tuple(0_I, lane_id % mma.grpn_b, 0_I, lane_id / mma.grpn_b)); + auto u_c = opus::partition_layout_c(mma, opus::make_tuple(stride_c, 1_I), + opus::make_tuple(0_I, lane_id % mma.grpn_c, 0_I, lane_id / mma.grpn_c)); + + auto g_a = opus::make_gmem(ptr_a + g_im * stride_a); + auto g_b = opus::make_gmem(ptr_b + g_in * stride_b); + auto g_c = opus::make_gmem(ptr_c + g_im * stride_c + g_in); + + int loops = (k + WK - 1) / WK; + typename decltype(mma)::vtype_c v_c; + opus::clear(v_c); + for (int i = 0; i < loops; i++) { + auto v_a = g_a.template load(u_a); + u_a += WK; + auto v_b = g_b.template load(u_b); + u_b += WK; + v_c = mma(v_a, v_b, v_c); + } + g_c.template store(v_c, u_c); +} + +template __global__ void wmma_gfx12_tiled_kernel(const opus::fp16_t*, const opus::fp16_t*, opus::fp32_t*, int, int, int, int); +template __global__ void wmma_gfx12_tiled_kernel(const opus::bf16_t*, const opus::bf16_t*, opus::fp32_t*, int, int, int, int); + +#endif +#else +#include "opus/opus.hpp" +#include "opus/hip_minimal.hpp" +#include +#define HIP_CALL(call) do { hipError_t err = (call); if (err != hipSuccess) { fprintf(stderr, "HIP error %d at %s:%d\n", (int)err, __FILE__, __LINE__); return; } } while(0) + +template +__global__ void wmma_gfx12_tiled_kernel(const DIN*, const DIN*, DOUT*, int, int, int, int) {} + +#define LAUNCHER_(NAME, DIN, DOUT, WM, WN, WK) \ +extern "C" void run_wmma_gfx1201_tiled_ ## NAME ( \ + const void* d_a, const void* d_b, void* d_c, int stride_a, int stride_b, int stride_c) { \ + hipLaunchKernelGGL((wmma_gfx12_tiled_kernel), \ + dim3(1, 1), 32, 0, 0, \ + static_cast(d_a), static_cast(d_b), \ + static_cast(d_c), WK, stride_a, stride_b, stride_c); \ + HIP_CALL(hipGetLastError()); HIP_CALL(hipDeviceSynchronize()); } + +LAUNCHER_(f32_f16, fp16_t, fp32_t, 16, 16, 16) +LAUNCHER_(f32_bf16, bf16_t, fp32_t, 16, 16, 16) +#undef LAUNCHER_ +#endif From c601323aae23340e0e22724a152ce037009a0430 Mon Sep 17 00:00:00 2001 From: carlushuang Date: Mon, 18 May 2026 01:02:35 +0800 Subject: [PATCH 17/18] triton: add gfx1201 GEMM configs (copied from gfx1250 as baseline) --- .../gemm/gfx1201-GEMM-A16W16-ATOMIC.json | 15 ++++ .../gemm/gfx1201-GEMM-A16W16-gated.json | 74 +++++++++++++++++ .../configs/gemm/gfx1201-GEMM-A16W16.json | 80 +++++++++++++++++++ .../configs/gemm/gfx1201-GEMM-A8W8.json | 14 ++++ 4 files changed, 183 insertions(+) create mode 100644 aiter/ops/triton/configs/gemm/gfx1201-GEMM-A16W16-ATOMIC.json create mode 100644 aiter/ops/triton/configs/gemm/gfx1201-GEMM-A16W16-gated.json create mode 100644 aiter/ops/triton/configs/gemm/gfx1201-GEMM-A16W16.json create mode 100644 aiter/ops/triton/configs/gemm/gfx1201-GEMM-A8W8.json diff --git a/aiter/ops/triton/configs/gemm/gfx1201-GEMM-A16W16-ATOMIC.json b/aiter/ops/triton/configs/gemm/gfx1201-GEMM-A16W16-ATOMIC.json new file mode 100644 index 0000000000..44271f5634 --- /dev/null +++ b/aiter/ops/triton/configs/gemm/gfx1201-GEMM-A16W16-ATOMIC.json @@ -0,0 +1,15 @@ +{ + "any": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 4, + "NUM_KSPLIT": 1, + "cache_modifier": null, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 32, + "kpack": 1 + } +} diff --git a/aiter/ops/triton/configs/gemm/gfx1201-GEMM-A16W16-gated.json b/aiter/ops/triton/configs/gemm/gfx1201-GEMM-A16W16-gated.json new file mode 100644 index 0000000000..67448543a2 --- /dev/null +++ b/aiter/ops/triton/configs/gemm/gfx1201-GEMM-A16W16-gated.json @@ -0,0 +1,74 @@ +{ + "M_LEQ_64": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 3, + "matrix_instr_nonkdim": 16, + "cache_modifier": ".cg", + "kpack": 1 + }, + "M_LEQ_128": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 256, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 3, + "matrix_instr_nonkdim": 16, + "cache_modifier": ".cg", + "kpack": 1 + }, + "M_LEQ_256": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": ".cg", + "kpack": 1 + }, + "M_LEQ_512": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": ".cg", + "kpack": 1 + }, + "M_LEQ_2048": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "kpack": 1 + }, + "any": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 16, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "kpack": 1 + } +} diff --git a/aiter/ops/triton/configs/gemm/gfx1201-GEMM-A16W16.json b/aiter/ops/triton/configs/gemm/gfx1201-GEMM-A16W16.json new file mode 100644 index 0000000000..e593e2b10c --- /dev/null +++ b/aiter/ops/triton/configs/gemm/gfx1201-GEMM-A16W16.json @@ -0,0 +1,80 @@ +{ + "M_LEQ_64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 6, + "num_warps": 4, + "num_stages": 3, + "waves_per_eu": 3, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1, + "kpack": 1 + }, + "M_LEQ_128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 2, + "waves_per_eu": 3, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1, + "kpack": 1 + }, + "M_LEQ_256": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 4, + "num_warps": 4, + "num_stages": 3, + "waves_per_eu": 3, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1, + "kpack": 1 + }, + "M_LEQ_512": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": ".cg", + "NUM_KSPLIT": 1, + "kpack": 1 + }, + "M_LEQ_2048": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 3, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1, + "kpack": 1 + }, + "any": { + "BLOCK_SIZE_M": 256, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 6, + "num_warps": 8, + "num_stages": 3, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "cache_modifier": null, + "NUM_KSPLIT": 1, + "kpack": 1 + } +} diff --git a/aiter/ops/triton/configs/gemm/gfx1201-GEMM-A8W8.json b/aiter/ops/triton/configs/gemm/gfx1201-GEMM-A8W8.json new file mode 100644 index 0000000000..4f38cf580f --- /dev/null +++ b/aiter/ops/triton/configs/gemm/gfx1201-GEMM-A8W8.json @@ -0,0 +1,14 @@ +{ + "any": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 4, + "num_warps": 8, + "num_stages": 2, + "waves_per_eu": 2, + "matrix_instr_nonkdim": 16, + "kpack": 1, + "NUM_KSPLIT": 1 + } +} From e4ec1fdac761c7bdb62142eb1a5168cbbbe9413a Mon Sep 17 00:00:00 2001 From: carlushuang Date: Mon, 18 May 2026 00:15:36 +0000 Subject: [PATCH 18/18] Revert "triton: add gfx1201 GEMM configs (copied from gfx1250 as baseline)" This reverts commit c601323aae23340e0e22724a152ce037009a0430. --- .../gemm/gfx1201-GEMM-A16W16-ATOMIC.json | 15 ---- .../gemm/gfx1201-GEMM-A16W16-gated.json | 74 ----------------- .../configs/gemm/gfx1201-GEMM-A16W16.json | 80 ------------------- .../configs/gemm/gfx1201-GEMM-A8W8.json | 14 ---- 4 files changed, 183 deletions(-) delete mode 100644 aiter/ops/triton/configs/gemm/gfx1201-GEMM-A16W16-ATOMIC.json delete mode 100644 aiter/ops/triton/configs/gemm/gfx1201-GEMM-A16W16-gated.json delete mode 100644 aiter/ops/triton/configs/gemm/gfx1201-GEMM-A16W16.json delete mode 100644 aiter/ops/triton/configs/gemm/gfx1201-GEMM-A8W8.json diff --git a/aiter/ops/triton/configs/gemm/gfx1201-GEMM-A16W16-ATOMIC.json b/aiter/ops/triton/configs/gemm/gfx1201-GEMM-A16W16-ATOMIC.json deleted file mode 100644 index 44271f5634..0000000000 --- a/aiter/ops/triton/configs/gemm/gfx1201-GEMM-A16W16-ATOMIC.json +++ /dev/null @@ -1,15 +0,0 @@ -{ - "any": { - "BLOCK_SIZE_M": 256, - "BLOCK_SIZE_N": 256, - "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 4, - "NUM_KSPLIT": 1, - "cache_modifier": null, - "num_warps": 8, - "num_stages": 2, - "waves_per_eu": 2, - "matrix_instr_nonkdim": 32, - "kpack": 1 - } -} diff --git a/aiter/ops/triton/configs/gemm/gfx1201-GEMM-A16W16-gated.json b/aiter/ops/triton/configs/gemm/gfx1201-GEMM-A16W16-gated.json deleted file mode 100644 index 67448543a2..0000000000 --- a/aiter/ops/triton/configs/gemm/gfx1201-GEMM-A16W16-gated.json +++ /dev/null @@ -1,74 +0,0 @@ -{ - "M_LEQ_64": { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 256, - "GROUP_SIZE_M": 1, - "num_warps": 8, - "num_stages": 2, - "waves_per_eu": 3, - "matrix_instr_nonkdim": 16, - "cache_modifier": ".cg", - "kpack": 1 - }, - "M_LEQ_128": { - "BLOCK_SIZE_M": 128, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 256, - "GROUP_SIZE_M": 1, - "num_warps": 8, - "num_stages": 2, - "waves_per_eu": 3, - "matrix_instr_nonkdim": 16, - "cache_modifier": ".cg", - "kpack": 1 - }, - "M_LEQ_256": { - "BLOCK_SIZE_M": 256, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 1, - "num_warps": 8, - "num_stages": 2, - "waves_per_eu": 2, - "matrix_instr_nonkdim": 16, - "cache_modifier": ".cg", - "kpack": 1 - }, - "M_LEQ_512": { - "BLOCK_SIZE_M": 256, - "BLOCK_SIZE_N": 256, - "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 1, - "num_warps": 8, - "num_stages": 2, - "waves_per_eu": 2, - "matrix_instr_nonkdim": 16, - "cache_modifier": ".cg", - "kpack": 1 - }, - "M_LEQ_2048": { - "BLOCK_SIZE_M": 256, - "BLOCK_SIZE_N": 256, - "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 1, - "num_warps": 8, - "num_stages": 2, - "waves_per_eu": 2, - "matrix_instr_nonkdim": 16, - "cache_modifier": null, - "kpack": 1 - }, - "any": { - "BLOCK_SIZE_M": 256, - "BLOCK_SIZE_N": 256, - "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 16, - "num_warps": 8, - "num_stages": 2, - "waves_per_eu": 2, - "matrix_instr_nonkdim": 16, - "cache_modifier": null, - "kpack": 1 - } -} diff --git a/aiter/ops/triton/configs/gemm/gfx1201-GEMM-A16W16.json b/aiter/ops/triton/configs/gemm/gfx1201-GEMM-A16W16.json deleted file mode 100644 index e593e2b10c..0000000000 --- a/aiter/ops/triton/configs/gemm/gfx1201-GEMM-A16W16.json +++ /dev/null @@ -1,80 +0,0 @@ -{ - "M_LEQ_64": { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 6, - "num_warps": 4, - "num_stages": 3, - "waves_per_eu": 3, - "matrix_instr_nonkdim": 16, - "cache_modifier": null, - "NUM_KSPLIT": 1, - "kpack": 1 - }, - "M_LEQ_128": { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 1, - "num_warps": 4, - "num_stages": 2, - "waves_per_eu": 3, - "matrix_instr_nonkdim": 16, - "cache_modifier": null, - "NUM_KSPLIT": 1, - "kpack": 1 - }, - "M_LEQ_256": { - "BLOCK_SIZE_M": 128, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 32, - "GROUP_SIZE_M": 4, - "num_warps": 4, - "num_stages": 3, - "waves_per_eu": 3, - "matrix_instr_nonkdim": 16, - "cache_modifier": null, - "NUM_KSPLIT": 1, - "kpack": 1 - }, - "M_LEQ_512": { - "BLOCK_SIZE_M": 256, - "BLOCK_SIZE_N": 256, - "BLOCK_SIZE_K": 32, - "GROUP_SIZE_M": 1, - "num_warps": 8, - "num_stages": 3, - "waves_per_eu": 2, - "matrix_instr_nonkdim": 16, - "cache_modifier": ".cg", - "NUM_KSPLIT": 1, - "kpack": 1 - }, - "M_LEQ_2048": { - "BLOCK_SIZE_M": 256, - "BLOCK_SIZE_N": 256, - "BLOCK_SIZE_K": 32, - "GROUP_SIZE_M": 1, - "num_warps": 8, - "num_stages": 3, - "waves_per_eu": 2, - "matrix_instr_nonkdim": 16, - "cache_modifier": null, - "NUM_KSPLIT": 1, - "kpack": 1 - }, - "any": { - "BLOCK_SIZE_M": 256, - "BLOCK_SIZE_N": 256, - "BLOCK_SIZE_K": 32, - "GROUP_SIZE_M": 6, - "num_warps": 8, - "num_stages": 3, - "waves_per_eu": 2, - "matrix_instr_nonkdim": 16, - "cache_modifier": null, - "NUM_KSPLIT": 1, - "kpack": 1 - } -} diff --git a/aiter/ops/triton/configs/gemm/gfx1201-GEMM-A8W8.json b/aiter/ops/triton/configs/gemm/gfx1201-GEMM-A8W8.json deleted file mode 100644 index 4f38cf580f..0000000000 --- a/aiter/ops/triton/configs/gemm/gfx1201-GEMM-A8W8.json +++ /dev/null @@ -1,14 +0,0 @@ -{ - "any": { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 128, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 4, - "num_warps": 8, - "num_stages": 2, - "waves_per_eu": 2, - "matrix_instr_nonkdim": 16, - "kpack": 1, - "NUM_KSPLIT": 1 - } -}