From 4232bcd3e4480337c8e25de087ab3b5f0af9531d Mon Sep 17 00:00:00 2001 From: Zach Vincze Date: Wed, 28 Jan 2026 14:58:24 -0500 Subject: [PATCH 01/10] Avoid branching in casting implementations --- include/core/detail/casting.hpp | 90 +++++++++++++++++++++------------ src/op_composite.cpp | 2 +- 2 files changed, 58 insertions(+), 34 deletions(-) diff --git a/include/core/detail/casting.hpp b/include/core/detail/casting.hpp index ea21b0bc..bc6bedf9 100644 --- a/include/core/detail/casting.hpp +++ b/include/core/detail/casting.hpp @@ -36,36 +36,48 @@ __device__ __host__ T ScalarSaturateCast(U v) { constexpr bool bigToSmall = !smallToBig; if constexpr (std::is_integral_v && std::is_floating_point_v) { - // Any float -> any integral - return static_cast(std::clamp(std::round(v), static_cast(std::numeric_limits::min()), - static_cast(std::numeric_limits::max()))); - } else if constexpr (std::is_integral_v && std::is_integral_v && std::is_signed_v && std::is_signed_v && - smallToBig) { - // Any integral signed -> Any integral unsigned, small -> big or equal - return v <= 0 ? 0 : static_cast(v); - } else if constexpr (std::is_integral_v && std::is_integral_v && - ((std::is_signed_v && std::is_signed_v) || - (std::is_unsigned_v && std::is_unsigned_v)) && - bigToSmall) { - // Any integral signed -> Any integral signed, big -> small - // Any integral unsigned -> Any integral unsigned, big -> small - return v <= std::numeric_limits::min() - ? std::numeric_limits::min() - : (v >= std::numeric_limits::max() ? std::numeric_limits::max() : static_cast(v)); - } else if constexpr (std::is_integral_v && std::is_unsigned_v && std::is_integral_v && - std::is_signed_v) { - // Any integral unsigned -> Any integral signed, small -> big or equal - return v >= std::numeric_limits::max() ? std::numeric_limits::max() : static_cast(v); - } else if constexpr (std::is_integral_v && std::is_signed_v && std::is_integral_v && - std::is_unsigned_v && bigToSmall) { - // Any integral signed -> Any integral unsigned, big -> small - return v <= static_cast(std::numeric_limits::min()) - ? std::numeric_limits::min() - : (v >= static_cast(std::numeric_limits::max()) ? std::numeric_limits::max() - : static_cast(v)); - } else { - // All other cases fall into this - return v; + // Float -> integral: clamp then round + constexpr U minVal = static_cast(std::numeric_limits::min()); + constexpr U maxVal = static_cast(std::numeric_limits::max()); +#ifdef __HIP_DEVICE_COMPILE__ + return static_cast(rintf(fminf(fmaxf(v, minVal), maxVal))); +#else + return static_cast(std::round(std::clamp(v, minVal, maxVal))); +#endif + } + + else if constexpr (std::is_integral_v && std::is_integral_v && std::is_signed_v && std::is_unsigned_v && + smallToBig) { + // Signed -> unsigned, small to big: clamp negative to 0 + // Branchless: max(v, 0) handles negative values + return static_cast(max(v, U{0})); + } + + else if constexpr (std::is_integral_v && std::is_integral_v && + ((std::is_signed_v && std::is_signed_v) || + (std::is_unsigned_v && std::is_unsigned_v)) && + bigToSmall) { + // Same signedness, big -> small: clamp to [min, max] + constexpr U minVal = static_cast(std::numeric_limits::min()); + constexpr U maxVal = static_cast(std::numeric_limits::max()); + return static_cast(min(max(v, minVal), maxVal)); + } + + else if constexpr (std::is_integral_v && std::is_unsigned_v && std::is_integral_v && std::is_signed_v) { + // Unsigned -> signed: clamp to max (can't exceed min since unsigned) + constexpr U maxVal = static_cast(std::numeric_limits::max()); + return static_cast(min(v, maxVal)); + } + + else if constexpr (std::is_integral_v && std::is_signed_v && std::is_integral_v && std::is_unsigned_v && + bigToSmall) { + // Signed -> unsigned, big -> small: clamp to [0, max] + constexpr U maxVal = static_cast(std::numeric_limits::max()); + return static_cast(min(max(v, U{0}), maxVal)); + } + + else { + return static_cast(v); } } @@ -117,9 +129,21 @@ __device__ __host__ T ScalarRangeCast(U v) { else if constexpr (std::is_integral_v && std::is_floating_point_v && std::is_unsigned_v) { // float to unsigned integers - return v >= T{1} ? std::numeric_limits::max() - : v <= T{0} ? 0 - : static_cast(lrintf(static_cast(std::numeric_limits::max()) * v)); + constexpr U scale = static_cast(std::numeric_limits::max()); + + if constexpr (sizeof(T) <= 2) { + // 8/16 bit integer cases. These can be represented exactly in floating point. +#ifdef __HIP_DEVICE_COMPILE__ + return static_cast(__float2int_rn(__saturatef(v) * scale)); +#else + return static_cast(lrintf(fminf(fmaxf(v, 0.0f), 1.0f) * scale)); +#endif + } else { + // 32/64 bit integer cases. + return v >= U{1} ? std::numeric_limits::max() + : v <= U{-1} ? std::numeric_limits::min() + : static_cast(std::round(v * scale)); + } } else if constexpr (std::is_floating_point_v && std::is_integral_v && std::is_signed_v) { diff --git a/src/op_composite.cpp b/src/op_composite.cpp index 6d49a156..fedc7a9b 100644 --- a/src/op_composite.cpp +++ b/src/op_composite.cpp @@ -40,7 +40,7 @@ void dispatch_composite_masktype(hipStream_t stream, const Tensor& foreground, c switch (device) { case eDeviceType::GPU: { - dim3 block(64, 16); + dim3 block(32, 8); dim3 grid((outputWrapper.width() + block.x - 1) / block.x, (outputWrapper.height() + block.y - 1) / block.y, outputWrapper.batches()); Kernels::Device::composite<<>>(fgWrapper, bgWrapper, maskWrapper, outputWrapper); From 77cabc7a032143d0a13d08550228ef675d69ec03 Mon Sep 17 00:00:00 2001 From: Zach Vincze Date: Tue, 3 Feb 2026 12:22:49 -0500 Subject: [PATCH 02/10] Add more tests for Saturate cast --- .../tests/core/detail/test_saturate_cast.cpp | 51 +++++++++++++++++++ 1 file changed, 51 insertions(+) create mode 100644 tests/roccv/cpp/src/tests/core/detail/test_saturate_cast.cpp diff --git a/tests/roccv/cpp/src/tests/core/detail/test_saturate_cast.cpp b/tests/roccv/cpp/src/tests/core/detail/test_saturate_cast.cpp new file mode 100644 index 00000000..015265c5 --- /dev/null +++ b/tests/roccv/cpp/src/tests/core/detail/test_saturate_cast.cpp @@ -0,0 +1,51 @@ +/* + * Copyright (c) 2026 Advanced Micro Devices, Inc. All rights reserved. + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ + +#include + +#include "test_helpers.hpp" + +using namespace roccv::detail; +using namespace roccv::tests; +using namespace roccv; + +int main(int argc, char **argv) { + TEST_CASES_BEGIN(); + + EXPECT_EQ(SaturateCast(1.0f), 1); + EXPECT_EQ(SaturateCast(-1.0f), -1); + EXPECT_EQ(SaturateCast(1.0f), 1); + EXPECT_EQ(SaturateCast(-1.0f), 0); + EXPECT_EQ(SaturateCast(1), 1.0f); + EXPECT_EQ(SaturateCast(-1), -1.0f); + EXPECT_EQ(SaturateCast(1), 1.0); + EXPECT_EQ(SaturateCast(-1), -1.0); + + // Test numeric limits + EXPECT_EQ(SaturateCast(std::numeric_limits::max()), std::numeric_limits::max()); + EXPECT_EQ(SaturateCast(std::numeric_limits::max()), std::numeric_limits::max()); + + // Test vectorized types + EXPECT_TRUE((SaturateCast(uchar4{255, 128, 0, 255}) == float4{255.0f, 128.0f, 0.0f, 255.0f})); + EXPECT_TRUE((SaturateCast(char4{-128, -128, -128, -128}) == float4{-128.0f, -128.0f, -128.0f, -128.0f})); + + TEST_CASES_END(); +} \ No newline at end of file From d887102cacc81425561b102d364c9646d5e5d5b2 Mon Sep 17 00:00:00 2001 From: Zach Vincze Date: Fri, 6 Feb 2026 15:08:08 -0500 Subject: [PATCH 03/10] Fix issues with float -> integer saturate casts --- include/core/detail/casting.hpp | 25 ++++++++++++----- include/core/detail/type_traits.hpp | 3 +++ .../tests/core/detail/test_saturate_cast.cpp | 27 ++++++++++--------- 3 files changed, 37 insertions(+), 18 deletions(-) diff --git a/include/core/detail/casting.hpp b/include/core/detail/casting.hpp index bc6bedf9..d9c8ac14 100644 --- a/include/core/detail/casting.hpp +++ b/include/core/detail/casting.hpp @@ -21,8 +21,6 @@ #pragma once -#include - #include "core/detail/type_traits.hpp" namespace roccv::detail { @@ -36,14 +34,29 @@ __device__ __host__ T ScalarSaturateCast(U v) { constexpr bool bigToSmall = !smallToBig; if constexpr (std::is_integral_v && std::is_floating_point_v) { - // Float -> integral: clamp then round - constexpr U minVal = static_cast(std::numeric_limits::min()); + // Float -> integral: clamp to [min, max] then round. + constexpr U minVal = static_cast(std::numeric_limits::lowest()); constexpr U maxVal = static_cast(std::numeric_limits::max()); + + if constexpr (sizeof(T) <= 2) { + // 8/16 bit integer cases. These can be represented exactly in floating point. +#ifdef __HIP_DEVICE_COMPILE__ + return static_cast(rintf(fminf(fmaxf(v, minVal), maxVal))); +#else + return static_cast(std::round(std::clamp(v, minVal, maxVal))); +#endif + } else { + // 32/64 bit integer cases. #ifdef __HIP_DEVICE_COMPILE__ - return static_cast(rintf(fminf(fmaxf(v, minVal), maxVal))); + U rounded = rintf(v); #else - return static_cast(std::round(std::clamp(v, minVal, maxVal))); + U rounded = std::round(v); #endif + + return rounded >= maxVal ? std::numeric_limits::max() + : rounded <= minVal ? std::numeric_limits::min() + : static_cast(rounded); + } } else if constexpr (std::is_integral_v && std::is_integral_v && std::is_signed_v && std::is_unsigned_v && diff --git a/include/core/detail/type_traits.hpp b/include/core/detail/type_traits.hpp index dcf77eb0..32f14d58 100644 --- a/include/core/detail/type_traits.hpp +++ b/include/core/detail/type_traits.hpp @@ -20,6 +20,7 @@ */ #include + #include #pragma once @@ -83,6 +84,8 @@ DEFINE_TYPE_TRAITS_0_TO_4(int, signed int); DEFINE_TYPE_TRAITS_0_TO_4(short, signed short); DEFINE_TYPE_TRAITS_0_TO_4(ushort, unsigned short); DEFINE_TYPE_TRAITS_0_TO_4(double, double); +DEFINE_TYPE_TRAITS_0_TO_4(long, signed long); +DEFINE_TYPE_TRAITS_0_TO_4(ulong, unsigned long); /** * @brief Returns the number of elements in a HIP vectorized type. For example: uchar3 will return 3, int2 will diff --git a/tests/roccv/cpp/src/tests/core/detail/test_saturate_cast.cpp b/tests/roccv/cpp/src/tests/core/detail/test_saturate_cast.cpp index 015265c5..84be5ee2 100644 --- a/tests/roccv/cpp/src/tests/core/detail/test_saturate_cast.cpp +++ b/tests/roccv/cpp/src/tests/core/detail/test_saturate_cast.cpp @@ -30,22 +30,25 @@ using namespace roccv; int main(int argc, char **argv) { TEST_CASES_BEGIN(); - EXPECT_EQ(SaturateCast(1.0f), 1); - EXPECT_EQ(SaturateCast(-1.0f), -1); - EXPECT_EQ(SaturateCast(1.0f), 1); - EXPECT_EQ(SaturateCast(-1.0f), 0); - EXPECT_EQ(SaturateCast(1), 1.0f); - EXPECT_EQ(SaturateCast(-1), -1.0f); - EXPECT_EQ(SaturateCast(1), 1.0); - EXPECT_EQ(SaturateCast(-1), -1.0); + TEST_CASE(EXPECT_EQ(SaturateCast(1.0f), 1)); + TEST_CASE(EXPECT_EQ(SaturateCast(-1.0f), -1)); + TEST_CASE(EXPECT_EQ(SaturateCast(1.0f), 1)); + TEST_CASE(EXPECT_EQ(SaturateCast(-1.0f), 0)); + TEST_CASE(EXPECT_EQ(SaturateCast(1), 1.0f)); + TEST_CASE(EXPECT_EQ(SaturateCast(-1), -1.0f)); + TEST_CASE(EXPECT_EQ(SaturateCast(1), 1.0)); + TEST_CASE(EXPECT_EQ(SaturateCast(-1), -1.0)); // Test numeric limits - EXPECT_EQ(SaturateCast(std::numeric_limits::max()), std::numeric_limits::max()); - EXPECT_EQ(SaturateCast(std::numeric_limits::max()), std::numeric_limits::max()); + TEST_CASE(EXPECT_EQ(SaturateCast(std::numeric_limits::max()), std::numeric_limits::max())); + TEST_CASE(EXPECT_EQ(SaturateCast(std::numeric_limits::max()), std::numeric_limits::max())); + TEST_CASE(EXPECT_EQ(SaturateCast(std::numeric_limits::max()), std::numeric_limits::max())); + TEST_CASE(EXPECT_EQ(SaturateCast(std::numeric_limits::lowest()), 0UL)); // Test vectorized types - EXPECT_TRUE((SaturateCast(uchar4{255, 128, 0, 255}) == float4{255.0f, 128.0f, 0.0f, 255.0f})); - EXPECT_TRUE((SaturateCast(char4{-128, -128, -128, -128}) == float4{-128.0f, -128.0f, -128.0f, -128.0f})); + TEST_CASE(EXPECT_TRUE((SaturateCast(uchar4{255, 128, 0, 255}) == float4{255.0f, 128.0f, 0.0f, 255.0f}))); + TEST_CASE(EXPECT_TRUE( + (SaturateCast(char4{-128, -128, -128, -128}) == float4{-128.0f, -128.0f, -128.0f, -128.0f}))); TEST_CASES_END(); } \ No newline at end of file From 146a1f9740d28eef6b0c7b150e8781fb01ec91a2 Mon Sep 17 00:00:00 2001 From: Zach Vincze Date: Fri, 6 Feb 2026 15:28:08 -0500 Subject: [PATCH 04/10] Undo changes to composite --- src/op_composite.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/op_composite.cpp b/src/op_composite.cpp index fedc7a9b..6d49a156 100644 --- a/src/op_composite.cpp +++ b/src/op_composite.cpp @@ -40,7 +40,7 @@ void dispatch_composite_masktype(hipStream_t stream, const Tensor& foreground, c switch (device) { case eDeviceType::GPU: { - dim3 block(32, 8); + dim3 block(64, 16); dim3 grid((outputWrapper.width() + block.x - 1) / block.x, (outputWrapper.height() + block.y - 1) / block.y, outputWrapper.batches()); Kernels::Device::composite<<>>(fgWrapper, bgWrapper, maskWrapper, outputWrapper); From e9e9f0b8b8fd7a516f95762a4526fcae12a2dc00 Mon Sep 17 00:00:00 2001 From: Zach Vincze Date: Fri, 6 Feb 2026 16:18:22 -0500 Subject: [PATCH 05/10] Review fixes --- include/core/detail/casting.hpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/include/core/detail/casting.hpp b/include/core/detail/casting.hpp index d9c8ac14..688db24d 100644 --- a/include/core/detail/casting.hpp +++ b/include/core/detail/casting.hpp @@ -21,6 +21,8 @@ #pragma once +#include + #include "core/detail/type_traits.hpp" namespace roccv::detail { @@ -153,9 +155,7 @@ __device__ __host__ T ScalarRangeCast(U v) { #endif } else { // 32/64 bit integer cases. - return v >= U{1} ? std::numeric_limits::max() - : v <= U{-1} ? std::numeric_limits::min() - : static_cast(std::round(v * scale)); + return v >= U{1} ? std::numeric_limits::max() : v <= U{0} ? 0 : static_cast(std::round(v * scale)); } } From 13a78bec09a2d3e04fa7ad3a50a1cde242fe7cfb Mon Sep 17 00:00:00 2001 From: Zach Vincze Date: Fri, 6 Feb 2026 16:21:04 -0500 Subject: [PATCH 06/10] Add another test case for RangeCast --- tests/roccv/cpp/src/tests/core/detail/test_range_cast.cpp | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/roccv/cpp/src/tests/core/detail/test_range_cast.cpp b/tests/roccv/cpp/src/tests/core/detail/test_range_cast.cpp index 7a2ccf5b..2d9ea7dd 100644 --- a/tests/roccv/cpp/src/tests/core/detail/test_range_cast.cpp +++ b/tests/roccv/cpp/src/tests/core/detail/test_range_cast.cpp @@ -37,6 +37,8 @@ int main(int argc, char **argv) { TEST_CASE(EXPECT_EQ(RangeCast(-1.0f), std::numeric_limits::min())); TEST_CASE(EXPECT_EQ(RangeCast(1.0f), std::numeric_limits::max())); TEST_CASE(EXPECT_EQ(RangeCast(-1.0f), 0)); + TEST_CASE(EXPECT_EQ(RangeCast(0.0f), 0)); + // Test unsigned/signed integer -> float casting TEST_CASE(EXPECT_EQ(RangeCast(std::numeric_limits::max()), 1.0f)); From 883a1f09f47c0807323ad1e01bcb44f5c5cef036 Mon Sep 17 00:00:00 2001 From: Zach Vincze Date: Mon, 27 Apr 2026 21:02:41 -0400 Subject: [PATCH 07/10] Fix double precision issue in scalar range casting + add unified rounding helper --- include/core/detail/casting.hpp | 84 ++++++++++++++----- .../tests/core/detail/test_saturate_cast.cpp | 3 + 2 files changed, 67 insertions(+), 20 deletions(-) diff --git a/include/core/detail/casting.hpp b/include/core/detail/casting.hpp index 688db24d..b7a322ee 100644 --- a/include/core/detail/casting.hpp +++ b/include/core/detail/casting.hpp @@ -22,11 +22,50 @@ #pragma once #include +#include #include "core/detail/type_traits.hpp" namespace roccv::detail { +/** + * @brief Rounds a floating-point value to the nearest integer using IEEE + * half-to-even rounding (the default rounding mode). Matches the semantics of + * __float2int_rn on device. Selects single- vs double-precision based on the + * argument type to avoid silent precision loss when U is double. + */ +template +__device__ __host__ inline U IEEERound(U v) { + static_assert(std::is_floating_point_v, "IEEERound requires a floating-point input"); +#ifdef __HIP_DEVICE_COMPILE__ + if constexpr (std::is_same_v) { + return rintf(v); + } else { + return rint(v); + } +#else + return std::rint(v); +#endif +} + +/** + * @brief Clamps v to [lo, hi]. Uses fminf/fmin/fmaxf/fmax on device to avoid + * the branchy std::clamp implementation. + */ +template +__device__ __host__ inline U FpClamp(U v, U lo, U hi) { + static_assert(std::is_floating_point_v, "FpClamp requires a floating-point input"); +#ifdef __HIP_DEVICE_COMPILE__ + if constexpr (std::is_same_v) { + return fminf(fmaxf(v, lo), hi); + } else { + return fmin(fmax(v, lo), hi); + } +#else + return std::clamp(v, lo, hi); +#endif +} + /** * @brief ScalarSaturateCast is for implementation purposes only. Use SaturateCast directly. */ @@ -36,25 +75,17 @@ __device__ __host__ T ScalarSaturateCast(U v) { constexpr bool bigToSmall = !smallToBig; if constexpr (std::is_integral_v && std::is_floating_point_v) { - // Float -> integral: clamp to [min, max] then round. + // Float -> integral: clamp to [min, max] then round (IEEE half-to-even). constexpr U minVal = static_cast(std::numeric_limits::lowest()); constexpr U maxVal = static_cast(std::numeric_limits::max()); if constexpr (sizeof(T) <= 2) { // 8/16 bit integer cases. These can be represented exactly in floating point. -#ifdef __HIP_DEVICE_COMPILE__ - return static_cast(rintf(fminf(fmaxf(v, minVal), maxVal))); -#else - return static_cast(std::round(std::clamp(v, minVal, maxVal))); -#endif + return static_cast(IEEERound(FpClamp(v, minVal, maxVal))); } else { - // 32/64 bit integer cases. -#ifdef __HIP_DEVICE_COMPILE__ - U rounded = rintf(v); -#else - U rounded = std::round(v); -#endif - + // 32/64 bit integer cases. maxVal may round up to an unrepresentable + // value when cast back, so compare against the rounded source. + const U rounded = IEEERound(v); return rounded >= maxVal ? std::numeric_limits::max() : rounded <= minVal ? std::numeric_limits::min() : static_cast(rounded); @@ -136,10 +167,19 @@ __device__ __host__ T ScalarRangeCast(U v) { } else if constexpr (std::is_integral_v && std::is_floating_point_v && std::is_signed_v) { - // Float to signed integers - return v >= T{1} ? std::numeric_limits::max() - : v <= T{-1} ? std::numeric_limits::min() - : static_cast(std::round(static_cast(std::numeric_limits::max()) * v)); + // Float to signed integer. Map [-1, 1] -> [min, max] with IEEE half-to-even rounding. + constexpr U scale = static_cast(std::numeric_limits::max()); + + if constexpr (sizeof(T) <= 2) { + // 8/16 bit signed cases. These can be represented exactly in floating point, + // so clamp first then round. + return static_cast(IEEERound(FpClamp(v, U{-1}, U{1}) * scale)); + } else { + // 32/64 bit signed cases. + return v >= U{1} ? std::numeric_limits::max() + : v <= U{-1} ? std::numeric_limits::min() + : static_cast(IEEERound(scale * v)); + } } else if constexpr (std::is_integral_v && std::is_floating_point_v && std::is_unsigned_v) { @@ -149,13 +189,17 @@ __device__ __host__ T ScalarRangeCast(U v) { if constexpr (sizeof(T) <= 2) { // 8/16 bit integer cases. These can be represented exactly in floating point. #ifdef __HIP_DEVICE_COMPILE__ - return static_cast(__float2int_rn(__saturatef(v) * scale)); + if constexpr (std::is_same_v) { + return static_cast(__float2int_rn(__saturatef(v) * scale)); + } else { + return static_cast(IEEERound(FpClamp(v, U{0}, U{1}) * scale)); + } #else - return static_cast(lrintf(fminf(fmaxf(v, 0.0f), 1.0f) * scale)); + return static_cast(IEEERound(FpClamp(v, U{0}, U{1}) * scale)); #endif } else { // 32/64 bit integer cases. - return v >= U{1} ? std::numeric_limits::max() : v <= U{0} ? 0 : static_cast(std::round(v * scale)); + return v >= U{1} ? std::numeric_limits::max() : v <= U{0} ? T{0} : static_cast(IEEERound(v * scale)); } } diff --git a/tests/roccv/cpp/src/tests/core/detail/test_saturate_cast.cpp b/tests/roccv/cpp/src/tests/core/detail/test_saturate_cast.cpp index 84be5ee2..5d63b184 100644 --- a/tests/roccv/cpp/src/tests/core/detail/test_saturate_cast.cpp +++ b/tests/roccv/cpp/src/tests/core/detail/test_saturate_cast.cpp @@ -28,6 +28,9 @@ using namespace roccv::tests; using namespace roccv; int main(int argc, char **argv) { + (void)argc; + (void)argv; + TEST_CASES_BEGIN(); TEST_CASE(EXPECT_EQ(SaturateCast(1.0f), 1)); From bbeefcf59a18bd68611a6dc8a4a6d99b10bc79e6 Mon Sep 17 00:00:00 2001 From: Zach Vincze Date: Mon, 27 Apr 2026 21:12:30 -0400 Subject: [PATCH 08/10] Improve vector saturate/range/static cast constexpr branching --- include/core/detail/casting.hpp | 73 +++++++++++++++++++++------------ 1 file changed, 47 insertions(+), 26 deletions(-) diff --git a/include/core/detail/casting.hpp b/include/core/detail/casting.hpp index b7a322ee..1ad80abf 100644 --- a/include/core/detail/casting.hpp +++ b/include/core/detail/casting.hpp @@ -141,18 +141,25 @@ __device__ __host__ T ScalarSaturateCast(U v) { template && HasTypeTraits) && (NumElements <= NumElements)>> __device__ __host__ T SaturateCast(U v) { + using B = BaseType; if constexpr (std::is_same_v) { return v; + } else if constexpr (NumElements == 1) { + return T{ScalarSaturateCast(GetElement(v, 0))}; + } else if constexpr (NumElements == 2) { + return T{ScalarSaturateCast(GetElement(v, 0)), + ScalarSaturateCast(GetElement(v, 1))}; + } else if constexpr (NumElements == 3) { + return T{ScalarSaturateCast(GetElement(v, 0)), + ScalarSaturateCast(GetElement(v, 1)), + ScalarSaturateCast(GetElement(v, 2))}; + } else { + static_assert(NumElements == 4, "SaturateCast supports up to 4-element vectors"); + return T{ScalarSaturateCast(GetElement(v, 0)), + ScalarSaturateCast(GetElement(v, 1)), + ScalarSaturateCast(GetElement(v, 2)), + ScalarSaturateCast(GetElement(v, 3))}; } - - T ret{}; - - GetElement(ret, 0) = ScalarSaturateCast>(GetElement(v, 0)); - if constexpr (NumElements >= 2) GetElement(ret, 1) = ScalarSaturateCast>(GetElement(v, 1)); - if constexpr (NumElements >= 3) GetElement(ret, 2) = ScalarSaturateCast>(GetElement(v, 2)); - if constexpr (NumElements >= 4) GetElement(ret, 3) = ScalarSaturateCast>(GetElement(v, 3)); - - return ret; } /** @@ -243,18 +250,25 @@ __device__ __host__ T ScalarRangeCast(U v) { template && HasTypeTraits) && NumElements <= NumElements>> __device__ __host__ T RangeCast(U v) { + using B = BaseType; if constexpr (std::is_same_v) { return v; + } else if constexpr (NumElements == 1) { + return T{ScalarRangeCast(GetElement(v, 0))}; + } else if constexpr (NumElements == 2) { + return T{ScalarRangeCast(GetElement(v, 0)), + ScalarRangeCast(GetElement(v, 1))}; + } else if constexpr (NumElements == 3) { + return T{ScalarRangeCast(GetElement(v, 0)), + ScalarRangeCast(GetElement(v, 1)), + ScalarRangeCast(GetElement(v, 2))}; + } else { + static_assert(NumElements == 4, "RangeCast supports up to 4-element vectors"); + return T{ScalarRangeCast(GetElement(v, 0)), + ScalarRangeCast(GetElement(v, 1)), + ScalarRangeCast(GetElement(v, 2)), + ScalarRangeCast(GetElement(v, 3))}; } - - T ret{}; - - GetElement(ret, 0) = ScalarRangeCast>(GetElement(v, 0)); - if constexpr (NumElements >= 2) GetElement(ret, 1) = ScalarRangeCast>(GetElement(v, 1)); - if constexpr (NumElements >= 3) GetElement(ret, 2) = ScalarRangeCast>(GetElement(v, 2)); - if constexpr (NumElements >= 4) GetElement(ret, 3) = ScalarRangeCast>(GetElement(v, 3)); - - return ret; } /** @@ -268,21 +282,28 @@ __device__ __host__ T RangeCast(U v) { template && HasTypeTraits) && NumElements <= NumElements>> __device__ __host__ T StaticCast(U v) { + using B = BaseType; if constexpr (std::is_same_v) { // Both same type, just return the value. return v; } else if constexpr (!IsCompound && !IsCompound) { // Both scalar values. Reduces to a standard static cast. return static_cast(v); + } else if constexpr (NumElements == 1) { + return T{StaticCast(GetElement(v, 0))}; + } else if constexpr (NumElements == 2) { + return T{StaticCast(GetElement(v, 0)), + StaticCast(GetElement(v, 1))}; + } else if constexpr (NumElements == 3) { + return T{StaticCast(GetElement(v, 0)), + StaticCast(GetElement(v, 1)), + StaticCast(GetElement(v, 2))}; } else { - // Vector types. Perform casting on each element. - T ret{}; - GetElement(ret, 0) = StaticCast>(GetElement(v, 0)); - if constexpr (NumElements >= 2) GetElement(ret, 1) = StaticCast>(GetElement(v, 1)); - if constexpr (NumElements >= 3) GetElement(ret, 2) = StaticCast>(GetElement(v, 2)); - if constexpr (NumElements >= 4) GetElement(ret, 3) = StaticCast>(GetElement(v, 3)); - - return ret; + static_assert(NumElements == 4, "StaticCast supports up to 4-element vectors"); + return T{StaticCast(GetElement(v, 0)), + StaticCast(GetElement(v, 1)), + StaticCast(GetElement(v, 2)), + StaticCast(GetElement(v, 3))}; } } } // namespace roccv::detail \ No newline at end of file From 0b10ce35cdb2b8447d76fb61f5be80c3cd77663c Mon Sep 17 00:00:00 2001 From: Zach Vincze Date: Mon, 27 Apr 2026 21:23:14 -0400 Subject: [PATCH 09/10] Add additional tests for casting helpers + introduce static cast tests --- .../src/tests/core/detail/test_range_cast.cpp | 65 ++++++++++++++ .../tests/core/detail/test_saturate_cast.cpp | 75 ++++++++++++++++ .../tests/core/detail/test_static_cast.cpp | 86 +++++++++++++++++++ 3 files changed, 226 insertions(+) create mode 100644 tests/roccv/cpp/src/tests/core/detail/test_static_cast.cpp diff --git a/tests/roccv/cpp/src/tests/core/detail/test_range_cast.cpp b/tests/roccv/cpp/src/tests/core/detail/test_range_cast.cpp index 55001a05..35fc843e 100644 --- a/tests/roccv/cpp/src/tests/core/detail/test_range_cast.cpp +++ b/tests/roccv/cpp/src/tests/core/detail/test_range_cast.cpp @@ -60,6 +60,71 @@ int main(int argc, char **argv) { TEST_CASE(EXPECT_EQ(RangeCast(std::numeric_limits::max()), 1.0f)); TEST_CASE(EXPECT_EQ(RangeCast(0), 0.0f)); + // ----- 8/16-bit signed fast path ----- + TEST_CASE(EXPECT_EQ(RangeCast(1.0f), 127)); + TEST_CASE(EXPECT_EQ(RangeCast(-1.0f), -127)); + TEST_CASE(EXPECT_EQ(RangeCast(0.0f), 0)); + TEST_CASE(EXPECT_EQ(RangeCast(2.0f), 127)); // out-of-range positive clamps + TEST_CASE(EXPECT_EQ(RangeCast(-2.0f), -127)); // out-of-range negative clamps + TEST_CASE(EXPECT_EQ(RangeCast(1.0f), 32767)); + TEST_CASE(EXPECT_EQ(RangeCast(-1.0f), -32767)); + TEST_CASE(EXPECT_EQ(RangeCast(2.0f), 32767)); + TEST_CASE(EXPECT_EQ(RangeCast(-2.0f), -32767)); + + // ----- 8/16-bit unsigned fast path ----- + TEST_CASE(EXPECT_EQ(RangeCast(1.0f), 255)); + TEST_CASE(EXPECT_EQ(RangeCast(0.0f), 0)); + TEST_CASE(EXPECT_EQ(RangeCast(2.0f), 255)); // clamp positive + TEST_CASE(EXPECT_EQ(RangeCast(-0.5f), 0)); // clamp negative + TEST_CASE(EXPECT_EQ(RangeCast(1.0f), 65535)); + TEST_CASE(EXPECT_EQ(RangeCast(-1.0f), 0)); + + // ----- Rounding mode: must be IEEE half-to-even ----- + TEST_CASE(EXPECT_EQ(RangeCast(0.5f / 255.0f), 0)); // would be 1 with std::round + TEST_CASE(EXPECT_EQ(RangeCast(1.5f / 255.0f), 2)); // round half to even + TEST_CASE(EXPECT_EQ(RangeCast(2.5f / 255.0f), 2)); // round half to even (down) + TEST_CASE(EXPECT_EQ(RangeCast(0.5f / 127.0f), 0)); // signed: same rounding rule + TEST_CASE(EXPECT_EQ(RangeCast(-0.5f / 127.0f), 0)); + TEST_CASE(EXPECT_EQ(RangeCast(-1.5f / 127.0f), -2)); + + // ----- Double precision in float -> int ----- + TEST_CASE(EXPECT_EQ(RangeCast(0.5), std::numeric_limits::max() / 2 + 1)); + TEST_CASE(EXPECT_EQ(RangeCast(0.0), 0)); + TEST_CASE(EXPECT_EQ(RangeCast(0.0), 0u)); + + // ----- int -> float clamping: signed min hits the -1.008... clamp ----- + // numeric_limits::min() / max() = -128 / 127 = -1.0078..., must clamp to -1. + TEST_CASE(EXPECT_EQ(RangeCast(int8_t{-128}), -1.0f)); + TEST_CASE(EXPECT_EQ(RangeCast(int8_t{127}), 1.0f)); + TEST_CASE(EXPECT_EQ(RangeCast(int8_t{0}), 0.0f)); + TEST_CASE(EXPECT_EQ(RangeCast(int16_t{-32768}), -1.0f)); + TEST_CASE(EXPECT_EQ(RangeCast(int16_t{32767}), 1.0f)); + + // ----- uint -> float ----- + TEST_CASE(EXPECT_EQ(RangeCast(uint8_t{255}), 1.0f)); + TEST_CASE(EXPECT_EQ(RangeCast(uint8_t{0}), 0.0f)); + TEST_CASE(EXPECT_EQ(RangeCast(uint16_t{65535}), 1.0f)); + + // ----- Integer -> integer falls back to SaturateCast ----- + TEST_CASE(EXPECT_EQ(RangeCast(int32_t{300}), 127)); + TEST_CASE(EXPECT_EQ(RangeCast(int32_t{-1}), 0)); + + // ----- Vector types ----- + // float -> uchar4: 0.0 -> 0, 0.5 -> 128, 1.0 -> 255 (with banker's rounding at half) + TEST_CASE(EXPECT_TRUE( + (RangeCast(float4{0.0f, 0.5f, 1.0f, -0.5f}) == uchar4{0, 128, 255, 0}))); + // uchar4 -> float4: 0 -> 0.0, 255 -> 1.0 + { + float4 result = RangeCast(uchar4{0, 128, 255, 64}); + TEST_CASE(EXPECT_EQ(result.x, 0.0f)); + TEST_CASE(EXPECT_EQ(result.z, 1.0f)); + TEST_CASE(EXPECT_TRUE(std::abs(result.y - (128.0f / 255.0f)) < 1e-6f)); + TEST_CASE(EXPECT_TRUE(std::abs(result.w - (64.0f / 255.0f)) < 1e-6f)); + } + // 2- and 3-element vectors + TEST_CASE(EXPECT_TRUE((RangeCast(float2{0.5f, -10.0f}) == uchar2{128, 0}))); + TEST_CASE(EXPECT_TRUE((RangeCast(float3{1.0f, -1.0f, 0.0f}) == char3{127, -127, 0}))); + // clang-format on TEST_CASES_END(); diff --git a/tests/roccv/cpp/src/tests/core/detail/test_saturate_cast.cpp b/tests/roccv/cpp/src/tests/core/detail/test_saturate_cast.cpp index 5d63b184..eca96d44 100644 --- a/tests/roccv/cpp/src/tests/core/detail/test_saturate_cast.cpp +++ b/tests/roccv/cpp/src/tests/core/detail/test_saturate_cast.cpp @@ -53,5 +53,80 @@ int main(int argc, char **argv) { TEST_CASE(EXPECT_TRUE( (SaturateCast(char4{-128, -128, -128, -128}) == float4{-128.0f, -128.0f, -128.0f, -128.0f}))); + // ----- Rounding mode: must be IEEE half-to-even (banker's rounding) ----- + // These regression-guard against accidentally switching back to std::round + // (half-away-from-zero), which would diverge from the device fast-paths. + TEST_CASE(EXPECT_EQ(SaturateCast(0.5f), 0)); // halfway -> nearest even (0) + TEST_CASE(EXPECT_EQ(SaturateCast(1.5f), 2)); // halfway -> nearest even (2) + TEST_CASE(EXPECT_EQ(SaturateCast(2.5f), 2)); // halfway -> nearest even (2) + TEST_CASE(EXPECT_EQ(SaturateCast(-0.5f), 0)); + TEST_CASE(EXPECT_EQ(SaturateCast(-1.5f), -2)); + TEST_CASE(EXPECT_EQ(SaturateCast(-2.5f), -2)); + // Same rounding rules in the 8/16-bit clamp-then-round path. + TEST_CASE(EXPECT_EQ(SaturateCast(0.5f), 0)); + TEST_CASE(EXPECT_EQ(SaturateCast(1.5f), 2)); + TEST_CASE(EXPECT_EQ(SaturateCast(2.5f), 2)); + TEST_CASE(EXPECT_EQ(SaturateCast(-1.5f), -2)); + // Non-half values should still round to nearest as expected. + TEST_CASE(EXPECT_EQ(SaturateCast(1.4f), 1)); + TEST_CASE(EXPECT_EQ(SaturateCast(1.6f), 2)); + TEST_CASE(EXPECT_EQ(SaturateCast(-1.4f), -1)); + TEST_CASE(EXPECT_EQ(SaturateCast(-1.6f), -2)); + + // ----- Double precision: must NOT be silently truncated to float ----- + TEST_CASE(EXPECT_EQ(SaturateCast(1234567890.7), 1234567891)); + TEST_CASE(EXPECT_EQ(SaturateCast(-1234567890.7), -1234567891)); + TEST_CASE(EXPECT_EQ(SaturateCast(16777217.0), 16777217)); // 2^24+1, not exact in float + TEST_CASE(EXPECT_EQ(SaturateCast(1234567890.5), 1234567890)); // half-to-even + + // ----- Float clamping: out-of-range floats clamp to numeric limits ----- + TEST_CASE(EXPECT_EQ(SaturateCast(300.0f), 255)); + TEST_CASE(EXPECT_EQ(SaturateCast(-1.0f), 0)); + TEST_CASE(EXPECT_EQ(SaturateCast(-100.0f), 0)); + TEST_CASE(EXPECT_EQ(SaturateCast(200.0f), 127)); + TEST_CASE(EXPECT_EQ(SaturateCast(-200.0f), -128)); + TEST_CASE(EXPECT_EQ(SaturateCast(40000.0f), 32767)); + TEST_CASE(EXPECT_EQ(SaturateCast(-40000.0f), -32768)); + TEST_CASE(EXPECT_EQ(SaturateCast(70000.0f), 65535)); + TEST_CASE(EXPECT_EQ(SaturateCast(-1.0f), 0)); + + // ----- Integer narrowing: same signedness ----- + TEST_CASE(EXPECT_EQ(SaturateCast(int32_t{300}), 127)); + TEST_CASE(EXPECT_EQ(SaturateCast(int32_t{-300}), -128)); + TEST_CASE(EXPECT_EQ(SaturateCast(int32_t{42}), 42)); // in-range, passthrough + TEST_CASE(EXPECT_EQ(SaturateCast(uint32_t{300}), 255)); + TEST_CASE(EXPECT_EQ(SaturateCast(uint32_t{42}), 42)); + TEST_CASE(EXPECT_EQ(SaturateCast(int64_t{-100000}), -32768)); + + // ----- Integer cross-signedness narrowing ----- + // Signed -> unsigned, big -> small: clamp negatives to 0, big to max + TEST_CASE(EXPECT_EQ(SaturateCast(int32_t{-1}), 0)); + TEST_CASE(EXPECT_EQ(SaturateCast(int32_t{300}), 255)); + TEST_CASE(EXPECT_EQ(SaturateCast(int32_t{42}), 42)); + // Unsigned -> signed: clamp values exceeding signed max + TEST_CASE(EXPECT_EQ(SaturateCast(uint32_t{300}), 127)); + TEST_CASE(EXPECT_EQ(SaturateCast(uint32_t{42}), 42)); + TEST_CASE(EXPECT_EQ(SaturateCast(uint32_t{70000}), 32767)); + + // ----- Integer cross-signedness widening ----- + // Signed -> unsigned, small to big: clamp negatives to 0 + TEST_CASE(EXPECT_EQ(SaturateCast(int8_t{-1}), 0u)); + TEST_CASE(EXPECT_EQ(SaturateCast(int8_t{-128}), 0u)); + TEST_CASE(EXPECT_EQ(SaturateCast(int8_t{42}), 42u)); + // Unsigned -> signed widening: always representable, no clamping + TEST_CASE(EXPECT_EQ(SaturateCast(uint8_t{255}), 255)); + TEST_CASE(EXPECT_EQ(SaturateCast(uint8_t{0}), 0)); + + // ----- Same-type early-return path ----- + TEST_CASE(EXPECT_EQ(SaturateCast(int{42}), 42)); + TEST_CASE(EXPECT_EQ(SaturateCast(1.5f), 1.5f)); + TEST_CASE(EXPECT_EQ(SaturateCast(uint8_t{200}), 200)); + + // ----- Additional vector coverage: 2- and 3-element types, integer narrowing ----- + TEST_CASE(EXPECT_TRUE((SaturateCast(int2{300, -50}) == uchar2{255, 0}))); + TEST_CASE(EXPECT_TRUE( + (SaturateCast(float3{300.0f, -10.0f, 127.5f}) == uchar3{255, 0, 128}))); // 127.5 rounds to even (128) + TEST_CASE(EXPECT_TRUE((SaturateCast(float4{200.0f, -200.0f, 0.5f, -0.5f}) == char4{127, -128, 0, 0}))); + TEST_CASES_END(); } \ No newline at end of file diff --git a/tests/roccv/cpp/src/tests/core/detail/test_static_cast.cpp b/tests/roccv/cpp/src/tests/core/detail/test_static_cast.cpp new file mode 100644 index 00000000..c7f2d3a9 --- /dev/null +++ b/tests/roccv/cpp/src/tests/core/detail/test_static_cast.cpp @@ -0,0 +1,86 @@ +/* + * Copyright (c) 2026 Advanced Micro Devices, Inc. All rights reserved. + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + */ + +#include + +#include "test_helpers.hpp" + +using namespace roccv::detail; +using namespace roccv::tests; +using namespace roccv; + +int main(int argc, char **argv) { + (void)argc; + (void)argv; + TEST_CASES_BEGIN(); + + // ----- Scalar same-type early return ----- + TEST_CASE(EXPECT_EQ(StaticCast(int{42}), 42)); + TEST_CASE(EXPECT_EQ(StaticCast(1.5f), 1.5f)); + TEST_CASE(EXPECT_EQ(StaticCast(2.5), 2.5)); + + // ----- Scalar conversions: behave exactly like static_cast ----- + // Float -> int: truncates toward zero, no clamping or rounding. + TEST_CASE(EXPECT_EQ(StaticCast(3.7f), 3)); + TEST_CASE(EXPECT_EQ(StaticCast(-3.7f), -3)); + TEST_CASE(EXPECT_EQ(StaticCast(0.999f), 0)); + // int -> float: exact for small values. + TEST_CASE(EXPECT_EQ(StaticCast(int{42}), 42.0f)); + TEST_CASE(EXPECT_EQ(StaticCast(int{-42}), -42.0f)); + // Widening / narrowing integer conversions follow C++ rules (no clamping). + TEST_CASE(EXPECT_EQ(StaticCast(int8_t{-1}), -1)); + TEST_CASE(EXPECT_EQ(StaticCast(int32_t{300}), static_cast(300))); + // double -> float + TEST_CASE(EXPECT_EQ(StaticCast(1.5), 1.5f)); + + // ----- Vector same-type early return ----- + TEST_CASE(EXPECT_TRUE((StaticCast(float4{1.0f, 2.0f, 3.0f, 4.0f}) == float4{1.0f, 2.0f, 3.0f, 4.0f}))); + TEST_CASE(EXPECT_TRUE((StaticCast(uchar4{1, 2, 3, 4}) == uchar4{1, 2, 3, 4}))); + + // ----- Vector conversions across base types (same arity) ----- + TEST_CASE(EXPECT_TRUE((StaticCast(uchar4{1, 2, 3, 4}) == float4{1.0f, 2.0f, 3.0f, 4.0f}))); + TEST_CASE(EXPECT_TRUE((StaticCast(float4{1.7f, -2.7f, 3.3f, -3.3f}) == int4{1, -2, 3, -3}))); + TEST_CASE(EXPECT_TRUE((StaticCast(uchar3{10, 20, 30}) == float3{10.0f, 20.0f, 30.0f}))); + TEST_CASE(EXPECT_TRUE((StaticCast(int2{-5, 5}) == float2{-5.0f, 5.0f}))); + + // ----- Partial-element extraction (NumElements < NumElements) ----- + // Per the enable_if (NumElements <= NumElements), narrower vectors are allowed. + TEST_CASE(EXPECT_TRUE((StaticCast(float4{1.0f, 2.0f, 3.0f, 4.0f}) == float2{1.0f, 2.0f}))); + TEST_CASE(EXPECT_TRUE((StaticCast(float4{1.0f, 2.0f, 3.0f, 4.0f}) == float3{1.0f, 2.0f, 3.0f}))); + TEST_CASE(EXPECT_TRUE((StaticCast(uchar4{10, 20, 30, 40}) == uchar2{10, 20}))); + // Cross-type partial extraction + TEST_CASE(EXPECT_TRUE((StaticCast(float4{1.7f, -2.7f, 3.3f, -3.3f}) == int2{1, -2}))); + + // ----- Scalar destination from compound source ----- + // NumElements == 1 with compound U: takes element 0 only. + TEST_CASE(EXPECT_EQ(StaticCast(float4{7.0f, 1.0f, 2.0f, 3.0f}), 7.0f)); + TEST_CASE(EXPECT_EQ(StaticCast(float2{4.7f, 9.0f}), 4)); + + // ----- No clamping on overflow (this is what distinguishes StaticCast from SaturateCast) ----- + // float -> uint8 with out-of-range input: result is implementation-defined per C++, + // but specifically does NOT clamp like SaturateCast would. + // We only assert that the values DIFFER from the saturate-cast behaviour to lock + // in StaticCast's pass-through semantics. + TEST_CASE(EXPECT_NE(static_cast(StaticCast(int32_t{300})), + static_cast(SaturateCast(int32_t{300})))); + + TEST_CASES_END(); +} From 3d54d7d6f3eb3303bee60ef6269567d7975ea9da Mon Sep 17 00:00:00 2001 From: Zach Vincze Date: Mon, 27 Apr 2026 21:44:37 -0400 Subject: [PATCH 10/10] Use fmed3f for floating point clamping --- include/core/detail/casting.hpp | 44 +++++++++++++-------------------- 1 file changed, 17 insertions(+), 27 deletions(-) diff --git a/include/core/detail/casting.hpp b/include/core/detail/casting.hpp index 1ad80abf..25e5bf6f 100644 --- a/include/core/detail/casting.hpp +++ b/include/core/detail/casting.hpp @@ -49,15 +49,18 @@ __device__ __host__ inline U IEEERound(U v) { } /** - * @brief Clamps v to [lo, hi]. Uses fminf/fmin/fmaxf/fmax on device to avoid - * the branchy std::clamp implementation. + * @brief Clamps v to [lo, hi]. + * @param[in] v The value to clamp. + * @param[in] lo The lower bound of the clamp. + * @param[in] hi The upper bound of the clamp. + * @return The value v clamped to [lo, hi]. */ template __device__ __host__ inline U FpClamp(U v, U lo, U hi) { static_assert(std::is_floating_point_v, "FpClamp requires a floating-point input"); #ifdef __HIP_DEVICE_COMPILE__ if constexpr (std::is_same_v) { - return fminf(fmaxf(v, lo), hi); + return __builtin_amdgcn_fmed3f(v, lo, hi); } else { return fmin(fmax(v, lo), hi); } @@ -147,18 +150,14 @@ __device__ __host__ T SaturateCast(U v) { } else if constexpr (NumElements == 1) { return T{ScalarSaturateCast(GetElement(v, 0))}; } else if constexpr (NumElements == 2) { - return T{ScalarSaturateCast(GetElement(v, 0)), - ScalarSaturateCast(GetElement(v, 1))}; + return T{ScalarSaturateCast(GetElement(v, 0)), ScalarSaturateCast(GetElement(v, 1))}; } else if constexpr (NumElements == 3) { - return T{ScalarSaturateCast(GetElement(v, 0)), - ScalarSaturateCast(GetElement(v, 1)), + return T{ScalarSaturateCast(GetElement(v, 0)), ScalarSaturateCast(GetElement(v, 1)), ScalarSaturateCast(GetElement(v, 2))}; } else { static_assert(NumElements == 4, "SaturateCast supports up to 4-element vectors"); - return T{ScalarSaturateCast(GetElement(v, 0)), - ScalarSaturateCast(GetElement(v, 1)), - ScalarSaturateCast(GetElement(v, 2)), - ScalarSaturateCast(GetElement(v, 3))}; + return T{ScalarSaturateCast(GetElement(v, 0)), ScalarSaturateCast(GetElement(v, 1)), + ScalarSaturateCast(GetElement(v, 2)), ScalarSaturateCast(GetElement(v, 3))}; } } @@ -256,18 +255,14 @@ __device__ __host__ T RangeCast(U v) { } else if constexpr (NumElements == 1) { return T{ScalarRangeCast(GetElement(v, 0))}; } else if constexpr (NumElements == 2) { - return T{ScalarRangeCast(GetElement(v, 0)), - ScalarRangeCast(GetElement(v, 1))}; + return T{ScalarRangeCast(GetElement(v, 0)), ScalarRangeCast(GetElement(v, 1))}; } else if constexpr (NumElements == 3) { - return T{ScalarRangeCast(GetElement(v, 0)), - ScalarRangeCast(GetElement(v, 1)), + return T{ScalarRangeCast(GetElement(v, 0)), ScalarRangeCast(GetElement(v, 1)), ScalarRangeCast(GetElement(v, 2))}; } else { static_assert(NumElements == 4, "RangeCast supports up to 4-element vectors"); - return T{ScalarRangeCast(GetElement(v, 0)), - ScalarRangeCast(GetElement(v, 1)), - ScalarRangeCast(GetElement(v, 2)), - ScalarRangeCast(GetElement(v, 3))}; + return T{ScalarRangeCast(GetElement(v, 0)), ScalarRangeCast(GetElement(v, 1)), + ScalarRangeCast(GetElement(v, 2)), ScalarRangeCast(GetElement(v, 3))}; } } @@ -292,17 +287,12 @@ __device__ __host__ T StaticCast(U v) { } else if constexpr (NumElements == 1) { return T{StaticCast(GetElement(v, 0))}; } else if constexpr (NumElements == 2) { - return T{StaticCast(GetElement(v, 0)), - StaticCast(GetElement(v, 1))}; + return T{StaticCast(GetElement(v, 0)), StaticCast(GetElement(v, 1))}; } else if constexpr (NumElements == 3) { - return T{StaticCast(GetElement(v, 0)), - StaticCast(GetElement(v, 1)), - StaticCast(GetElement(v, 2))}; + return T{StaticCast(GetElement(v, 0)), StaticCast(GetElement(v, 1)), StaticCast(GetElement(v, 2))}; } else { static_assert(NumElements == 4, "StaticCast supports up to 4-element vectors"); - return T{StaticCast(GetElement(v, 0)), - StaticCast(GetElement(v, 1)), - StaticCast(GetElement(v, 2)), + return T{StaticCast(GetElement(v, 0)), StaticCast(GetElement(v, 1)), StaticCast(GetElement(v, 2)), StaticCast(GetElement(v, 3))}; } }