diff --git a/csrc/include/opus/opus.hpp b/csrc/include/opus/opus.hpp index e0c84f7ecb..3b51f4a877 100644 --- a/csrc/include/opus/opus.hpp +++ b/csrc/include/opus/opus.hpp @@ -499,6 +499,9 @@ template struct tuple_element constexpr auto embed(const X& x, const Y& y, seq) { return ( ... + (get(x) * get(y))); } @@ -1101,8 +1104,8 @@ OPUS_D constexpr unsigned short fp32_to_bf16_rtn_raw(float f) else if(bits & 0xffff) { bits |= 0x10000; /* Preserve signaling NaN */ } return static_cast(bits >> 16); } -#if (defined(__gfx950__) || defined(__gfx1250__)) && __clang_major__ >= 20 -template // gfx950/gfx1250 has instruction conversion, leave 'rm' here for compatiblity +#if (defined(__gfx950__) || defined(__gfx1250__) || defined(__gfx1201__) || defined(__gfx1200__)) && __clang_major__ >= 20 +template // gfx950/gfx1250/gfx12 has instruction conversion, leave 'rm' here for compatiblity OPUS_D constexpr auto fp32_to_bf16(const fp32_t& x, number = {}) { return static_cast(x); } #else template // 0:standard, 1:truncate_with_nan, 2:truncate, 3:standard asm 4:rta_asm(round to nearest away) @@ -1497,12 +1500,23 @@ OPUS_D constexpr decltype(auto) cast(const S& s, Aux&&... aux) { // Guarded by OPUS_ENABLE_RUNTIME_QUERY (default 0). Define OPUS_ENABLE_RUNTIME_QUERY=1 before // including opus.hpp (or via compiler flag) to enable these functions and the hip_runtime_api.h include. // +// gfx12 wave32/64 detection: __AMDGCN_WAVEFRONT_SIZE__ removed in ROCm 7.2; _w32 builtins are gated by wavefrontsize32 target feature, so __has_builtin is a constexpr proxy. +#if (defined(__gfx1201__) || defined(__gfx1200__)) && defined(__HIP_DEVICE_COMPILE__) +# if __has_builtin(__builtin_amdgcn_wmma_f32_16x16x16_f16_w32_gfx12) +# define OPUS_GFX120X_IS_WAVE32 1 +# else +# define OPUS_GFX120X_IS_WAVE32 0 +# endif +#endif + OPUS_H_D constexpr index_t get_warp_size() { #if defined(__gfx1250__) return 32; #elif defined(__GFX9__) || !defined(__HIP_DEVICE_COMPILE__) return 64; +#elif defined(OPUS_GFX120X_IS_WAVE32) && !OPUS_GFX120X_IS_WAVE32 + return 64; #else return 32; #endif @@ -1613,7 +1627,8 @@ OPUS_D constexpr auto buffer_default_config() { return 0x00020000; #elif defined(__gfx103__) return 0x31014000; -#elif defined(__gfx11__) || defined(__gfx12__) || defined(__gfx1250__) +// gfx1200/gfx1201 (Navi 44/48) listed explicitly — __gfx11__/__gfx12__ are typos (clang predefines only uppercase __GFX11__/__GFX12__), so without this gfx1200/gfx1201 fall into the 0xffffffff sentinel and make_gmem<> stores silently drop. +#elif defined(__gfx11__) || defined(__gfx12__) || defined(__gfx1250__) || defined(__gfx1201__) || defined(__gfx1200__) return 0x31004000; #else return 0xffffffff; @@ -2302,8 +2317,8 @@ using mfma_scale_f32_16x16x128_fp4_fp4 = mfma_f32_16x16x128_fp4_fp4; #endif // __GFX9__ (mfma) ///////////////////////////////////////////////////////////////////////////////////////////////////////// -// wmma (gfx1250 / RDNA4, wave32) -#if defined(__gfx1250__) || !defined(__HIP_DEVICE_COMPILE__) +// wmma (RDNA4 / wave32) — gfx1250 uses wmma-256b builtins (16x16x{4,32,64,128}); gfx1200/gfx1201 (Navi 44/48) use wmma-128b _w32_gfx12 builtins (16x16x16). Dispatch macros for the two arg-list shapes differ — gfx12 set is DISPATCH_WMMA_GFX12_*. +#if defined(__gfx1250__) || defined(__gfx1201__) || defined(__gfx1200__) || !defined(__HIP_DEVICE_COMPILE__) // f16/bf16/f32 builtins: (neg_a, A, neg_b, B, matrix_fmts, C, clamp, neg_c) #define DISPATCH_WMMA_(ta_, tb_, tc_, wm_, wn_, wk_, inst_) \ (std::is_same_v && std::is_same_v && std::is_same_v && \ @@ -2326,6 +2341,23 @@ using mfma_scale_f32_16x16x128_fp4_fp4 = mfma_f32_16x16x128_fp4_fp4; __builtin_bit_cast(vector_t, b), \ static_cast(0), c, false, false); } +// gfx12 builtins: 3-arg (A, B, C) — no matrix_fmts/neg_c. ws_ selects _w32/_w64 (same {wm,wn,wk} triple). +#define DISPATCH_WMMA_GFX12_MATCH_(ta_, tb_, tc_, wm_, wn_, wk_, ws_) \ + (std::is_same_v && std::is_same_v && std::is_same_v && wave_m == wm_ && wave_n == wn_ && wave_k == wk_ && warp_size == ws_) +#define DISPATCH_WMMA_GFX12_F32_(ta_, tb_, tc_, wm_, wn_, wk_, ws_, inst_) \ + DISPATCH_WMMA_GFX12_MATCH_(ta_, tb_, tc_, wm_, wn_, wk_, ws_) { return inst_(a, b, c); } +#define DISPATCH_WMMA_GFX12_8BIT_(ta_, tb_, tc_, wm_, wn_, wk_, ws_, inst_) /* fp8/bf8: A/B bitcast to i32 */ \ + DISPATCH_WMMA_GFX12_MATCH_(ta_, tb_, tc_, wm_, wn_, wk_, ws_) { \ + constexpr index_t i32_a = elem_a * static_cast(sizeof(dtype_a)) / static_cast(sizeof(i32_t)); \ + constexpr index_t i32_b = elem_b * static_cast(sizeof(dtype_b)) / static_cast(sizeof(i32_t)); \ + return inst_(__builtin_bit_cast(vector_t, a), __builtin_bit_cast(vector_t, b), c); } +#define DISPATCH_WMMA_GFX12_F32_STEP_K_(ta_, tb_, tc_, wm_, wn_, wk_, ws_, inst_k_, inst_) /* chain N×inst for wider K */ \ + DISPATCH_WMMA_GFX12_MATCH_(ta_, tb_, tc_, wm_, wn_, wk_, ws_) { \ + constexpr index_t steps = wk_ / inst_k_, e_a = elem_a / steps, e_b = elem_b / steps; \ + auto tmp = inst_(slice(a, number<0>{}, number{}), slice(b, number<0>{}, number{}), c); \ + static_for([&](auto i){ tmp = inst_(slice(a, number{}, number{}), slice(b, number{}, number{}), tmp); }); \ + return tmp; } + template struct wmma { using dtype_a = remove_cvref_t; @@ -2389,6 +2421,37 @@ struct wmma { else if constexpr DISPATCH_WMMA_8BIT_(bf8_t, fp8_t, fp16_t, 16, 16, 128, __builtin_amdgcn_wmma_f16_16x16x128_bf8_fp8) else if constexpr DISPATCH_WMMA_8BIT_(bf8_t, bf8_t, fp16_t, 16, 16, 128, __builtin_amdgcn_wmma_f16_16x16x128_bf8_bf8) #endif +#if defined(__gfx1201__) || defined(__gfx1200__) + // _w32/_w64 builtins gated by wavefrontsize target feature; OPUS_GFX120X_IS_WAVE32 selects the active wave mode. +# if OPUS_GFX120X_IS_WAVE32 + // gfx12 wave32 16x16x16: f16/bf16/fp8/bf8 → f32 + same-type acc. iu8/iu4 deferred. + else if constexpr DISPATCH_WMMA_GFX12_F32_ (fp16_t, fp16_t, fp32_t , 16, 16, 16, 32, __builtin_amdgcn_wmma_f32_16x16x16_f16_w32_gfx12) + else if constexpr DISPATCH_WMMA_GFX12_F32_ (bf16_t, bf16_t, fp32_t , 16, 16, 16, 32, __builtin_amdgcn_wmma_f32_16x16x16_bf16_w32_gfx12) + else if constexpr DISPATCH_WMMA_GFX12_F32_ (fp16_t, fp16_t, fp16_t , 16, 16, 16, 32, __builtin_amdgcn_wmma_f16_16x16x16_f16_w32_gfx12) + else if constexpr DISPATCH_WMMA_GFX12_F32_ (bf16_t, bf16_t, bf16_t , 16, 16, 16, 32, __builtin_amdgcn_wmma_bf16_16x16x16_bf16_w32_gfx12) + else if constexpr DISPATCH_WMMA_GFX12_8BIT_(fp8_t , fp8_t , fp32_t , 16, 16, 16, 32, __builtin_amdgcn_wmma_f32_16x16x16_fp8_fp8_w32_gfx12) + else if constexpr DISPATCH_WMMA_GFX12_8BIT_(fp8_t , bf8_t , fp32_t , 16, 16, 16, 32, __builtin_amdgcn_wmma_f32_16x16x16_fp8_bf8_w32_gfx12) + else if constexpr DISPATCH_WMMA_GFX12_8BIT_(bf8_t , fp8_t , fp32_t , 16, 16, 16, 32, __builtin_amdgcn_wmma_f32_16x16x16_bf8_fp8_w32_gfx12) + else if constexpr DISPATCH_WMMA_GFX12_8BIT_(bf8_t , bf8_t , fp32_t , 16, 16, 16, 32, __builtin_amdgcn_wmma_f32_16x16x16_bf8_bf8_w32_gfx12) + // gfx12 wave32 16x16x32 synthetic: 2× stacked 16x16x16 along K. + else if constexpr DISPATCH_WMMA_GFX12_F32_STEP_K_(fp16_t, fp16_t, fp32_t , 16, 16, 32, 32, 16, __builtin_amdgcn_wmma_f32_16x16x16_f16_w32_gfx12) + else if constexpr DISPATCH_WMMA_GFX12_F32_STEP_K_(bf16_t, bf16_t, fp32_t , 16, 16, 32, 32, 16, __builtin_amdgcn_wmma_f32_16x16x16_bf16_w32_gfx12) + else if constexpr DISPATCH_WMMA_GFX12_F32_STEP_K_(fp16_t, fp16_t, fp16_t , 16, 16, 32, 32, 16, __builtin_amdgcn_wmma_f16_16x16x16_f16_w32_gfx12) + else if constexpr DISPATCH_WMMA_GFX12_F32_STEP_K_(bf16_t, bf16_t, bf16_t , 16, 16, 32, 32, 16, __builtin_amdgcn_wmma_bf16_16x16x16_bf16_w32_gfx12) +# endif // wave32 builtin available +# if !OPUS_GFX120X_IS_WAVE32 + // gfx12 wave64 16x16x16: native _w64_gfx12 builtins (4 elem/lane). Requires -mwavefrontsize64. + else if constexpr DISPATCH_WMMA_GFX12_F32_ (fp16_t, fp16_t, fp32_t , 16, 16, 16, 64, __builtin_amdgcn_wmma_f32_16x16x16_f16_w64_gfx12) + else if constexpr DISPATCH_WMMA_GFX12_F32_ (bf16_t, bf16_t, fp32_t , 16, 16, 16, 64, __builtin_amdgcn_wmma_f32_16x16x16_bf16_w64_gfx12) + else if constexpr DISPATCH_WMMA_GFX12_F32_ (fp16_t, fp16_t, fp16_t , 16, 16, 16, 64, __builtin_amdgcn_wmma_f16_16x16x16_f16_w64_gfx12) + else if constexpr DISPATCH_WMMA_GFX12_F32_ (bf16_t, bf16_t, bf16_t , 16, 16, 16, 64, __builtin_amdgcn_wmma_bf16_16x16x16_bf16_w64_gfx12) + // gfx12 wave64 16x16x32 synthetic: 2× stacked _w64 16x16x16 (8 elem/lane = 1 b128 load). + else if constexpr DISPATCH_WMMA_GFX12_F32_STEP_K_(fp16_t, fp16_t, fp32_t , 16, 16, 32, 64, 16, __builtin_amdgcn_wmma_f32_16x16x16_f16_w64_gfx12) + else if constexpr DISPATCH_WMMA_GFX12_F32_STEP_K_(bf16_t, bf16_t, fp32_t , 16, 16, 32, 64, 16, __builtin_amdgcn_wmma_f32_16x16x16_bf16_w64_gfx12) + else if constexpr DISPATCH_WMMA_GFX12_F32_STEP_K_(fp16_t, fp16_t, fp16_t , 16, 16, 32, 64, 16, __builtin_amdgcn_wmma_f16_16x16x16_f16_w64_gfx12) + else if constexpr DISPATCH_WMMA_GFX12_F32_STEP_K_(bf16_t, bf16_t, bf16_t , 16, 16, 32, 64, 16, __builtin_amdgcn_wmma_bf16_16x16x16_bf16_w64_gfx12) +# endif // wave64 builtin available +#endif // __gfx1201__ / __gfx1200__ __builtin_unreachable(); } @@ -2483,6 +2546,35 @@ struct wmma { #undef DISPATCH_WMMA_ #undef DISPATCH_WMMA_BF16F32_ #undef DISPATCH_WMMA_8BIT_ +#undef DISPATCH_WMMA_GFX12_MATCH_ +#undef DISPATCH_WMMA_GFX12_F32_ +#undef DISPATCH_WMMA_GFX12_8BIT_ +#undef DISPATCH_WMMA_GFX12_F32_STEP_K_ + +// gfx12 wave32 16x16x16 aliases (default warp_size). wave64 + synthetic 16x16x32 aliases below. +using wmma_f32_16x16x16_f16 = wmma; +using wmma_f16_16x16x16_f16 = wmma; +using wmma_f32_16x16x16_bf16 = wmma; +using wmma_bf16_16x16x16_bf16 = wmma; +using wmma_f32_16x16x16_fp8_fp8 = wmma; +using wmma_f32_16x16x16_fp8_bf8 = wmma; +using wmma_f32_16x16x16_bf8_fp8 = wmma; +using wmma_f32_16x16x16_bf8_bf8 = wmma; +// gfx12 wave32 16x16x32 synthetic (2× 16x16x16). +using wmma_f32_16x16x32_f16_w32 = wmma; +using wmma_f32_16x16x32_bf16_w32 = wmma; +using wmma_f16_16x16x32_f16_w32 = wmma; +using wmma_bf16_16x16x32_bf16_w32 = wmma; +// gfx12 wave64 16x16x16 (-mwavefrontsize64). +using wmma_f32_16x16x16_f16_w64 = wmma; +using wmma_f32_16x16x16_bf16_w64 = wmma; +using wmma_f16_16x16x16_f16_w64 = wmma; +using wmma_bf16_16x16x16_bf16_w64 = wmma; +// gfx12 wave64 16x16x32 synthetic (2× 16x16x16_w64). +using wmma_f32_16x16x32_f16_w64 = wmma; +using wmma_f32_16x16x32_bf16_w64 = wmma; +using wmma_f16_16x16x32_f16_w64 = wmma; +using wmma_bf16_16x16x32_bf16_w64 = wmma; // f16/bf16 16x16x32 using wmma_f32_16x16x32_f16 = wmma; @@ -2514,7 +2606,7 @@ using wmma_scale_f32_16x16x128_fp8_fp8 = wmma using wmma_scale_f32_16x16x128_fp4_fp4 = wmma; // Scaled WMMA (dedicated fp4 32x16x128 instruction) using wmma_scale_f32_32x16x128_fp4_fp4 = wmma; -#endif // __gfx1250__ (wmma) +#endif // __gfx1250__ / __gfx1201__ / __gfx1200__ (wmma) ///////////////////////////////////////////////////////////////////////////////////////////////////////// // adaptor @@ -2714,11 +2806,11 @@ template = {}) { return A{}(mfma(WaveMNK{}), get<1>(WaveMNK{}), get<2>(WaveMNK{}), warp_size_>{}); } #endif // __GFX9__ -// wmma_adaptor: same layout encoding as mfma_adaptor but for wave32 WMMA (gfx1250) +// wmma_adaptor: layout encoding for wave32 WMMA (gfx1250, gfx1200/gfx1201). // A:[(grpm_a

), (rept_a, grpk_a

, pack_a)], MxK // B:[(grpn_b

), (rept_b, grpk_b

, pack_b)], NxK // C:[(grpm_c

