Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
70f6e4a
opus: forward-declare mma adaptors so opus.hpp parses on gfx1201
carlushuang May 16, 2026
9878ec9
opus: route gfx1200/gfx1201 to the RDNA buffer rsrc config (fix silen…
carlushuang May 16, 2026
b8aa331
opus: add gfx1201 (Navi 48 / RDNA4) WMMA support — 8 wave32 16x16x16 …
carlushuang May 16, 2026
fdf9d87
opus: black + ruff format on test_opus_device.py wmma additions
carlushuang May 16, 2026
cde14bb
opus: condense gfx1201 commentary to match the rest of the file
carlushuang May 16, 2026
d9f992e
opus device tests: skip gfx1201-incompatible kernels at build time
carlushuang May 16, 2026
4ab7e0d
opus: extend gfx1201 gates to also cover gfx1200 (Navi 44, same RDNA4…
carlushuang May 16, 2026
3007acd
opus: add gfx1201 wave64 WMMA test infrastructure and D-matrix layout…
carlushuang May 17, 2026
1387801
opus: fix wave64 WMMA test — use direct mbcnt for lane_id on gfx1201
carlushuang May 17, 2026
f0a7c21
opus: fix get_warp_size() for gfx12 wave64 via __has_builtin proxy
carlushuang May 17, 2026
a4ce01b
opus: condense gfx12 wave64 comments to match file style
carlushuang May 17, 2026
46747bc
opus: black format on test_opus_device.py w64 additions
carlushuang May 17, 2026
38c1a65
Merge origin/main into carhuang/opus_gfx1201_parse_fix
carlushuang May 17, 2026
1eb7cca
opus: factor out DISPATCH_WMMA_GFX12_MATCH_ to condense gfx12 dispatc…
carlushuang May 17, 2026
71e005e
opus: fix gfx1201 test failures + define OPUS_GFX120X_IS_WAVE32 + cod…
carlushuang May 17, 2026
7c98ab4
opus: condense comments to single-line style
carlushuang May 17, 2026
b4b2fc5
opus: add make_tiled_mma support for gfx1201/gfx1200 + tiled WMMA dev…
carlushuang May 17, 2026
c601323
triton: add gfx1201 GEMM configs (copied from gfx1250 as baseline)
carlushuang May 17, 2026
e4ec1fd
Revert "triton: add gfx1201 GEMM configs (copied from gfx1250 as base…
carlushuang May 18, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 101 additions & 9 deletions csrc/include/opus/opus.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -499,6 +499,9 @@ template <std::size_t I, typename... Ts> struct tuple_element<I, const opus::tup
} // namespace std

namespace opus {
// fwd-decl mma adaptors — needed so make_tiled_mma() parses on archs (gfx12) where full defs are gated out.
struct mfma_adaptor; struct mfma_adaptor_swap_ab; struct wmma_adaptor; struct wmma_adaptor_swap_ab;

/////////////////////////////////////////////////////////////////////////////////////////////////////////
// transforms
template<typename X, typename Y, index_t... Is> constexpr auto embed(const X& x, const Y& y, seq<Is...>) { return ( ... + (get<Is>(x) * get<Is>(y))); }
Expand Down Expand Up @@ -1101,8 +1104,8 @@ OPUS_D constexpr unsigned short fp32_to_bf16_rtn_raw(float f)
else if(bits & 0xffff) { bits |= 0x10000; /* Preserve signaling NaN */ }
return static_cast<unsigned short>(bits >> 16);
}
#if (defined(__gfx950__) || defined(__gfx1250__)) && __clang_major__ >= 20
template<index_t rm = OPUS_FP32_to_BF16_DEFAULT> // gfx950/gfx1250 has instruction conversion, leave 'rm' here for compatiblity
#if (defined(__gfx950__) || defined(__gfx1250__) || defined(__gfx1201__) || defined(__gfx1200__)) && __clang_major__ >= 20
template<index_t rm = OPUS_FP32_to_BF16_DEFAULT> // gfx950/gfx1250/gfx12 has instruction conversion, leave 'rm' here for compatiblity
OPUS_D constexpr auto fp32_to_bf16(const fp32_t& x, number<rm> = {}) { return static_cast<bf16_t>(x); }
#else
template<index_t rm = OPUS_FP32_to_BF16_DEFAULT> // 0:standard, 1:truncate_with_nan, 2:truncate, 3:standard asm 4:rta_asm(round to nearest away)
Expand Down Expand Up @@ -1497,12 +1500,23 @@ OPUS_D constexpr decltype(auto) cast(const S& s, Aux&&... aux) {
// Guarded by OPUS_ENABLE_RUNTIME_QUERY (default 0). Define OPUS_ENABLE_RUNTIME_QUERY=1 before
// including opus.hpp (or via compiler flag) to enable these functions and the hip_runtime_api.h include.
//
// gfx12 wave32/64 detection: __AMDGCN_WAVEFRONT_SIZE__ removed in ROCm 7.2; _w32 builtins are gated by wavefrontsize32 target feature, so __has_builtin is a constexpr proxy.
#if (defined(__gfx1201__) || defined(__gfx1200__)) && defined(__HIP_DEVICE_COMPILE__)
# if __has_builtin(__builtin_amdgcn_wmma_f32_16x16x16_f16_w32_gfx12)
# define OPUS_GFX120X_IS_WAVE32 1
# else
# define OPUS_GFX120X_IS_WAVE32 0
# endif
#endif

OPUS_H_D constexpr index_t get_warp_size()
{
#if defined(__gfx1250__)
return 32;
#elif defined(__GFX9__) || !defined(__HIP_DEVICE_COMPILE__)
return 64;
#elif defined(OPUS_GFX120X_IS_WAVE32) && !OPUS_GFX120X_IS_WAVE32
return 64;
#else
return 32;
#endif
Expand Down Expand Up @@ -1613,7 +1627,8 @@ OPUS_D constexpr auto buffer_default_config() {
return 0x00020000;
#elif defined(__gfx103__)
return 0x31014000;
#elif defined(__gfx11__) || defined(__gfx12__) || defined(__gfx1250__)
// gfx1200/gfx1201 (Navi 44/48) listed explicitly — __gfx11__/__gfx12__ are typos (clang predefines only uppercase __GFX11__/__GFX12__), so without this gfx1200/gfx1201 fall into the 0xffffffff sentinel and make_gmem<> stores silently drop.
#elif defined(__gfx11__) || defined(__gfx12__) || defined(__gfx1250__) || defined(__gfx1201__) || defined(__gfx1200__)
return 0x31004000;
#else
return 0xffffffff;
Expand Down Expand Up @@ -2302,8 +2317,8 @@ using mfma_scale_f32_16x16x128_fp4_fp4 = mfma_f32_16x16x128_fp4_fp4;
#endif // __GFX9__ (mfma)

/////////////////////////////////////////////////////////////////////////////////////////////////////////
// wmma (gfx1250 / RDNA4, wave32)
#if defined(__gfx1250__) || !defined(__HIP_DEVICE_COMPILE__)
// wmma (RDNA4 / wave32) — gfx1250 uses wmma-256b builtins (16x16x{4,32,64,128}); gfx1200/gfx1201 (Navi 44/48) use wmma-128b _w32_gfx12 builtins (16x16x16). Dispatch macros for the two arg-list shapes differ — gfx12 set is DISPATCH_WMMA_GFX12_*.
#if defined(__gfx1250__) || defined(__gfx1201__) || defined(__gfx1200__) || !defined(__HIP_DEVICE_COMPILE__)
// f16/bf16/f32 builtins: (neg_a, A, neg_b, B, matrix_fmts, C, clamp, neg_c)
#define DISPATCH_WMMA_(ta_, tb_, tc_, wm_, wn_, wk_, inst_) \
(std::is_same_v<dtype_a, ta_> && std::is_same_v<dtype_b, tb_> && std::is_same_v<dtype_c, tc_> && \
Expand All @@ -2326,6 +2341,23 @@ using mfma_scale_f32_16x16x128_fp4_fp4 = mfma_f32_16x16x128_fp4_fp4;
__builtin_bit_cast(vector_t<i32_t, i32_b>, b), \
static_cast<short>(0), c, false, false); }

// gfx12 builtins: 3-arg (A, B, C) — no matrix_fmts/neg_c. ws_ selects _w32/_w64 (same {wm,wn,wk} triple).
#define DISPATCH_WMMA_GFX12_MATCH_(ta_, tb_, tc_, wm_, wn_, wk_, ws_) \
(std::is_same_v<dtype_a, ta_> && std::is_same_v<dtype_b, tb_> && std::is_same_v<dtype_c, tc_> && wave_m == wm_ && wave_n == wn_ && wave_k == wk_ && warp_size == ws_)
#define DISPATCH_WMMA_GFX12_F32_(ta_, tb_, tc_, wm_, wn_, wk_, ws_, inst_) \
DISPATCH_WMMA_GFX12_MATCH_(ta_, tb_, tc_, wm_, wn_, wk_, ws_) { return inst_(a, b, c); }
#define DISPATCH_WMMA_GFX12_8BIT_(ta_, tb_, tc_, wm_, wn_, wk_, ws_, inst_) /* fp8/bf8: A/B bitcast to i32 */ \
DISPATCH_WMMA_GFX12_MATCH_(ta_, tb_, tc_, wm_, wn_, wk_, ws_) { \
constexpr index_t i32_a = elem_a * static_cast<index_t>(sizeof(dtype_a)) / static_cast<index_t>(sizeof(i32_t)); \
constexpr index_t i32_b = elem_b * static_cast<index_t>(sizeof(dtype_b)) / static_cast<index_t>(sizeof(i32_t)); \
return inst_(__builtin_bit_cast(vector_t<i32_t, i32_a>, a), __builtin_bit_cast(vector_t<i32_t, i32_b>, b), c); }
#define DISPATCH_WMMA_GFX12_F32_STEP_K_(ta_, tb_, tc_, wm_, wn_, wk_, ws_, inst_k_, inst_) /* chain N×inst for wider K */ \
DISPATCH_WMMA_GFX12_MATCH_(ta_, tb_, tc_, wm_, wn_, wk_, ws_) { \
constexpr index_t steps = wk_ / inst_k_, e_a = elem_a / steps, e_b = elem_b / steps; \
auto tmp = inst_(slice(a, number<0>{}, number<e_a>{}), slice(b, number<0>{}, number<e_b>{}), c); \
static_for<steps - 1>([&](auto i){ tmp = inst_(slice(a, number<e_a*(i+1)>{}, number<e_a*(i+2)>{}), slice(b, number<e_b*(i+1)>{}, number<e_b*(i+2)>{}), tmp); }); \
return tmp; }

template<typename dtype_a_, typename dtype_b_, typename dtype_c_, index_t wave_m_, index_t wave_n_, index_t wave_k_, index_t warp_size_ = get_warp_size()>
struct wmma {
using dtype_a = remove_cvref_t<dtype_a_>;
Expand Down Expand Up @@ -2389,6 +2421,37 @@ struct wmma {
else if constexpr DISPATCH_WMMA_8BIT_(bf8_t, fp8_t, fp16_t, 16, 16, 128, __builtin_amdgcn_wmma_f16_16x16x128_bf8_fp8)
else if constexpr DISPATCH_WMMA_8BIT_(bf8_t, bf8_t, fp16_t, 16, 16, 128, __builtin_amdgcn_wmma_f16_16x16x128_bf8_bf8)
#endif
#if defined(__gfx1201__) || defined(__gfx1200__)
// _w32/_w64 builtins gated by wavefrontsize target feature; OPUS_GFX120X_IS_WAVE32 selects the active wave mode.
# if OPUS_GFX120X_IS_WAVE32
// gfx12 wave32 16x16x16: f16/bf16/fp8/bf8 → f32 + same-type acc. iu8/iu4 deferred.
else if constexpr DISPATCH_WMMA_GFX12_F32_ (fp16_t, fp16_t, fp32_t , 16, 16, 16, 32, __builtin_amdgcn_wmma_f32_16x16x16_f16_w32_gfx12)
else if constexpr DISPATCH_WMMA_GFX12_F32_ (bf16_t, bf16_t, fp32_t , 16, 16, 16, 32, __builtin_amdgcn_wmma_f32_16x16x16_bf16_w32_gfx12)
else if constexpr DISPATCH_WMMA_GFX12_F32_ (fp16_t, fp16_t, fp16_t , 16, 16, 16, 32, __builtin_amdgcn_wmma_f16_16x16x16_f16_w32_gfx12)
else if constexpr DISPATCH_WMMA_GFX12_F32_ (bf16_t, bf16_t, bf16_t , 16, 16, 16, 32, __builtin_amdgcn_wmma_bf16_16x16x16_bf16_w32_gfx12)
else if constexpr DISPATCH_WMMA_GFX12_8BIT_(fp8_t , fp8_t , fp32_t , 16, 16, 16, 32, __builtin_amdgcn_wmma_f32_16x16x16_fp8_fp8_w32_gfx12)
else if constexpr DISPATCH_WMMA_GFX12_8BIT_(fp8_t , bf8_t , fp32_t , 16, 16, 16, 32, __builtin_amdgcn_wmma_f32_16x16x16_fp8_bf8_w32_gfx12)
else if constexpr DISPATCH_WMMA_GFX12_8BIT_(bf8_t , fp8_t , fp32_t , 16, 16, 16, 32, __builtin_amdgcn_wmma_f32_16x16x16_bf8_fp8_w32_gfx12)
else if constexpr DISPATCH_WMMA_GFX12_8BIT_(bf8_t , bf8_t , fp32_t , 16, 16, 16, 32, __builtin_amdgcn_wmma_f32_16x16x16_bf8_bf8_w32_gfx12)
// gfx12 wave32 16x16x32 synthetic: 2× stacked 16x16x16 along K.
else if constexpr DISPATCH_WMMA_GFX12_F32_STEP_K_(fp16_t, fp16_t, fp32_t , 16, 16, 32, 32, 16, __builtin_amdgcn_wmma_f32_16x16x16_f16_w32_gfx12)
else if constexpr DISPATCH_WMMA_GFX12_F32_STEP_K_(bf16_t, bf16_t, fp32_t , 16, 16, 32, 32, 16, __builtin_amdgcn_wmma_f32_16x16x16_bf16_w32_gfx12)
else if constexpr DISPATCH_WMMA_GFX12_F32_STEP_K_(fp16_t, fp16_t, fp16_t , 16, 16, 32, 32, 16, __builtin_amdgcn_wmma_f16_16x16x16_f16_w32_gfx12)
else if constexpr DISPATCH_WMMA_GFX12_F32_STEP_K_(bf16_t, bf16_t, bf16_t , 16, 16, 32, 32, 16, __builtin_amdgcn_wmma_bf16_16x16x16_bf16_w32_gfx12)
# endif // wave32 builtin available
# if !OPUS_GFX120X_IS_WAVE32
// gfx12 wave64 16x16x16: native _w64_gfx12 builtins (4 elem/lane). Requires -mwavefrontsize64.
else if constexpr DISPATCH_WMMA_GFX12_F32_ (fp16_t, fp16_t, fp32_t , 16, 16, 16, 64, __builtin_amdgcn_wmma_f32_16x16x16_f16_w64_gfx12)
else if constexpr DISPATCH_WMMA_GFX12_F32_ (bf16_t, bf16_t, fp32_t , 16, 16, 16, 64, __builtin_amdgcn_wmma_f32_16x16x16_bf16_w64_gfx12)
else if constexpr DISPATCH_WMMA_GFX12_F32_ (fp16_t, fp16_t, fp16_t , 16, 16, 16, 64, __builtin_amdgcn_wmma_f16_16x16x16_f16_w64_gfx12)
else if constexpr DISPATCH_WMMA_GFX12_F32_ (bf16_t, bf16_t, bf16_t , 16, 16, 16, 64, __builtin_amdgcn_wmma_bf16_16x16x16_bf16_w64_gfx12)
// gfx12 wave64 16x16x32 synthetic: 2× stacked _w64 16x16x16 (8 elem/lane = 1 b128 load).
else if constexpr DISPATCH_WMMA_GFX12_F32_STEP_K_(fp16_t, fp16_t, fp32_t , 16, 16, 32, 64, 16, __builtin_amdgcn_wmma_f32_16x16x16_f16_w64_gfx12)
else if constexpr DISPATCH_WMMA_GFX12_F32_STEP_K_(bf16_t, bf16_t, fp32_t , 16, 16, 32, 64, 16, __builtin_amdgcn_wmma_f32_16x16x16_bf16_w64_gfx12)
else if constexpr DISPATCH_WMMA_GFX12_F32_STEP_K_(fp16_t, fp16_t, fp16_t , 16, 16, 32, 64, 16, __builtin_amdgcn_wmma_f16_16x16x16_f16_w64_gfx12)
else if constexpr DISPATCH_WMMA_GFX12_F32_STEP_K_(bf16_t, bf16_t, bf16_t , 16, 16, 32, 64, 16, __builtin_amdgcn_wmma_bf16_16x16x16_bf16_w64_gfx12)
# endif // wave64 builtin available
#endif // __gfx1201__ / __gfx1200__
__builtin_unreachable();
}

Expand Down Expand Up @@ -2483,6 +2546,35 @@ struct wmma {
#undef DISPATCH_WMMA_
#undef DISPATCH_WMMA_BF16F32_
#undef DISPATCH_WMMA_8BIT_
#undef DISPATCH_WMMA_GFX12_MATCH_
#undef DISPATCH_WMMA_GFX12_F32_
#undef DISPATCH_WMMA_GFX12_8BIT_
#undef DISPATCH_WMMA_GFX12_F32_STEP_K_

// gfx12 wave32 16x16x16 aliases (default warp_size). wave64 + synthetic 16x16x32 aliases below.
using wmma_f32_16x16x16_f16 = wmma<fp16_t, fp16_t, fp32_t, 16, 16, 16>;
using wmma_f16_16x16x16_f16 = wmma<fp16_t, fp16_t, fp16_t, 16, 16, 16>;
using wmma_f32_16x16x16_bf16 = wmma<bf16_t, bf16_t, fp32_t, 16, 16, 16>;
using wmma_bf16_16x16x16_bf16 = wmma<bf16_t, bf16_t, bf16_t, 16, 16, 16>;
using wmma_f32_16x16x16_fp8_fp8 = wmma<fp8_t , fp8_t , fp32_t, 16, 16, 16>;
using wmma_f32_16x16x16_fp8_bf8 = wmma<fp8_t , bf8_t , fp32_t, 16, 16, 16>;
using wmma_f32_16x16x16_bf8_fp8 = wmma<bf8_t , fp8_t , fp32_t, 16, 16, 16>;
using wmma_f32_16x16x16_bf8_bf8 = wmma<bf8_t , bf8_t , fp32_t, 16, 16, 16>;
// gfx12 wave32 16x16x32 synthetic (2× 16x16x16).
using wmma_f32_16x16x32_f16_w32 = wmma<fp16_t, fp16_t, fp32_t, 16, 16, 32, 32>;
using wmma_f32_16x16x32_bf16_w32 = wmma<bf16_t, bf16_t, fp32_t, 16, 16, 32, 32>;
using wmma_f16_16x16x32_f16_w32 = wmma<fp16_t, fp16_t, fp16_t, 16, 16, 32, 32>;
using wmma_bf16_16x16x32_bf16_w32 = wmma<bf16_t, bf16_t, bf16_t, 16, 16, 32, 32>;
// gfx12 wave64 16x16x16 (-mwavefrontsize64).
using wmma_f32_16x16x16_f16_w64 = wmma<fp16_t, fp16_t, fp32_t, 16, 16, 16, 64>;
using wmma_f32_16x16x16_bf16_w64 = wmma<bf16_t, bf16_t, fp32_t, 16, 16, 16, 64>;
using wmma_f16_16x16x16_f16_w64 = wmma<fp16_t, fp16_t, fp16_t, 16, 16, 16, 64>;
using wmma_bf16_16x16x16_bf16_w64 = wmma<bf16_t, bf16_t, bf16_t, 16, 16, 16, 64>;
// gfx12 wave64 16x16x32 synthetic (2× 16x16x16_w64).
using wmma_f32_16x16x32_f16_w64 = wmma<fp16_t, fp16_t, fp32_t, 16, 16, 32, 64>;
using wmma_f32_16x16x32_bf16_w64 = wmma<bf16_t, bf16_t, fp32_t, 16, 16, 32, 64>;
using wmma_f16_16x16x32_f16_w64 = wmma<fp16_t, fp16_t, fp16_t, 16, 16, 32, 64>;
using wmma_bf16_16x16x32_bf16_w64 = wmma<bf16_t, bf16_t, bf16_t, 16, 16, 32, 64>;

// f16/bf16 16x16x32
using wmma_f32_16x16x32_f16 = wmma<fp16_t, fp16_t, fp32_t, 16, 16, 32>;
Expand Down Expand Up @@ -2514,7 +2606,7 @@ using wmma_scale_f32_16x16x128_fp8_fp8 = wmma<fp8_t, fp8_t, fp32_t, 16, 16, 128>
using wmma_scale_f32_16x16x128_fp4_fp4 = wmma<fp4_t, fp4_t, fp32_t, 16, 16, 128>;
// Scaled WMMA (dedicated fp4 32x16x128 instruction)
using wmma_scale_f32_32x16x128_fp4_fp4 = wmma<fp4_t, fp4_t, fp32_t, 32, 16, 128>;
#endif // __gfx1250__ (wmma)
#endif // __gfx1250__ / __gfx1201__ / __gfx1200__ (wmma)

/////////////////////////////////////////////////////////////////////////////////////////////////////////
// adaptor
Expand Down Expand Up @@ -2714,11 +2806,11 @@ template<typename d_a, typename d_b, typename d_c, typename WaveMNK /*seq<m, n,
OPUS_D decltype(auto) make_mfma(WaveMNK&&, A&& = {}, number<warp_size_> = {}) { return A{}(mfma<d_a, d_b, d_c, get<0>(WaveMNK{}), get<1>(WaveMNK{}), get<2>(WaveMNK{}), warp_size_>{}); }
#endif // __GFX9__

// wmma_adaptor: same layout encoding as mfma_adaptor but for wave32 WMMA (gfx1250)
// wmma_adaptor: layout encoding for wave32 WMMA (gfx1250, gfx1200/gfx1201).
// A:[(grpm_a<p>), (rept_a<y>, grpk_a<p>, pack_a<y>)], MxK
// B:[(grpn_b<p>), (rept_b<y>, grpk_b<p>, pack_b<y>)], NxK
// C:[(grpm_c<p>, rept_c<y>, pack_c<y>), (grpn_c<p>)], MxN
#if defined(__gfx1250__) || !defined(__HIP_DEVICE_COMPILE__)
#if defined(__gfx1250__) || defined(__gfx1201__) || defined(__gfx1200__) || !defined(__HIP_DEVICE_COMPILE__)
namespace impl {
template<typename WMMA>
struct wmma_adaptor : public remove_cvref_t<WMMA> {
Expand Down Expand Up @@ -2797,7 +2889,7 @@ OPUS_D decltype(auto) make_wmma(number<w_m>, number<w_n>, number<w_k>, A&& = {},

template<typename d_a, typename d_b, typename d_c, typename WaveMNK, typename A = wmma_adaptor, index_t warp_size_ = get_warp_size()>
OPUS_D decltype(auto) make_wmma(WaveMNK&&, A&& = {}, number<warp_size_> = {}) { return A{}(wmma<d_a, d_b, d_c, get<0>(WaveMNK{}), get<1>(WaveMNK{}), get<2>(WaveMNK{}), warp_size_>{}); }
#endif // __gfx1250__
#endif // __gfx1250__ / __gfx1201__ / __gfx1200__ (wmma_adaptor)

/////////////////////////////////////////////////////////////////////////////////////////////////////////
namespace impl {
Expand Down
34 changes: 31 additions & 3 deletions op_tests/opus/device/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,15 @@
"test_numeric_limits.cu",
"test_workgroup_barrier.cu",
"test_finfo.cu",
"test_opus_gmem_gfx1201.cu",
"test_wmma_gfx1201.cu",
"test_wmma_gfx1201_w64.cu",
"test_wmma_gfx1201_tiled.cu",
]

# Sources requiring -mwavefrontsize64 (wave64 builtins).
_W64_SOURCES = {"test_wmma_gfx1201_w64.cu"}


def _detect_arch():
try:
Expand Down Expand Up @@ -74,7 +81,8 @@ def _find_hipcc():

def _compile_one(args):
"""Compile a single .cu -> .o. Used as a worker function for parallel builds."""
src, obj, hipcc, arch, verbose = args
src, obj, hipcc, arch, verbose, *rest = args
extra_flags = rest[0] if rest else []
cmd = [
hipcc,
f"--offload-arch={arch}",
Expand All @@ -83,6 +91,7 @@ def _compile_one(args):
"-D__HIPCC_RTC__",
f"-I{_REPO_CSRC}",
f"-I{_THIS_DIR}",
*extra_flags,
"-c",
src,
"-o",
Expand All @@ -109,14 +118,33 @@ def build(verbose=False, jobs=None):
if verbose:
print(f"[setup] arch={arch}, jobs={jobs}")

# Per-arch skip list: kernels that use builtins not available on the
# target arch. Skipped at .so build time so the rest of the suite
# still links; the Python harness sees the missing extern "C" launcher
# and reports SKIP for those tests.
#
# gfx1201 / gfx1200 (Navi 44/48, RDNA4): opus _async_load uses
# __builtin_amdgcn_raw_ptr_buffer_load_lds which needs the
# `vmem-to-lds-load-insts` target feature (gfx9x / gfx950 / gfx1250 only).
_ARCH_SKIP_SOURCES = {
"gfx1200": {"test_async_load.cu", "test_load_store_if.cu"},
"gfx1201": {"test_async_load.cu", "test_load_store_if.cu"},
}
skip = _ARCH_SKIP_SOURCES.get(arch, set())
sources = [s for s in _CU_SOURCES if s not in skip]
if verbose and skip:
for s in sorted(skip):
print(f"[setup] skip {s} (incompatible with arch={arch})")

t0 = time.monotonic()

# Parallel compile: each .cu -> .o
tasks = []
for s in _CU_SOURCES:
for s in sources:
src = os.path.join(_THIS_DIR, s)
obj = os.path.join(_THIS_DIR, s.replace(".cu", ".o"))
tasks.append((src, obj, hipcc, arch, verbose))
extra = ["-mwavefrontsize64"] if s in _W64_SOURCES else []
tasks.append((src, obj, hipcc, arch, verbose, extra))

objs = []
with ProcessPoolExecutor(max_workers=jobs) as pool:
Expand Down
Loading
Loading