99#include " ds_kernel_utils.h"
1010#include " memory_access_utils.h"
1111
12+ #if defined(BF16_AVAILABLE) && defined(__HIP_PLATFORM_AMD__)
13+ #include < hip/hip_bfloat16.h>
14+ #endif
15+
1216namespace cg = cooperative_groups;
1317
1418namespace reduce {
@@ -374,7 +378,11 @@ DS_D_INLINE __half init<ROpType::Max>()
374378template <>
375379DS_D_INLINE __nv_bfloat16 init<ROpType::Max>()
376380{
381+ #ifdef __HIP_PLATFORM_AMD__
382+ constexpr __hip_bfloat16_raw neg_inf = {0xFF80 };
383+ #else
377384 constexpr __nv_bfloat16_raw neg_inf = {0xFF80 };
385+ #endif
378386 return __nv_bfloat16 (neg_inf);
379387}
380388#endif
@@ -526,29 +534,12 @@ here (fold is C++17 only and I don't think helps and recursion feels like
526534huge overkill that harms readability) that would be wonderful.
527535*/
528536
529- template <typename T>
530- DS_D_INLINE T shfl_xor_helper (cg::thread_block_tile<hw_warp_size>& warp, const T& value, int i)
531- {
532- return warp.shfl_xor (value, i);
533- }
534-
535- #if defined(__HIP_PLATFORM_AMD__)
536- template <>
537- DS_D_INLINE __half shfl_xor_helper<__half>(cg::thread_block_tile<hw_warp_size>& warp,
538- const __half& value,
539- int i)
540- {
541- float fvalue = __half2float (value);
542- return __half (warp.shfl_xor (fvalue, i));
543- }
544- #endif
545-
546537template <typename T, ROpType Op, int reduce_width = hw_warp_size>
547538DS_D_INLINE void _warp (cg::thread_block_tile<hw_warp_size>& warp, T* data)
548539{
549540#pragma unroll
550541 for (int i = 1 ; i < reduce_width; i *= 2 ) {
551- data[0 ] = element<Op>(data[0 ], shfl_xor_helper ( warp, data[0 ], i));
542+ data[0 ] = element<Op>(data[0 ], warp. shfl_xor ( data[0 ], i));
552543 }
553544}
554545
@@ -557,8 +548,8 @@ DS_D_INLINE void _warp(cg::thread_block_tile<hw_warp_size>& warp, T* data)
557548{
558549#pragma unroll
559550 for (int i = 1 ; i < reduce_width; i *= 2 ) {
560- data[0 ] = element<Op1>(data[0 ], shfl_xor_helper ( warp, data[0 ], i));
561- data[1 ] = element<Op2>(data[1 ], shfl_xor_helper ( warp, data[1 ], i));
551+ data[0 ] = element<Op1>(data[0 ], warp. shfl_xor ( data[0 ], i));
552+ data[1 ] = element<Op2>(data[1 ], warp. shfl_xor ( data[1 ], i));
562553 }
563554}
564555
@@ -567,9 +558,9 @@ DS_D_INLINE void _warp(cg::thread_block_tile<hw_warp_size>& warp, T* data)
567558{
568559#pragma unroll
569560 for (int i = 1 ; i < reduce_width; i *= 2 ) {
570- data[0 ] = element<Op1>(data[0 ], shfl_xor_helper ( warp, data[0 ], i));
571- data[1 ] = element<Op2>(data[1 ], shfl_xor_helper ( warp, data[1 ], i));
572- data[2 ] = element<Op3>(data[2 ], shfl_xor_helper ( warp, data[2 ], i));
561+ data[0 ] = element<Op1>(data[0 ], warp. shfl_xor ( data[0 ], i));
562+ data[1 ] = element<Op2>(data[1 ], warp. shfl_xor ( data[1 ], i));
563+ data[2 ] = element<Op3>(data[2 ], warp. shfl_xor ( data[2 ], i));
573564 }
574565}
575566
@@ -583,13 +574,39 @@ DS_D_INLINE void _warp(cg::thread_block_tile<hw_warp_size>& warp, T* data)
583574{
584575#pragma unroll
585576 for (int i = 1 ; i < reduce_width; i *= 2 ) {
586- data[0 ] = element<Op1>(data[0 ], shfl_xor_helper ( warp, data[0 ], i));
587- data[1 ] = element<Op2>(data[1 ], shfl_xor_helper ( warp, data[1 ], i));
588- data[2 ] = element<Op3>(data[2 ], shfl_xor_helper ( warp, data[2 ], i));
589- data[3 ] = element<Op4>(data[3 ], shfl_xor_helper ( warp, data[3 ], i));
577+ data[0 ] = element<Op1>(data[0 ], warp. shfl_xor ( data[0 ], i));
578+ data[1 ] = element<Op2>(data[1 ], warp. shfl_xor ( data[1 ], i));
579+ data[2 ] = element<Op3>(data[2 ], warp. shfl_xor ( data[2 ], i));
580+ data[3 ] = element<Op4>(data[3 ], warp. shfl_xor ( data[3 ], i));
590581 }
591582}
592583
584+ #if defined(__HIP_PLATFORM_AMD__)
585+ template <int reduce_width, typename T, ROpType... Ops>
586+ DS_D_INLINE void _warp_with_type_conversion (
587+ cg::thread_block_tile<hw_warp_size>& warp_arg,
588+ T* data)
589+ {
590+ constexpr int elems = sizeof ...(Ops);
591+ if constexpr (
592+ !(std::is_integral<T>::value || std::is_floating_point<T>::value)
593+ ) {
594+ float temp_data[elems];
595+ #pragma unroll
596+ for (int i = 0 ; i < elems; i++) {
597+ temp_data[i] = conversion::to<float >(data[i]);
598+ }
599+ _warp<float , Ops...>(warp_arg, temp_data);
600+ #pragma unroll
601+ for (int i = 0 ; i < elems; i++) {
602+ data[i] = conversion::to<T>(temp_data[i]);
603+ }
604+ } else {
605+ _warp<T, Ops...>(warp_arg, data);
606+ }
607+ }
608+ #endif // defined(__HIP_PLATFORM_AMD__)
609+
593610/*
594611Implementation for primary block reduction that serves both `block` and
595612`partitioned_block`.
@@ -617,7 +634,11 @@ DS_D_INLINE void _block(cg::thread_block& tb,
617634#endif
618635
619636 // Always perform warp-scope reduction
637+ #ifdef __HIP_PLATFORM_AMD__
638+ _warp_with_type_conversion<hw_warp_size, T, Ops...>(warp_arg, data);
639+ #else
620640 _warp<T, Ops...>(warp_arg, data);
641+ #endif
621642
622643 // If max_warps == 1 let's skip the runtime check
623644 if (total_warps != 1 ) {
@@ -641,8 +662,12 @@ DS_D_INLINE void _block(cg::thread_block& tb,
641662 } else {
642663 init<Ops...>(data);
643664 }
644-
665+ #ifdef __HIP_PLATFORM_AMD__
666+ _warp_with_type_conversion<total_warps, T, Ops...>(warp_arg, data);
667+ #else
645668 _warp<T, Ops..., total_warps>(warp_arg, data);
669+ #endif
670+
646671
647672#pragma unroll
648673 for (int i = 0 ; i < elems; i++) {
0 commit comments