, rept_c, pack_c), (grpn_c

)], MxN -#if defined(__gfx1250__) || !defined(__HIP_DEVICE_COMPILE__) +#if defined(__gfx1250__) || defined(__gfx1201__) || defined(__gfx1200__) || !defined(__HIP_DEVICE_COMPILE__) namespace impl { template struct wmma_adaptor : public remove_cvref_t { @@ -2797,7 +2889,7 @@ OPUS_D decltype(auto) make_wmma(number, number, number, A&& = {}, template OPUS_D decltype(auto) make_wmma(WaveMNK&&, A&& = {}, number = {}) { return A{}(wmma(WaveMNK{}), get<1>(WaveMNK{}), get<2>(WaveMNK{}), warp_size_>{}); } -#endif // __gfx1250__ +#endif // __gfx1250__ / __gfx1201__ / __gfx1200__ (wmma_adaptor) ///////////////////////////////////////////////////////////////////////////////////////////////////////// namespace impl { diff --git a/op_tests/opus/device/setup.py b/op_tests/opus/device/setup.py index 376d8d2404..85d9f5445c 100644 --- a/op_tests/opus/device/setup.py +++ b/op_tests/opus/device/setup.py @@ -40,8 +40,15 @@ "test_numeric_limits.cu", "test_workgroup_barrier.cu", "test_finfo.cu", + "test_opus_gmem_gfx1201.cu", + "test_wmma_gfx1201.cu", + "test_wmma_gfx1201_w64.cu", + "test_wmma_gfx1201_tiled.cu", ] +# Sources requiring -mwavefrontsize64 (wave64 builtins). +_W64_SOURCES = {"test_wmma_gfx1201_w64.cu"} + def _detect_arch(): try: @@ -74,7 +81,8 @@ def _find_hipcc(): def _compile_one(args): """Compile a single .cu -> .o. Used as a worker function for parallel builds.""" - src, obj, hipcc, arch, verbose = args + src, obj, hipcc, arch, verbose, *rest = args + extra_flags = rest[0] if rest else [] cmd = [ hipcc, f"--offload-arch={arch}", @@ -83,6 +91,7 @@ def _compile_one(args): "-D__HIPCC_RTC__", f"-I{_REPO_CSRC}", f"-I{_THIS_DIR}", + *extra_flags, "-c", src, "-o", @@ -109,14 +118,33 @@ def build(verbose=False, jobs=None): if verbose: print(f"[setup] arch={arch}, jobs={jobs}") + # Per-arch skip list: kernels that use builtins not available on the + # target arch. Skipped at .so build time so the rest of the suite + # still links; the Python harness sees the missing extern "C" launcher + # and reports SKIP for those tests. + # + # gfx1201 / gfx1200 (Navi 44/48, RDNA4): opus _async_load uses + # __builtin_amdgcn_raw_ptr_buffer_load_lds which needs the + # `vmem-to-lds-load-insts` target feature (gfx9x / gfx950 / gfx1250 only). + _ARCH_SKIP_SOURCES = { + "gfx1200": {"test_async_load.cu", "test_load_store_if.cu"}, + "gfx1201": {"test_async_load.cu", "test_load_store_if.cu"}, + } + skip = _ARCH_SKIP_SOURCES.get(arch, set()) + sources = [s for s in _CU_SOURCES if s not in skip] + if verbose and skip: + for s in sorted(skip): + print(f"[setup] skip {s} (incompatible with arch={arch})") + t0 = time.monotonic() # Parallel compile: each .cu -> .o tasks = [] - for s in _CU_SOURCES: + for s in sources: src = os.path.join(_THIS_DIR, s) obj = os.path.join(_THIS_DIR, s.replace(".cu", ".o")) - tasks.append((src, obj, hipcc, arch, verbose)) + extra = ["-mwavefrontsize64"] if s in _W64_SOURCES else [] + tasks.append((src, obj, hipcc, arch, verbose, extra)) objs = [] with ProcessPoolExecutor(max_workers=jobs) as pool: diff --git a/op_tests/opus/device/test_opus_device.py b/op_tests/opus/device/test_opus_device.py index 90e5a3f05e..e7016f7c7c 100644 --- a/op_tests/opus/device/test_opus_device.py +++ b/op_tests/opus/device/test_opus_device.py @@ -176,6 +176,97 @@ def run_vector_add(self, A, B, Result): fn.argtypes = [_VP, _VP, _VP, _I] fn(self._ptr(A), self._ptr(B), self._ptr(Result), int(A.numel())) + # -- opus_gmem_gfx1201 (verifies opus.hpp parses + opus utils work on gfx1201) -- + def run_opus_gmem_gfx1201(self, A, B, Result): + fn = self._lib.run_opus_gmem_gfx1201 + fn.restype = None + fn.argtypes = [_VP, _VP, _VP, _I] + fn(self._ptr(A), self._ptr(B), self._ptr(Result), int(A.numel())) + + # -- wmma_gfx1201 (8 wave32 16x16x16 variants via __builtin_amdgcn_wmma_*_w32_gfx12) -- + def _run_wmma_gfx1201(self, suffix, A, B, C): + fn = getattr(self._lib, f"run_wmma_gfx1201_{suffix}") + fn.restype = None + fn.argtypes = [_VP, _VP, _VP, _I, _I, _I] + fn( + self._ptr(A), + self._ptr(B), + self._ptr(C), + int(A.stride(0)), + int(B.stride(0)), + int(C.stride(0)), + ) + + def run_wmma_gfx1201_f32_f16(self, A, B, C): + self._run_wmma_gfx1201("f32_f16", A, B, C) + + def run_wmma_gfx1201_f32_bf16(self, A, B, C): + self._run_wmma_gfx1201("f32_bf16", A, B, C) + + def run_wmma_gfx1201_f16_f16(self, A, B, C): + self._run_wmma_gfx1201("f16_f16", A, B, C) + + def run_wmma_gfx1201_bf16_bf16(self, A, B, C): + self._run_wmma_gfx1201("bf16_bf16", A, B, C) + + def run_wmma_gfx1201_f32_fp8_fp8(self, A, B, C): + self._run_wmma_gfx1201("f32_fp8_fp8", A, B, C) + + def run_wmma_gfx1201_f32_fp8_bf8(self, A, B, C): + self._run_wmma_gfx1201("f32_fp8_bf8", A, B, C) + + def run_wmma_gfx1201_f32_bf8_fp8(self, A, B, C): + self._run_wmma_gfx1201("f32_bf8_fp8", A, B, C) + + def run_wmma_gfx1201_f32_bf8_bf8(self, A, B, C): + self._run_wmma_gfx1201("f32_bf8_bf8", A, B, C) + + # -- wmma_gfx1201_w64 (4 wave64 16x16x16 variants via __builtin_amdgcn_wmma_*_w64_gfx12) -- + def _run_wmma_gfx1201_w64(self, suffix, A, B, C): + fn = getattr(self._lib, f"run_wmma_gfx1201_w64_{suffix}") + fn.restype = None + fn.argtypes = [_VP, _VP, _VP, _I, _I, _I] + fn( + self._ptr(A), + self._ptr(B), + self._ptr(C), + int(A.stride(0)), + int(B.stride(0)), + int(C.stride(0)), + ) + + def run_wmma_gfx1201_w64_f32_f16(self, A, B, C): + self._run_wmma_gfx1201_w64("f32_f16", A, B, C) + + def run_wmma_gfx1201_w64_f32_bf16(self, A, B, C): + self._run_wmma_gfx1201_w64("f32_bf16", A, B, C) + + def run_wmma_gfx1201_w64_f16_f16(self, A, B, C): + self._run_wmma_gfx1201_w64("f16_f16", A, B, C) + + def run_wmma_gfx1201_w64_bf16_bf16(self, A, B, C): + self._run_wmma_gfx1201_w64("bf16_bf16", A, B, C) + + # -- wmma_gfx1201_tiled (make_tiled_mma + partition_layout, C = A @ B^T via swap_ab) -- + def _run_wmma_gfx1201_tiled(self, suffix, A, B, C): + fn = getattr(self._lib, f"run_wmma_gfx1201_tiled_{suffix}") + fn.restype = None + fn.argtypes = [_VP, _VP, _VP, _I, _I, _I] + fn( + self._ptr(A), + self._ptr(B), + self._ptr(C), + int(A.stride(0)), + int(B.stride(0)), + int(C.stride(0)), + ) + + def run_wmma_gfx1201_tiled_f32_f16(self, A, B, C): + self._run_wmma_gfx1201_tiled("f32_f16", A, B, C) + + def run_wmma_gfx1201_tiled_f32_bf16(self, A, B, C): + self._run_wmma_gfx1201_tiled("f32_bf16", A, B, C) + # -- async_load -- def run_async_load(self, Src, Dst): fn = self._lib.run_async_load @@ -274,6 +365,15 @@ def _get_gpu_arch(): return getattr(props, "gcnArchName", "").split(":")[0] +def _skip_if_missing_symbol(mod, sym, label): + """Print SKIP + return True if the .so wasn't built with `sym` (setup.py + skips arch-incompatible sources per _ARCH_SKIP_SOURCES).""" + if not hasattr(mod._lib, sym): + print(f" SKIP: {label} ({sym} not built for arch={_get_gpu_arch()})") + return True + return False + + # --------------------------------------------------------------------------- # Individual test functions # --------------------------------------------------------------------------- @@ -300,7 +400,7 @@ def _get_gpu_arch(): def _get_fp8_dtype(): """Return the correct FP8 (e4m3) torch dtype for the current GPU arch.""" arch = _get_gpu_arch() - if arch in ("gfx950", "gfx1250"): + if arch in ("gfx950", "gfx1250", "gfx1201", "gfx1200"): return torch.float8_e4m3fn return torch.float8_e4m3fnuz # gfx942 default @@ -308,7 +408,7 @@ def _get_fp8_dtype(): def _get_bf8_dtype(): """Return the correct BF8 (e5m2) torch dtype for the current GPU arch.""" arch = _get_gpu_arch() - if arch in ("gfx950", "gfx1250"): + if arch in ("gfx950", "gfx1250", "gfx1201", "gfx1200"): return torch.float8_e5m2 return torch.float8_e5m2fnuz # gfx942 default @@ -1166,8 +1266,343 @@ def test_vector_add(mod): return 0 +# Archs where the opus.hpp gfx12 (Navi 44/48 RDNA4) path is active. The kernel +# body in test_opus_gmem_gfx1201.cu is gated by __gfx1201__ / __gfx1200__ — +# on other archs the launcher runs an empty kernel, so we skip the correctness +# check to avoid a misleading failure. +_OPUS_PARSE_GFX1201_ARCHS = {"gfx1201", "gfx1200"} + + +def test_opus_gmem_gfx1201(mod): + """Verify opus.hpp parses + opus utilities (make_gmem / .load/.store / + cast) work on gfx1201. Mirrors the load → cast → store pattern + in sample_kernels.cu. Skips on other archs (kernel body is gfx1201-only).""" + arch = _get_gpu_arch() + if arch not in _OPUS_PARSE_GFX1201_ARCHS: + print(f" SKIP: opus_gmem_gfx1201 (arch={arch}, gfx1201-only test)") + return 0 + + n = 1310720 + device = torch.device("cuda") + dtype = torch.float32 + + torch.manual_seed(42) + A = torch.randn(n, device=device, dtype=dtype) + B = torch.randn(n, device=device, dtype=dtype) + Result = torch.empty(n, device=device, dtype=dtype) + + mod.run_opus_gmem_gfx1201(A, B, Result) + + Ref = A + B + + atol, rtol = 1e-5, 1e-5 + ok = torch.allclose(Result, Ref, atol=atol, rtol=rtol) + max_diff = (Result - Ref).abs().max().item() + if not ok: + diff_count = (Result - Ref).abs().gt(atol + rtol * Ref.abs()).sum().item() + print( + f" FAIL: opus_gmem_gfx1201 max_diff={max_diff:.6e}, " + f"{diff_count} elements outside tol" + ) + return 1 + print(f" PASS: opus_gmem_gfx1201 (arch={arch}, n={n}), max_diff={max_diff:.6e}") + return 0 + + +# WMMA tests for gfx1200/gfx1201 (Navi 44/48, RDNA4). Both archs share the +# same gfx12 wmma-128b ISA so the kernel bodies in test_wmma_gfx1201.cu are +# gated by __gfx1201__ / __gfx1200__ — on other archs the launcher runs an +# empty kernel so we skip the correctness check. +_WMMA_GFX1201_ARCHS = {"gfx1201", "gfx1200"} + + +def _wmma_gfx1201_tolerances(out_dtype): + # f32 acc is bit-exact against the FP32 reference matmul; f16/bf16 acc + # picks up one ULP of rounding error. + if out_dtype == torch.float32: + return 5e-2, 1e-2 + if out_dtype == torch.float16: + return 1e-1, 1e-2 + if out_dtype == torch.bfloat16: + return 5e-1, 5e-2 + return 1e-2, 1e-2 + + +def _test_wmma_gfx1201_variant(mod, name, runner, in_dtype_a, in_dtype_b, out_dtype): + """Drive one wmma_gfx1201 variant: build random 16x16 A/B, call the + WMMA kernel, compare against fp32 torch matmul cast to out_dtype.""" + arch = _get_gpu_arch() + if arch not in _WMMA_GFX1201_ARCHS: + print(f" SKIP: wmma_gfx1201_{name} (arch={arch}, gfx1201-only)") + return 0 + + M = N = K = 16 + device = torch.device("cuda") + + torch.manual_seed(42) + a_ref = torch.randn(M, K, dtype=torch.float32, device=device) * 2.0 + b_ref = torch.randn(K, N, dtype=torch.float32, device=device) * 2.0 + A = a_ref.to(in_dtype_a) + B = b_ref.to(in_dtype_b) + C = torch.zeros(M, N, dtype=out_dtype, device=device) + + Ref = (A.to(torch.float32) @ B.to(torch.float32)).to(out_dtype) + + runner(A, B, C) + + atol, rtol = _wmma_gfx1201_tolerances(out_dtype) + Cf = C.to(torch.float32) + Rf = Ref.to(torch.float32) + ok = torch.allclose(Cf, Rf, atol=atol, rtol=rtol) + max_diff = (Cf - Rf).abs().max().item() + if not ok: + print(f" FAIL: wmma_gfx1201_{name} max_diff={max_diff:.4e} (atol={atol})") + return 1 + print( + f" PASS: wmma_gfx1201_{name} (in=({in_dtype_a}, {in_dtype_b}), out={out_dtype}, max_diff={max_diff:.4e})" + ) + return 0 + + +def test_wmma_gfx1201_f32_f16(mod): + return _test_wmma_gfx1201_variant( + mod, + "f32_f16", + mod.run_wmma_gfx1201_f32_f16, + torch.float16, + torch.float16, + torch.float32, + ) + + +def test_wmma_gfx1201_f32_bf16(mod): + return _test_wmma_gfx1201_variant( + mod, + "f32_bf16", + mod.run_wmma_gfx1201_f32_bf16, + torch.bfloat16, + torch.bfloat16, + torch.float32, + ) + + +def test_wmma_gfx1201_f16_f16(mod): + return _test_wmma_gfx1201_variant( + mod, + "f16_f16", + mod.run_wmma_gfx1201_f16_f16, + torch.float16, + torch.float16, + torch.float16, + ) + + +def test_wmma_gfx1201_bf16_bf16(mod): + return _test_wmma_gfx1201_variant( + mod, + "bf16_bf16", + mod.run_wmma_gfx1201_bf16_bf16, + torch.bfloat16, + torch.bfloat16, + torch.bfloat16, + ) + + +def test_wmma_gfx1201_f32_fp8_fp8(mod): + return _test_wmma_gfx1201_variant( + mod, + "f32_fp8_fp8", + mod.run_wmma_gfx1201_f32_fp8_fp8, + torch.float8_e4m3fn, + torch.float8_e4m3fn, + torch.float32, + ) + + +def test_wmma_gfx1201_f32_fp8_bf8(mod): + return _test_wmma_gfx1201_variant( + mod, + "f32_fp8_bf8", + mod.run_wmma_gfx1201_f32_fp8_bf8, + torch.float8_e4m3fn, + torch.float8_e5m2, + torch.float32, + ) + + +def test_wmma_gfx1201_f32_bf8_fp8(mod): + return _test_wmma_gfx1201_variant( + mod, + "f32_bf8_fp8", + mod.run_wmma_gfx1201_f32_bf8_fp8, + torch.float8_e5m2, + torch.float8_e4m3fn, + torch.float32, + ) + + +def test_wmma_gfx1201_f32_bf8_bf8(mod): + return _test_wmma_gfx1201_variant( + mod, + "f32_bf8_bf8", + mod.run_wmma_gfx1201_f32_bf8_bf8, + torch.float8_e5m2, + torch.float8_e5m2, + torch.float32, + ) + + +# WMMA wave64 tests for gfx1200/gfx1201 (RDNA4). The _w64_gfx12 builtins +# use 64 lanes with 4 elem/lane instead of wave32's 8 elem/lane. + + +def _test_wmma_gfx1201_w64_variant( + mod, name, runner, in_dtype_a, in_dtype_b, out_dtype +): + arch = _get_gpu_arch() + if arch not in _WMMA_GFX1201_ARCHS: + print(f" SKIP: wmma_gfx1201_w64_{name} (arch={arch}, gfx1201-only)") + return 0 + if _skip_if_missing_symbol( + mod, f"run_wmma_gfx1201_w64_{name}", f"wmma_gfx1201_w64_{name}" + ): + return 0 + + M = N = K = 16 + device = torch.device("cuda") + + torch.manual_seed(42) + a_ref = torch.randn(M, K, dtype=torch.float32, device=device) * 2.0 + b_ref = torch.randn(K, N, dtype=torch.float32, device=device) * 2.0 + A = a_ref.to(in_dtype_a) + B = b_ref.to(in_dtype_b) + C = torch.zeros(M, N, dtype=out_dtype, device=device) + + Ref = (A.to(torch.float32) @ B.to(torch.float32)).to(out_dtype) + + runner(A, B, C) + + atol, rtol = _wmma_gfx1201_tolerances(out_dtype) + Cf = C.to(torch.float32) + Rf = Ref.to(torch.float32) + ok = torch.allclose(Cf, Rf, atol=atol, rtol=rtol) + max_diff = (Cf - Rf).abs().max().item() + if not ok: + print(f" FAIL: wmma_gfx1201_w64_{name} max_diff={max_diff:.4e} (atol={atol})") + return 1 + print( + f" PASS: wmma_gfx1201_w64_{name} (in=({in_dtype_a}, {in_dtype_b}), out={out_dtype}, max_diff={max_diff:.4e})" + ) + return 0 + + +def test_wmma_gfx1201_w64_f32_f16(mod): + return _test_wmma_gfx1201_w64_variant( + mod, + "f32_f16", + mod.run_wmma_gfx1201_w64_f32_f16, + torch.float16, + torch.float16, + torch.float32, + ) + + +def test_wmma_gfx1201_w64_f32_bf16(mod): + return _test_wmma_gfx1201_w64_variant( + mod, + "f32_bf16", + mod.run_wmma_gfx1201_w64_f32_bf16, + torch.bfloat16, + torch.bfloat16, + torch.float32, + ) + + +def test_wmma_gfx1201_w64_f16_f16(mod): + return _test_wmma_gfx1201_w64_variant( + mod, + "f16_f16", + mod.run_wmma_gfx1201_w64_f16_f16, + torch.float16, + torch.float16, + torch.float16, + ) + + +def test_wmma_gfx1201_w64_bf16_bf16(mod): + return _test_wmma_gfx1201_w64_variant( + mod, + "bf16_bf16", + mod.run_wmma_gfx1201_w64_bf16_bf16, + torch.bfloat16, + torch.bfloat16, + torch.bfloat16, + ) + + +# Tiled WMMA tests for gfx1200/gfx1201: make_tiled_mma + partition_layout + gmem (C = A @ B^T). + + +def _test_wmma_gfx1201_tiled_variant(mod, name, runner, in_dtype, out_dtype): + arch = _get_gpu_arch() + if arch not in _WMMA_GFX1201_ARCHS: + print(f" SKIP: wmma_gfx1201_tiled_{name} (arch={arch}, gfx1201-only)") + return 0 + if _skip_if_missing_symbol( + mod, f"run_wmma_gfx1201_tiled_{name}", f"wmma_gfx1201_tiled_{name}" + ): + return 0 + + M = N = K = 16 + device = torch.device("cuda") + torch.manual_seed(42) + A = torch.randn(M, K, dtype=torch.float32, device=device) * 2.0 + B = torch.randn(N, K, dtype=torch.float32, device=device) * 2.0 + a = A.to(in_dtype) + b = B.to(in_dtype) + C = torch.zeros(M, N, dtype=out_dtype, device=device) + Ref = (a.to(torch.float32) @ b.to(torch.float32).t()).to(out_dtype) + + runner(a, b, C) + + atol, rtol = _wmma_gfx1201_tolerances(out_dtype) + ok = torch.allclose(C.float(), Ref.float(), atol=atol, rtol=rtol) + max_diff = (C.float() - Ref.float()).abs().max().item() + if not ok: + print( + f" FAIL: wmma_gfx1201_tiled_{name} max_diff={max_diff:.4e} (atol={atol})" + ) + return 1 + print( + f" PASS: wmma_gfx1201_tiled_{name} (in={in_dtype}, out={out_dtype}, max_diff={max_diff:.4e})" + ) + return 0 + + +def test_wmma_gfx1201_tiled_f32_f16(mod): + return _test_wmma_gfx1201_tiled_variant( + mod, + "f32_f16", + mod.run_wmma_gfx1201_tiled_f32_f16, + torch.float16, + torch.float32, + ) + + +def test_wmma_gfx1201_tiled_f32_bf16(mod): + return _test_wmma_gfx1201_tiled_variant( + mod, + "f32_bf16", + mod.run_wmma_gfx1201_tiled_f32_bf16, + torch.bfloat16, + torch.float32, + ) + + def test_async_load(mod): """Test async_load: copy data through LDS and verify integrity.""" + if _skip_if_missing_symbol(mod, "run_async_load", "async_load"): + return 0 # n should be a multiple of BLOCK_SIZE (256) n = 1048576 # 1M elements device = torch.device("cuda") @@ -1670,6 +2105,8 @@ def test_dtype_convert_fp32_fp4_x4(mod): def test_predicated_copy(mod): """Test gmem load_if/store_if via free function wrappers (boundary predicate).""" + if _skip_if_missing_symbol(mod, "run_predicated_copy", "predicated_copy"): + return 0 # Use n not aligned to block*4 to create a partial boundary condition n = 1001 BLOCK_SIZE = 256 @@ -1716,6 +2153,8 @@ def test_predicated_copy_2d(mod): Uses a 2D layout with row/col boundary checking — the predicate receives (i_row, i_col) and uses them to check bounds, which would fail if given a single flat index. """ + if _skip_if_missing_symbol(mod, "run_predicated_copy_2d", "predicated_copy_2d"): + return 0 ROWS = 4 # issue space rows per workgroup COLS = 4 # issue space cols per thread BLOCK_SIZE = 256 # threads per block @@ -1756,6 +2195,8 @@ def test_predicated_copy_2d(mod): def test_free_func_vector_add(mod): """Test opus::load / opus::store free function wrappers (vector add).""" + if _skip_if_missing_symbol(mod, "run_free_func_add", "free_func_vector_add"): + return 0 n = 1310720 # same as regular vector_add test device = torch.device("cuda") dtype = torch.float32 @@ -1785,6 +2226,10 @@ def test_free_func_vector_add(mod): def test_predicated_async_load(mod): """Test gmem async_load_if via free function wrapper (boundary predicate).""" + if _skip_if_missing_symbol( + mod, "run_predicated_async_load", "predicated_async_load" + ): + return 0 n = 1001 BLOCK_SIZE = 256 n_padded = ((n + BLOCK_SIZE - 1) // BLOCK_SIZE) * BLOCK_SIZE # 1024 @@ -1934,7 +2379,14 @@ def ref_int(dtype, size): ("fp16", 5, 2, True, torch.float16, True), ("bf16", 10, 2, True, torch.bfloat16, True), ("fp8", 15, 1, True, fp8_dtype, False), - ("bf8", 20, 1, True, bf8_dtype, arch in ("gfx950", "gfx1250")), + ( + "bf8", + 20, + 1, + True, + bf8_dtype, + arch in ("gfx950", "gfx1250", "gfx1201", "gfx1200"), + ), ("i32", 25, 4, False, torch.int32, False), ("i16", 35, 2, False, torch.int16, False), ("i8", 45, 1, False, torch.int8, False), @@ -2171,6 +2623,21 @@ def main(): failures += test_wmma_scale_16x16x128_fp8_bx32_scaled(mod) failures += test_mma_step_k_bf16(mod) failures += test_vector_add(mod) + failures += test_opus_gmem_gfx1201(mod) + failures += test_wmma_gfx1201_f32_f16(mod) + failures += test_wmma_gfx1201_f32_bf16(mod) + failures += test_wmma_gfx1201_f16_f16(mod) + failures += test_wmma_gfx1201_bf16_bf16(mod) + failures += test_wmma_gfx1201_f32_fp8_fp8(mod) + failures += test_wmma_gfx1201_f32_fp8_bf8(mod) + failures += test_wmma_gfx1201_f32_bf8_fp8(mod) + failures += test_wmma_gfx1201_f32_bf8_bf8(mod) + failures += test_wmma_gfx1201_w64_f32_f16(mod) + failures += test_wmma_gfx1201_w64_f32_bf16(mod) + failures += test_wmma_gfx1201_w64_f16_f16(mod) + failures += test_wmma_gfx1201_w64_bf16_bf16(mod) + failures += test_wmma_gfx1201_tiled_f32_f16(mod) + failures += test_wmma_gfx1201_tiled_f32_bf16(mod) failures += test_async_load(mod) failures += test_tr_load_f16(mod) failures += test_dtype_convert_fp32_bf16(mod) diff --git a/op_tests/opus/device/test_opus_gmem_gfx1201.cu b/op_tests/opus/device/test_opus_gmem_gfx1201.cu new file mode 100644 index 0000000000..a9773e530e --- /dev/null +++ b/op_tests/opus/device/test_opus_gmem_gfx1201.cu @@ -0,0 +1,114 @@ +// SPDX-License-Identifier: MIT +// Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. + +/** + * @file test_opus_gmem_gfx1201.cu + * @brief Exercise opus::make_gmem<>.load<>/.store<> on gfx1201 (Navi 48 / + * RX 9070 XT, RDNA4). + * + * Two opus.hpp changes need to be in place for this test to pass: + * + * 1) Forward declarations of mfma_adaptor / wmma_adaptor in the opus + * namespace, so the header parses for gfx1201 device code. + * make_tiled_mma()'s default template argument names these types, + * and without forward decls name lookup fails on archs where the + * full definitions (gated by __GFX9__ / __gfx1250__) are inactive. + * + * 2) buffer_default_config() returning the correct RDNA buffer rsrc + * config (0x31004000) for gfx1201 instead of the 0xffffffff + * fallback. The existing __gfx11__ / __gfx12__ tokens are typos + * (clang only predefines the uppercase __GFX11__ / __GFX12__), so + * without an explicit __gfx1201__ / __gfx1200__ branch the + * make_gmem<> resource descriptor is invalid and all buffer_load_b32 + * lanes return 0 / all buffer_store_b32 lanes drop on the floor. + * + * The kernel exercises the exact opus API that sample_kernels.cu / + * topk_softmax_kernels_group.cu / mhc_kernels.cu depend on: + * + * auto g = opus::make_gmem(ptr); + * auto v = g.load(i); // buffer_load via cached rsrc + * ... opus::cast(v[j]) ... + * g.store(vr, i); // buffer_store via cached rsrc + * + * Kernel body is gated by __gfx1201__ / __gfx1200__ (same RDNA4 family, + * same ISA for buffer_load/store) so other archs see an empty no-op pass + * — gfx1250 / gfx9x behavior is unchanged. + */ + +#ifdef __HIP_DEVICE_COMPILE__ +// ── Device pass: opus.hpp + kernel body, no hip_runtime.h ────────────────── +#include "opus/opus.hpp" + +#if defined(__gfx1201__) || defined(__gfx1200__) +// Element-wise add via opus make_gmem load / store + per-lane opus::cast. +// Mirrors the load → cast → store pattern in sample_kernels.cu. +template +__global__ void opus_gmem_gfx1201_kernel( + const float* __restrict__ a, + const float* __restrict__ b, + float* __restrict__ result, + int n) +{ + auto g_a = opus::make_gmem(a); + auto g_b = opus::make_gmem(b); + auto g_r = opus::make_gmem(result); + + int idx = __builtin_amdgcn_workgroup_id_x() * BLOCK_SIZE + __builtin_amdgcn_workitem_id_x(); + int stride = __builtin_amdgcn_grid_size_x(); + + for (int i = idx * VECTOR_SIZE; i < n; i += stride * VECTOR_SIZE) { + auto va = g_a.load(i); + auto vb = g_b.load(i); + + decltype(va) vr; + for (int j = 0; j < VECTOR_SIZE; ++j) { + // opus::cast(float) is a compile-time pass-through. + // Including it exercises the same template path sample_kernels.cu + // instantiates for its per-lane DTYPE_I → float conversion. + vr[j] = opus::cast(va[j]) + opus::cast(vb[j]); + } + g_r.store(vr, i); + } +} + +template __global__ void opus_gmem_gfx1201_kernel<256, 4>(const float*, const float*, float*, int); +#endif // __gfx1201__ / __gfx1200__ + +#else +// ── Host pass: launcher + empty kernel stub ──────────────────────────────── +#include "opus/hip_minimal.hpp" +#include + +#define HIP_CALL(call) do { \ + hipError_t err = (call); \ + if (err != hipSuccess) { \ + fprintf(stderr, "HIP error %d at %s:%d\n", (int)err, __FILE__, __LINE__); \ + return; \ + } \ +} while(0) + +template +__global__ void opus_gmem_gfx1201_kernel(const float*, const float*, float*, int) {} + +extern "C" void run_opus_gmem_gfx1201( + const void* d_a, + const void* d_b, + void* d_result, + int n) +{ + const auto* a = static_cast(d_a); + const auto* b = static_cast(d_b); + auto* r = static_cast(d_result); + + constexpr int BLOCK_SIZE = 256; + constexpr int VECTOR_SIZE = 4; + int blocks = (n + (BLOCK_SIZE * VECTOR_SIZE) - 1) / (BLOCK_SIZE * VECTOR_SIZE); + + hipLaunchKernelGGL( + (opus_gmem_gfx1201_kernel), + dim3(blocks), dim3(BLOCK_SIZE), 0, 0, + a, b, r, n); + HIP_CALL(hipGetLastError()); + HIP_CALL(hipDeviceSynchronize()); +} +#endif diff --git a/op_tests/opus/device/test_wmma_gfx1201.cu b/op_tests/opus/device/test_wmma_gfx1201.cu new file mode 100644 index 0000000000..0fcf05fe5c --- /dev/null +++ b/op_tests/opus/device/test_wmma_gfx1201.cu @@ -0,0 +1,150 @@ +// SPDX-License-Identifier: MIT +// Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. + +/** + * @file test_wmma_gfx1201.cu + * @brief Direct WMMA tests on gfx1201 (Navi 48, RDNA4) via the opus::wmma<> + * struct dispatch (which targets __builtin_amdgcn_wmma_*_w32_gfx12). + * + * Covers the wave32 16x16x16 WMMA variants gfx1201 supports: + * + * - f32 <- f16 / f16 + * - f32 <- bf16 / bf16 + * - f16 <- f16 / f16 + * - bf16 <- bf16 / bf16 + * - f32 <- fp8 / fp8 + * - f32 <- fp8 / bf8 + * - f32 <- bf8 / fp8 + * - f32 <- bf8 / bf8 + * + * Lane / register layout used here (per the gfx12 WMMA fragment diagrams in + * AMD's internal Navi 4 layout reference and the community-verified guide + * https://github.com/JohnTDI-cpu/rdna4-wmma-guide): + * + * lane i in [0, 31], j in [0, 7]: + * A_frag[i][j] = A[i % 16, (i/16)*8 + j] // ROW-distributed + * B_frag[i][j] = B[(i/16)*8 + j, i % 16] // COLUMN-distributed + * C_frag[i][j] = C[(i/16)*8 + j, i % 16] // COLUMN-distributed + * + * That is: + * - A: lanes 0..15 cover rows 0..15 with K=0..7 within each lane; lanes + * 16..31 cover the same rows with K=8..15. + * - B and C: lanes 0..15 cover columns 0..15 with K (for B) or M-rows + * 0..7 (for C) within each lane; lanes 16..31 cover the same columns + * with K=8..15 (B) or M-rows 8..15 (C). + * + * This asymmetry (A row-distributed, B and C column-distributed) is the + * native gfx12 wmma_128b fragment encoding and matches CK's wmma_gemm.hpp. + * + * The tests deliberately go through opus::wmma<>::operator() to exercise the + * dispatch table (DISPATCH_WMMA_GFX12_F32_ / DISPATCH_WMMA_GFX12_8BIT_), not + * the high-level make_tiled_mma / partition_layout_* path — the latter + * relies on opus::wmma_adaptor whose lane encoding is row-distributed and + * matches gfx1250 but NOT gfx12. A dedicated gfx12 wmma_adaptor would be + * needed for the tiled API to work on gfx1201 (see TODO in opus.hpp). + */ + +#ifdef __HIP_DEVICE_COMPILE__ +// ── Device pass ───────────────────────────────────────────────────────────── +#include "opus/opus.hpp" +#if defined(__gfx1201__) || defined(__gfx1200__) + +// Generic 1-wave, 1-tile WMMA driver. Each lane loads its column-distributed +// A and B fragment from global memory, calls opus::wmma<>::operator(), and +// stores the C fragment back. +template +__global__ void wmma_gfx12_kernel( + const DIN_A* __restrict__ ptr_a, + const DIN_B* __restrict__ ptr_b, + DOUT* __restrict__ ptr_c, + int stride_a, + int stride_b, + int stride_c) +{ + constexpr int WM = 16, WN = 16, WK = 16; + constexpr int ELEM_A = WM * WK / 32; // 8 + constexpr int ELEM_B = WN * WK / 32; // 8 + constexpr int ELEM_C = WM * WN / 32; // 8 + + using vtype_a = opus::vector_t; + using vtype_b = opus::vector_t; + using vtype_c = opus::vector_t; + + int lane = static_cast(__builtin_amdgcn_workitem_id_x() % 32); + int col = lane % 16; // for B / C (column-distributed) + int row_base = (lane / 16) * 8; // for B (K block) / C (M row block) + int a_row = lane % 16; // A is row-distributed: lane selects row + int a_k_base = (lane / 16) * 8; // and lane group selects an 8-wide K block + + vtype_a v_a{}; + vtype_b v_b{}; + vtype_c v_c{}; + + #pragma unroll + for (int j = 0; j < ELEM_A; ++j) v_a[j] = ptr_a[a_row * stride_a + (a_k_base + j)]; + #pragma unroll + for (int j = 0; j < ELEM_B; ++j) v_b[j] = ptr_b[(row_base + j) * stride_b + col]; + + // Call through the opus::wmma<> dispatch — this is the path users of the + // library hit when they call opus::wmma<...>{}(a, b, c). + opus::wmma mma; + v_c = mma(v_a, v_b, v_c); + + #pragma unroll + for (int j = 0; j < ELEM_C; ++j) ptr_c[(row_base + j) * stride_c + col] = v_c[j]; +} + +template __global__ void wmma_gfx12_kernel(const opus::fp16_t*, const opus::fp16_t*, opus::fp32_t*, int, int, int); +template __global__ void wmma_gfx12_kernel(const opus::bf16_t*, const opus::bf16_t*, opus::fp32_t*, int, int, int); +template __global__ void wmma_gfx12_kernel(const opus::fp16_t*, const opus::fp16_t*, opus::fp16_t*, int, int, int); +template __global__ void wmma_gfx12_kernel(const opus::bf16_t*, const opus::bf16_t*, opus::bf16_t*, int, int, int); +template __global__ void wmma_gfx12_kernel(const opus::fp8_t* , const opus::fp8_t* , opus::fp32_t*, int, int, int); +template __global__ void wmma_gfx12_kernel(const opus::fp8_t* , const opus::bf8_t* , opus::fp32_t*, int, int, int); +template __global__ void wmma_gfx12_kernel(const opus::bf8_t* , const opus::fp8_t* , opus::fp32_t*, int, int, int); +template __global__ void wmma_gfx12_kernel(const opus::bf8_t* , const opus::bf8_t* , opus::fp32_t*, int, int, int); + +#endif // __gfx1201__ / __gfx1200__ + +#else +// ── Host pass: empty kernel stubs + extern "C" launchers ──────────────────── +#include "opus/opus.hpp" +#include "opus/hip_minimal.hpp" +#include + +#define HIP_CALL(call) do { \ + hipError_t err = (call); \ + if (err != hipSuccess) { \ + fprintf(stderr, "HIP error %d at %s:%d\n", (int)err, __FILE__, __LINE__); \ + return; \ + } \ +} while(0) + +template +__global__ void wmma_gfx12_kernel(const DIN_A*, const DIN_B*, DOUT*, int, int, int) {} + +#define LAUNCHER_(NAME, DA, DB, DC) \ +extern "C" void run_wmma_gfx1201_ ## NAME ( \ + const void* d_a, const void* d_b, void* d_c, \ + int stride_a, int stride_b, int stride_c) \ +{ \ + hipLaunchKernelGGL((wmma_gfx12_kernel), \ + dim3(1, 1), 32, 0, 0, \ + static_cast(d_a), \ + static_cast(d_b), \ + static_cast(d_c), \ + stride_a, stride_b, stride_c); \ + HIP_CALL(hipGetLastError()); \ + HIP_CALL(hipDeviceSynchronize()); \ +} + +LAUNCHER_(f32_f16, fp16_t, fp16_t, fp32_t) +LAUNCHER_(f32_bf16, bf16_t, bf16_t, fp32_t) +LAUNCHER_(f16_f16, fp16_t, fp16_t, fp16_t) +LAUNCHER_(bf16_bf16, bf16_t, bf16_t, bf16_t) +LAUNCHER_(f32_fp8_fp8, fp8_t, fp8_t, fp32_t) +LAUNCHER_(f32_fp8_bf8, fp8_t, bf8_t, fp32_t) +LAUNCHER_(f32_bf8_fp8, bf8_t, fp8_t, fp32_t) +LAUNCHER_(f32_bf8_bf8, bf8_t, bf8_t, fp32_t) + +#undef LAUNCHER_ +#endif // __HIP_DEVICE_COMPILE__ diff --git a/op_tests/opus/device/test_wmma_gfx1201_tiled.cu b/op_tests/opus/device/test_wmma_gfx1201_tiled.cu new file mode 100644 index 0000000000..3c79e0da82 --- /dev/null +++ b/op_tests/opus/device/test_wmma_gfx1201_tiled.cu @@ -0,0 +1,84 @@ +// SPDX-License-Identifier: MIT +// Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. +// Tiled WMMA test for gfx1201/gfx1200: make_tiled_mma + partition_layout + gmem load/store. +// Computes C = A @ B^T (swap_ab) with 16x16x16 wave32 WMMA, same pattern as test_wmma_f32.cu (gfx1250). + +#ifdef __HIP_DEVICE_COMPILE__ +#include "opus/opus.hpp" +#if defined(__gfx1201__) || defined(__gfx1200__) + +template +__global__ void wmma_gfx12_tiled_kernel( + const DIN* __restrict__ ptr_a, + const DIN* __restrict__ ptr_b, + DOUT* __restrict__ ptr_c, + int k, int stride_a, int stride_b, int stride_c) +{ + using opus::operator""_I; + constexpr int T_M = 1, T_N = 1, T_K = 1; + constexpr int E_M = 1, E_N = 1, E_K = 1; + constexpr int ELEM_A = WM * WK / 32; + constexpr int PACK_A = (16 / static_cast(sizeof(DIN)) < ELEM_A) ? 16 / static_cast(sizeof(DIN)) : ELEM_A; + constexpr int PACK_B = PACK_A; + constexpr int ELEM_C = WM * WN / 32; + constexpr int PACK_C = (16 / static_cast(sizeof(DOUT)) < ELEM_C) ? 16 / static_cast(sizeof(DOUT)) : ELEM_C; + using d_a = DIN; using d_b = DIN; using d_c = DOUT; + + int lane_id = static_cast(opus::lane_id()); + int g_im = __builtin_amdgcn_workgroup_id_x() * WM; + int g_in = __builtin_amdgcn_workgroup_id_y() * WN; + + auto mma = opus::make_tiled_mma( + opus::make_wmma(opus::seq{}, opus::wmma_adaptor_swap_ab{}), + opus::seq{}, opus::seq{}); + + auto u_a = opus::partition_layout_a(mma, opus::make_tuple(stride_a, 1_I), + opus::make_tuple(0_I, lane_id % mma.grpm_a, 0_I, lane_id / mma.grpm_a)); + auto u_b = opus::partition_layout_b(mma, opus::make_tuple(stride_b, 1_I), + opus::make_tuple(0_I, lane_id % mma.grpn_b, 0_I, lane_id / mma.grpn_b)); + auto u_c = opus::partition_layout_c(mma, opus::make_tuple(stride_c, 1_I), + opus::make_tuple(0_I, lane_id % mma.grpn_c, 0_I, lane_id / mma.grpn_c)); + + auto g_a = opus::make_gmem(ptr_a + g_im * stride_a); + auto g_b = opus::make_gmem(ptr_b + g_in * stride_b); + auto g_c = opus::make_gmem(ptr_c + g_im * stride_c + g_in); + + int loops = (k + WK - 1) / WK; + typename decltype(mma)::vtype_c v_c; + opus::clear(v_c); + for (int i = 0; i < loops; i++) { + auto v_a = g_a.template load(u_a); + u_a += WK; + auto v_b = g_b.template load(u_b); + u_b += WK; + v_c = mma(v_a, v_b, v_c); + } + g_c.template store(v_c, u_c); +} + +template __global__ void wmma_gfx12_tiled_kernel(const opus::fp16_t*, const opus::fp16_t*, opus::fp32_t*, int, int, int, int); +template __global__ void wmma_gfx12_tiled_kernel(const opus::bf16_t*, const opus::bf16_t*, opus::fp32_t*, int, int, int, int); + +#endif +#else +#include "opus/opus.hpp" +#include "opus/hip_minimal.hpp" +#include +#define HIP_CALL(call) do { hipError_t err = (call); if (err != hipSuccess) { fprintf(stderr, "HIP error %d at %s:%d\n", (int)err, __FILE__, __LINE__); return; } } while(0) + +template +__global__ void wmma_gfx12_tiled_kernel(const DIN*, const DIN*, DOUT*, int, int, int, int) {} + +#define LAUNCHER_(NAME, DIN, DOUT, WM, WN, WK) \ +extern "C" void run_wmma_gfx1201_tiled_ ## NAME ( \ + const void* d_a, const void* d_b, void* d_c, int stride_a, int stride_b, int stride_c) { \ + hipLaunchKernelGGL((wmma_gfx12_tiled_kernel), \ + dim3(1, 1), 32, 0, 0, \ + static_cast(d_a), static_cast(d_b), \ + static_cast(d_c), WK, stride_a, stride_b, stride_c); \ + HIP_CALL(hipGetLastError()); HIP_CALL(hipDeviceSynchronize()); } + +LAUNCHER_(f32_f16, fp16_t, fp32_t, 16, 16, 16) +LAUNCHER_(f32_bf16, bf16_t, fp32_t, 16, 16, 16) +#undef LAUNCHER_ +#endif diff --git a/op_tests/opus/device/test_wmma_gfx1201_w64.cu b/op_tests/opus/device/test_wmma_gfx1201_w64.cu new file mode 100644 index 0000000000..5d8b461e7d --- /dev/null +++ b/op_tests/opus/device/test_wmma_gfx1201_w64.cu @@ -0,0 +1,106 @@ +// SPDX-License-Identifier: MIT +// Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. + +#ifdef __HIP_DEVICE_COMPILE__ +#include "opus/opus.hpp" +#if defined(__gfx1201__) || defined(__gfx1200__) + +template +__global__ void wmma_gfx12_w64_kernel( + const DIN_A* __restrict__ ptr_a, + const DIN_B* __restrict__ ptr_b, + DOUT* __restrict__ ptr_c, + int stride_a, + int stride_b, + int stride_c) +{ + constexpr int WM = 16, WN = 16, WK = 16; + constexpr int WARP = 64; + constexpr int ELEM = WM * WK / WARP; // 4 + + using vtype_a = opus::vector_t; + using vtype_b = opus::vector_t; + using vtype_c = opus::vector_t; + + int lane = static_cast(opus::lane_id()); + int group = lane / 16; // 0,1,2,3 + int sublane = lane % 16; // column for B/C, row for A + + // A: row-distributed. A[row=sublane][K=(group*4+j)] + // Same group order as my original assumption — confirmed by AMD matrix calculator. + int a_row = sublane; + int a_k_base = group * 4; + + // B: column-distributed. B[K=(group*4+j)][col=sublane] + int b_col = sublane; + int b_k_base = group * 4; + + // D/C: column-distributed with INTERLEAVED group order {0,2,1,3} + // Group 0 (lanes 0-15) → rows 0-3 + // Group 1 (lanes 16-31) → rows 8-11 (NOT 4-7!) + // Group 2 (lanes 32-47) → rows 4-7 + // Group 3 (lanes 48-63) → rows 12-15 + constexpr int c_row_base_lut[4] = {0, 8, 4, 12}; + int c_col = sublane; + int c_row_base = c_row_base_lut[group]; + + vtype_a v_a{}; + vtype_b v_b{}; + vtype_c v_c{}; + + #pragma unroll + for (int j = 0; j < ELEM; ++j) v_a[j] = ptr_a[a_row * stride_a + (a_k_base + j)]; + #pragma unroll + for (int j = 0; j < ELEM; ++j) v_b[j] = ptr_b[(b_k_base + j) * stride_b + b_col]; + + opus::wmma mma; + v_c = mma(v_a, v_b, v_c); + + #pragma unroll + for (int j = 0; j < ELEM; ++j) ptr_c[(c_row_base + j) * stride_c + c_col] = v_c[j]; +} + +template __global__ void wmma_gfx12_w64_kernel(const opus::fp16_t*, const opus::fp16_t*, opus::fp32_t*, int, int, int); +template __global__ void wmma_gfx12_w64_kernel(const opus::bf16_t*, const opus::bf16_t*, opus::fp32_t*, int, int, int); +template __global__ void wmma_gfx12_w64_kernel(const opus::fp16_t*, const opus::fp16_t*, opus::fp16_t*, int, int, int); +template __global__ void wmma_gfx12_w64_kernel(const opus::bf16_t*, const opus::bf16_t*, opus::bf16_t*, int, int, int); + +#endif +#else +#include "opus/opus.hpp" +#include "opus/hip_minimal.hpp" +#include + +#define HIP_CALL(call) do { \ + hipError_t err = (call); \ + if (err != hipSuccess) { \ + fprintf(stderr, "HIP error %d at %s:%d\n", (int)err, __FILE__, __LINE__); \ + return; \ + } \ +} while(0) + +template +__global__ void wmma_gfx12_w64_kernel(const DIN_A*, const DIN_B*, DOUT*, int, int, int) {} + +#define LAUNCHER_(NAME, DA, DB, DC) \ +extern "C" void run_wmma_gfx1201_w64_ ## NAME ( \ + const void* d_a, const void* d_b, void* d_c, \ + int stride_a, int stride_b, int stride_c) \ +{ \ + hipLaunchKernelGGL((wmma_gfx12_w64_kernel), \ + dim3(1, 1), 64, 0, 0, \ + static_cast(d_a), \ + static_cast(d_b), \ + static_cast(d_c), \ + stride_a, stride_b, stride_c); \ + HIP_CALL(hipGetLastError()); \ + HIP_CALL(hipDeviceSynchronize()); \ +} + +LAUNCHER_(f32_f16, fp16_t, fp16_t, fp32_t) +LAUNCHER_(f32_bf16, bf16_t, bf16_t, fp32_t) +LAUNCHER_(f16_f16, fp16_t, fp16_t, fp16_t) +LAUNCHER_(bf16_bf16, bf16_t, bf16_t, bf16_t) + +#undef LAUNCHER_ +#endif