diff --git a/backends/aoti/slim/c10/cuda/Exception.h b/backends/aoti/slim/c10/cuda/Exception.h index 33d8414e661..4d2ba13d3bb 100644 --- a/backends/aoti/slim/c10/cuda/Exception.h +++ b/backends/aoti/slim/c10/cuda/Exception.h @@ -8,26 +8,55 @@ #pragma once -#ifdef CUDA_AVAILABLE - #include #include #include +#include #include #include /// Checks a CUDA expression and aborts on error. /// @param EXPR The CUDA expression to check. -#define ET_CUDA_CHECK(EXPR) \ - do { \ - const cudaError_t __err = EXPR; \ - ET_CHECK_MSG( \ - __err == cudaSuccess, "CUDA error: %s", cudaGetErrorString(__err)); \ +#ifndef ET_CUDA_CHECK +#define ET_CUDA_CHECK(EXPR) \ + do { \ + const cudaError_t __err = EXPR; \ + if (__err == cudaSuccess) { \ + break; \ + } \ + ET_LOG( \ + Error, \ + "%s:%d CUDA error: %s", \ + __FILE__, \ + __LINE__, \ + cudaGetErrorString(__err)); \ + ET_CHECK_MSG(false, "CUDA error: %s", cudaGetErrorString(__err)); \ } while (0) +#endif + +/// Checks a CUDA expression and returns Error::Internal on failure. +/// @param EXPR The CUDA expression to check. +#ifndef ET_CUDA_CHECK_OR_RETURN_ERROR +#define ET_CUDA_CHECK_OR_RETURN_ERROR(EXPR) \ + do { \ + const cudaError_t __err = EXPR; \ + if (__err == cudaSuccess) { \ + break; \ + } \ + ET_LOG( \ + Error, \ + "%s:%d CUDA error: %s", \ + __FILE__, \ + __LINE__, \ + cudaGetErrorString(__err)); \ + return ::executorch::runtime::Error::Internal; \ + } while (0) +#endif /// Checks a CUDA expression and logs a warning on error (non-fatal). /// @param EXPR The CUDA expression to check. +#ifndef ET_CUDA_LOG_WARN #define ET_CUDA_LOG_WARN(EXPR) \ do { \ const cudaError_t __err = EXPR; \ @@ -36,5 +65,17 @@ ET_LOG(Error, "CUDA warning: %s", cudaGetErrorString(__err)); \ } \ } while (0) +#endif + +/// Kernel launch check macro (with return) - checks cudaGetLastError after +/// kernel launch. +#ifndef ET_CUDA_KERNEL_LAUNCH_CHECK_OR_RETURN_ERROR +#define ET_CUDA_KERNEL_LAUNCH_CHECK_OR_RETURN_ERROR() \ + ET_CUDA_CHECK_OR_RETURN_ERROR(cudaGetLastError()) +#endif -#endif // CUDA_AVAILABLE +/// Kernel launch check macro (without return) - checks cudaGetLastError after +/// kernel launch. +#ifndef ET_CUDA_KERNEL_LAUNCH_CHECK +#define ET_CUDA_KERNEL_LAUNCH_CHECK() ET_CUDA_CHECK(cudaGetLastError()) +#endif diff --git a/backends/aoti/slim/core/Storage.h b/backends/aoti/slim/core/Storage.h index 156556aa9e1..2b48acd65f5 100644 --- a/backends/aoti/slim/core/Storage.h +++ b/backends/aoti/slim/core/Storage.h @@ -12,7 +12,7 @@ #ifdef CUDA_AVAILABLE #include -#include +#include #endif #include diff --git a/backends/aoti/slim/core/targets.bzl b/backends/aoti/slim/core/targets.bzl index 408738edd35..83b81e67131 100644 --- a/backends/aoti/slim/core/targets.bzl +++ b/backends/aoti/slim/core/targets.bzl @@ -18,7 +18,7 @@ def define_common_targets(): "//executorch/backends/aoti/slim/util:size_util", "//executorch/runtime/platform:platform", "//executorch/backends/aoti/slim/c10/cuda:exception", - "//executorch/backends/cuda/runtime:guard", + "//executorch/backends/aoti/slim/cuda:guard", ], ) diff --git a/backends/aoti/slim/cuda/TARGETS b/backends/aoti/slim/cuda/TARGETS new file mode 100644 index 00000000000..08e83a5f3c4 --- /dev/null +++ b/backends/aoti/slim/cuda/TARGETS @@ -0,0 +1,6 @@ +load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime") +load(":targets.bzl", "define_common_targets") + +oncall("executorch") + +define_common_targets() diff --git a/backends/aoti/slim/cuda/guard.cpp b/backends/aoti/slim/cuda/guard.cpp new file mode 100644 index 00000000000..461f7ea5944 --- /dev/null +++ b/backends/aoti/slim/cuda/guard.cpp @@ -0,0 +1,159 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include +#include + +namespace executorch::backends::cuda { + +namespace { +// Thread-local stream storage (private to this file) +thread_local std::unordered_map current_streams_; +} // namespace + +Error setCurrentCUDAStream(cudaStream_t stream, DeviceIndex device_index) { + if (device_index == -1) { + // Get current device if not specified + // CUDA API returns int, explicit cast to DeviceIndex (int8_t) following + // ATen + int tmp_device = -1; + ET_CUDA_CHECK_OR_RETURN_ERROR(cudaGetDevice(&tmp_device)); + device_index = static_cast(tmp_device); + } + + current_streams_[device_index] = stream; + return Error::Ok; +} + +Result getCurrentCUDAStream(DeviceIndex device_index) { + if (device_index == -1) { + // CUDA API returns int, explicit cast to DeviceIndex (int8_t) following + // ATen + int tmp_device = -1; + ET_CUDA_CHECK_OR_RETURN_ERROR(cudaGetDevice(&tmp_device)); + device_index = static_cast(tmp_device); + } + + auto it = current_streams_.find(device_index); + if (it != current_streams_.end()) { + return it->second; + } + + cudaStream_t stream; + ET_CUDA_CHECK_OR_RETURN_ERROR(cudaStreamCreate(&stream)); + setCurrentCUDAStream(stream, device_index); + return stream; +} + +CUDAGuard::CUDAGuard(CUDAGuard&& other) noexcept + : original_device_index_(other.original_device_index_), + current_device_index_(other.current_device_index_) { + // Mark the moved-from object as "already restored" so its destructor doesn't + // try to restore the device + other.original_device_index_ = other.current_device_index_; +} + +CUDAGuard::~CUDAGuard() { + if (original_device_index_ != current_device_index_) { + // DeviceIndex (int8_t) implicitly widens to int for cudaSetDevice + cudaError_t err = cudaSetDevice(original_device_index_); + if (err != cudaSuccess) { + ET_LOG( + Error, + "~CUDAGuard: Failed to restore device to %d: %s", + static_cast(original_device_index_), + cudaGetErrorString(err)); + } + } +} + +Error CUDAGuard::set_index(DeviceIndex device_index) { + // CUDA API returns int, explicit cast to DeviceIndex (int8_t) following ATen + int tmp_device = -1; + ET_CUDA_CHECK_OR_RETURN_ERROR(cudaGetDevice(&tmp_device)); + + original_device_index_ = static_cast(tmp_device); + current_device_index_ = device_index; + + if (current_device_index_ != original_device_index_) { + // DeviceIndex (int8_t) implicitly widens to int for cudaSetDevice + ET_CUDA_CHECK_OR_RETURN_ERROR(cudaSetDevice(current_device_index_)); + } + + return Error::Ok; +} + +Result CUDAGuard::create(DeviceIndex device_index) { + CUDAGuard guard; // Fixed: Removed () to create a variable, not a function + ET_CHECK_OK_OR_RETURN_ERROR(guard.set_index(device_index)); + return guard; +} + +CUDAStreamGuard::CUDAStreamGuard(CUDAStreamGuard&& other) noexcept + : device_guard_(std::move(other.device_guard_)), + original_stream_(other.original_stream_), + current_stream_(other.current_stream_), + device_index_(other.device_index_) { + // Mark the moved-from object as "already restored" so its destructor doesn't + // try to restore the stream + other.original_stream_ = other.current_stream_; +} + +CUDAStreamGuard::~CUDAStreamGuard() { + // Restore the original stream unless this object was moved-from. + // After a move, original_stream_ == current_stream_, which indicates + // the moved-from object should not restore. + // Note: nullptr is a valid stream value (represents the default stream), + // so we must restore even if original_stream_ is nullptr. + if (original_stream_ != current_stream_) { + Error err = setCurrentCUDAStream(original_stream_, device_index_); + if (err != Error::Ok) { + ET_LOG( + Error, + "~CUDAStreamGuard: Failed to restore stream for device %d", + static_cast(device_index_)); + } + } +} + +Error CUDAStreamGuard::set_stream( + cudaStream_t stream, + DeviceIndex device_index) { + auto result = getCurrentCUDAStream(device_index); + if (!result.ok()) { + ET_LOG( + Error, + "Failed to get current stream for device %d", + static_cast(device_index)); + return result.error(); + } + + original_stream_ = result.get(); + current_stream_ = stream; + device_index_ = device_index; + + ET_CHECK_OK_OR_RETURN_ERROR(setCurrentCUDAStream(stream, device_index)); + + return Error::Ok; +} + +Result CUDAStreamGuard::create( + cudaStream_t stream, + DeviceIndex device_index) { + auto guard_result = CUDAGuard::create(device_index); + ET_CHECK_OK_OR_RETURN_ERROR(guard_result.error()); + + CUDAStreamGuard stream_guard(std::move(guard_result.get())); + ET_CHECK_OK_OR_RETURN_ERROR(stream_guard.set_stream(stream, device_index)); + + return stream_guard; +} + +} // namespace executorch::backends::cuda diff --git a/backends/aoti/slim/cuda/guard.h b/backends/aoti/slim/cuda/guard.h new file mode 100644 index 00000000000..57c01acf3b2 --- /dev/null +++ b/backends/aoti/slim/cuda/guard.h @@ -0,0 +1,193 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include +#include +#include +#include + +namespace executorch::backends::cuda { + +using executorch::runtime::Error; +using executorch::runtime::Result; + +// Use DeviceIndex (int8_t) from slim c10 to match ATen's convention. +// CUDA APIs use int, but implicit widening (int8_t -> int) handles outbound +// calls, and explicit static_cast handles inbound conversions from CUDA APIs. +using executorch::backends::aoti::slim::c10::DeviceIndex; + +/** + * Set the current CUDA stream for the specified device. + * + * @param stream The CUDA stream to set as current + * @param device_index The device index (-1 to use current device) + * @return Error code indicating success or failure + */ +Error setCurrentCUDAStream(cudaStream_t stream, DeviceIndex device_index = -1); + +/** + * Get the current CUDA stream for the specified device. + * If no stream has been set, creates a new stream and sets it as current. + * + * @param device_index The device index (-1 to use current device) + * @return Result containing the current stream on success, or an error code on + * failure + */ +Result getCurrentCUDAStream(DeviceIndex device_index = -1); + +/** + * RAII guard that sets the current CUDA device and restores it on destruction. + * This ensures that the device is properly restored even if an exception + * occurs. + * + */ +class CUDAGuard { + private: + /** + * Private constructor - use create() factory method instead. + */ + explicit CUDAGuard() + : original_device_index_(-1), current_device_index_(-1) {} + + public: + /** + * Factory method to create a CUDAGuard. + * + * @param device_index The device index to set as current + * @return Result containing the guard on success, or an error code on failure + */ + static Result create(DeviceIndex device_index); + + // Copy is not allowed + CUDAGuard(const CUDAGuard&) = delete; + CUDAGuard& operator=(const CUDAGuard&) = delete; + + // Move constructor and assignment + CUDAGuard(CUDAGuard&& other) noexcept; + CUDAGuard& operator=(CUDAGuard&& other) = delete; + + /** + * Destructor that restores the original device if necessary. + */ + ~CUDAGuard(); + + /** + * Sets the CUDA device to the given device index. + * + * @param device_index The device index to set as current + * @return Error code indicating success or failure + */ + Error set_index(DeviceIndex device_index); + + /** + * Get the original device index before the guard was created. + * + * @return The original device index + */ + DeviceIndex original_device() const { + return original_device_index_; + } + + /** + * Get the current device index. + * + * @return The current device index + */ + DeviceIndex current_device() const { + return current_device_index_; + } + + private: + /// The original device before this guard was created + DeviceIndex original_device_index_; + /// The current device managed by this guard + DeviceIndex current_device_index_; +}; + +/** + * RAII guard that sets the current CUDA device and stream, restoring both on + * destruction. This is useful for temporarily switching to a different device + * and stream. + * + */ +class CUDAStreamGuard { + private: + // Private constructor that takes a CUDAGuard + explicit CUDAStreamGuard(CUDAGuard&& guard) + : device_guard_(std::move(guard)), + original_stream_(nullptr), + current_stream_(nullptr), + device_index_(-1) {} + + public: + /** + * Factory method to create a CUDAStreamGuard. + * + * @param stream The CUDA stream to set as current + * @param device_index The device index for the stream + * @return Result containing the guard on success, or an error code on failure + */ + static Result create( + cudaStream_t stream, + DeviceIndex device_index); + + // Copy is not allowed + CUDAStreamGuard(const CUDAStreamGuard&) = delete; + CUDAStreamGuard& operator=(const CUDAStreamGuard&) = delete; + + // Move constructor and assignment + CUDAStreamGuard(CUDAStreamGuard&& other) noexcept; + CUDAStreamGuard& operator=(CUDAStreamGuard&& other) noexcept = delete; + + /** + * Destructor that restores the original stream and device. + */ + ~CUDAStreamGuard(); + + /** + * Sets the CUDA stream to the given stream on the specified device. + * + * @param stream The CUDA stream to set as current + * @param device_index The device index for the stream + * @return Error code indicating success or failure + */ + Error set_stream(cudaStream_t stream, DeviceIndex device_index); + + /** + * Get the current guarded stream. + * + * @return The current stream + */ + cudaStream_t stream() const { + return current_stream_; + } + + /** + * Get the device index being guarded. + * + * @return The device index + */ + DeviceIndex device_index() const { + return device_index_; + } + + private: + /// The device guard that handles device switching + CUDAGuard device_guard_; + /// The original stream that was current before this guard + cudaStream_t original_stream_ = nullptr; + /// The current stream being guarded + cudaStream_t current_stream_ = nullptr; + /// The device index for this stream guard + DeviceIndex device_index_; +}; + +} // namespace executorch::backends::cuda diff --git a/backends/aoti/slim/cuda/targets.bzl b/backends/aoti/slim/cuda/targets.bzl new file mode 100644 index 00000000000..cddd69f1999 --- /dev/null +++ b/backends/aoti/slim/cuda/targets.bzl @@ -0,0 +1,27 @@ +load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime") + +def define_common_targets(): + """Define targets for SlimTensor CUDA guard module.""" + + runtime.cxx_library( + name = "guard", + srcs = [ + "guard.cpp", + ], + headers = [ + "guard.h", + ], + visibility = ["PUBLIC"], + deps = [ + "//executorch/runtime/platform:platform", + ], + exported_deps = [ + "//executorch/backends/aoti/slim/c10/core:device", + "//executorch/backends/aoti/slim/c10/cuda:exception", + "//executorch/runtime/core:core", + "//executorch/runtime/core/exec_aten:lib", + ], + external_deps = [ + ("cuda", None, "cuda-lazy"), + ], + ) diff --git a/backends/aoti/slim/cuda/test/TARGETS b/backends/aoti/slim/cuda/test/TARGETS new file mode 100644 index 00000000000..9ff3e83a8bd --- /dev/null +++ b/backends/aoti/slim/cuda/test/TARGETS @@ -0,0 +1,6 @@ +load("@fbcode_macros//build_defs:cpp_unittest.bzl", "cpp_unittest") +load(":targets.bzl", "define_common_targets") + +oncall("executorch") + +define_common_targets() diff --git a/backends/aoti/slim/cuda/test/targets.bzl b/backends/aoti/slim/cuda/test/targets.bzl new file mode 100644 index 00000000000..bf38b599637 --- /dev/null +++ b/backends/aoti/slim/cuda/test/targets.bzl @@ -0,0 +1,32 @@ +load("@fbcode_macros//build_defs:cpp_unittest.bzl", "cpp_unittest") +load("@fbcode_macros//build_defs/lib:re_test_utils.bzl", "re_test_utils") + +def cuda_slim_cpp_unittest(name): + cpp_unittest( + name = "test_" + name, + srcs = [ + "test_" + name + ".cpp", + ], + deps = [ + "//executorch/backends/aoti/slim/cuda:guard", + "//executorch/runtime/core:core", + "//executorch/runtime/core/exec_aten:lib", + "//executorch/runtime/platform:platform", + ], + external_deps = [ + ("cuda", None, "cuda-lazy"), + ], + keep_gpu_sections = True, + remote_execution = re_test_utils.remote_execution( + platform = "gpu-remote-execution", + ), + ) + +def define_common_targets(): + """Defines targets that should be shared between fbcode and xplat. + + The directory containing this targets.bzl file should also contain both + TARGETS and BUCK files that call this function. + """ + cuda_slim_cpp_unittest("cuda_guard") + cuda_slim_cpp_unittest("cuda_stream_guard") diff --git a/backends/aoti/slim/cuda/test/test_cuda_guard.cpp b/backends/aoti/slim/cuda/test/test_cuda_guard.cpp new file mode 100644 index 00000000000..c9938bf5cd8 --- /dev/null +++ b/backends/aoti/slim/cuda/test/test_cuda_guard.cpp @@ -0,0 +1,113 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include +#include + +using namespace executorch::backends::cuda; +using namespace executorch::runtime; + +// TODO(gasoonjia): Multiple device tests were not included due to test +// environment limitations. These tests should be added in the future when +// multi-GPU test environments are available, + +class CUDAGuardTest : public ::testing::Test { + protected: + void SetUp() override { + et_pal_init(); + + int device_count = 0; + cudaError_t error = cudaGetDeviceCount(&device_count); + if (error != cudaSuccess || device_count == 0) { + GTEST_SKIP() << "CUDA not available or no CUDA devices found"; + } + device_count_ = device_count; + + ASSERT_EQ(cudaGetDevice(&original_device_), cudaSuccess); + } + + void TearDown() override { + if (device_count_ > 0) { + ASSERT_EQ(cudaSetDevice(original_device_), cudaSuccess); + } + } + + int device_count_ = 0; + int original_device_ = 0; +}; + +TEST_F(CUDAGuardTest, BasicDeviceSwitching) { + int current_device; + ASSERT_EQ(cudaGetDevice(¤t_device), cudaSuccess); + + { + auto guard_result = CUDAGuard::create(0); + ASSERT_TRUE(guard_result.ok()); + CUDAGuard guard = std::move(guard_result.get()); + + int device_after_guard; + ASSERT_EQ(cudaGetDevice(&device_after_guard), cudaSuccess); + EXPECT_EQ(device_after_guard, 0); + EXPECT_EQ(guard.current_device(), 0); + EXPECT_EQ(guard.original_device(), current_device); + } + + int device_after_destruction; + ASSERT_EQ(cudaGetDevice(&device_after_destruction), cudaSuccess); + EXPECT_EQ(device_after_destruction, current_device); +} + +TEST_F(CUDAGuardTest, SameDeviceNoSwitching) { + ASSERT_EQ(cudaSetDevice(0), cudaSuccess); + + { + auto guard_result = CUDAGuard::create(0); + ASSERT_TRUE(guard_result.ok()); + CUDAGuard guard = std::move(guard_result.get()); + + int current_device; + ASSERT_EQ(cudaGetDevice(¤t_device), cudaSuccess); + EXPECT_EQ(current_device, 0); + EXPECT_EQ(guard.current_device(), 0); + EXPECT_EQ(guard.original_device(), 0); + } + + int final_device; + ASSERT_EQ(cudaGetDevice(&final_device), cudaSuccess); + EXPECT_EQ(final_device, 0); +} + +TEST_F(CUDAGuardTest, InvalidDeviceIndex) { + auto guard_result = CUDAGuard::create(999); + EXPECT_FALSE(guard_result.ok()); +} + +TEST_F(CUDAGuardTest, NegativeDeviceIndex) { + auto guard_result = CUDAGuard::create(-2); + EXPECT_FALSE(guard_result.ok()); +} + +TEST_F(CUDAGuardTest, CopyConstructorDeleted) { + static_assert( + !std::is_copy_constructible_v, + "CUDAGuard should not be copy constructible"); +} + +TEST_F(CUDAGuardTest, CopyAssignmentDeleted) { + static_assert( + !std::is_copy_assignable_v, + "CUDAGuard should not be copy assignable"); +} + +TEST_F(CUDAGuardTest, MoveAssignmentDeleted) { + static_assert( + !std::is_move_assignable_v, + "CUDAGuard should not be move assignable"); +} diff --git a/backends/aoti/slim/cuda/test/test_cuda_stream_guard.cpp b/backends/aoti/slim/cuda/test/test_cuda_stream_guard.cpp new file mode 100644 index 00000000000..613bc6ffe19 --- /dev/null +++ b/backends/aoti/slim/cuda/test/test_cuda_stream_guard.cpp @@ -0,0 +1,264 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include +#include + +using namespace executorch::backends::cuda; +using namespace executorch::runtime; + +// TODO(gasoonjia): Multiple device tests were not included due to test +// environment limitations. These tests should be added in the future when +// multi-GPU test environments are available, + +class CUDAStreamGuardTest : public ::testing::Test { + protected: + void SetUp() override { + et_pal_init(); + + int device_count = 0; + cudaError_t error = cudaGetDeviceCount(&device_count); + if (error != cudaSuccess || device_count == 0) { + GTEST_SKIP() << "CUDA not available or no CUDA devices found"; + } + device_count_ = device_count; + + ASSERT_EQ(cudaGetDevice(&original_device_), cudaSuccess); + + ASSERT_EQ(cudaStreamCreate(&test_stream1_), cudaSuccess); + ASSERT_EQ(cudaStreamCreate(&test_stream2_), cudaSuccess); + } + + void TearDown() override { + if (test_stream1_) { + ASSERT_EQ(cudaStreamDestroy(test_stream1_), cudaSuccess); + } + if (test_stream2_) { + ASSERT_EQ(cudaStreamDestroy(test_stream2_), cudaSuccess); + } + + if (device_count_ > 0) { + ASSERT_EQ(cudaSetDevice(original_device_), cudaSuccess); + } + } + + int device_count_ = 0; + int original_device_ = 0; + cudaStream_t test_stream1_ = nullptr; + cudaStream_t test_stream2_ = nullptr; +}; + +TEST_F(CUDAStreamGuardTest, BasicStreamSwitching) { + auto guard_result = CUDAStreamGuard::create(test_stream1_, 0); + ASSERT_TRUE(guard_result.ok()); + CUDAStreamGuard guard = std::move(guard_result.get()); + + EXPECT_EQ(guard.stream(), test_stream1_); + EXPECT_EQ(guard.device_index(), 0); + + auto current_stream_result = getCurrentCUDAStream(0); + ASSERT_TRUE(current_stream_result.ok()); + EXPECT_EQ(current_stream_result.get(), test_stream1_); + + int current_device; + ASSERT_EQ(cudaGetDevice(¤t_device), cudaSuccess); + EXPECT_EQ(current_device, 0); +} + +TEST_F(CUDAStreamGuardTest, StreamSwitchingOnSameDevice) { + Error err = setCurrentCUDAStream(test_stream1_, 0); + ASSERT_EQ(err, Error::Ok); + + auto current_stream_result = getCurrentCUDAStream(0); + ASSERT_TRUE(current_stream_result.ok()); + EXPECT_EQ(current_stream_result.get(), test_stream1_); + + { + auto guard_result = CUDAStreamGuard::create(test_stream2_, 0); + ASSERT_TRUE(guard_result.ok()); + CUDAStreamGuard guard = std::move(guard_result.get()); + + auto new_stream_result = getCurrentCUDAStream(0); + ASSERT_TRUE(new_stream_result.ok()); + EXPECT_EQ(new_stream_result.get(), test_stream2_); + EXPECT_EQ(guard.stream(), test_stream2_); + } + + auto restored_stream_result = getCurrentCUDAStream(0); + ASSERT_TRUE(restored_stream_result.ok()); + EXPECT_EQ(restored_stream_result.get(), test_stream1_); +} + +TEST_F(CUDAStreamGuardTest, NestedStreamGuards) { + cudaStream_t initial_stream; + ASSERT_EQ(cudaStreamCreate(&initial_stream), cudaSuccess); + + Error err = setCurrentCUDAStream(initial_stream, 0); + ASSERT_EQ(err, Error::Ok); + + { + auto guard1_result = CUDAStreamGuard::create(test_stream1_, 0); + ASSERT_TRUE(guard1_result.ok()); + CUDAStreamGuard guard1 = std::move(guard1_result.get()); + + auto stream_result = getCurrentCUDAStream(0); + ASSERT_TRUE(stream_result.ok()); + EXPECT_EQ(stream_result.get(), test_stream1_); + + { + auto guard2_result = CUDAStreamGuard::create(test_stream2_, 0); + ASSERT_TRUE(guard2_result.ok()); + CUDAStreamGuard guard2 = std::move(guard2_result.get()); + + auto stream_result2 = getCurrentCUDAStream(0); + ASSERT_TRUE(stream_result2.ok()); + EXPECT_EQ(stream_result2.get(), test_stream2_); + } + + auto stream_result3 = getCurrentCUDAStream(0); + ASSERT_TRUE(stream_result3.ok()); + EXPECT_EQ(stream_result3.get(), test_stream1_); + } + + auto final_stream_result = getCurrentCUDAStream(0); + ASSERT_TRUE(final_stream_result.ok()); + EXPECT_EQ(final_stream_result.get(), initial_stream); + + ASSERT_EQ(cudaStreamDestroy(initial_stream), cudaSuccess); +} + +TEST_F(CUDAStreamGuardTest, SameStreamNoChange) { + Error err = setCurrentCUDAStream(test_stream1_, 0); + ASSERT_EQ(err, Error::Ok); + + { + auto guard_result = CUDAStreamGuard::create(test_stream1_, 0); + ASSERT_TRUE(guard_result.ok()); + CUDAStreamGuard guard = std::move(guard_result.get()); + + auto stream_result = getCurrentCUDAStream(0); + ASSERT_TRUE(stream_result.ok()); + EXPECT_EQ(stream_result.get(), test_stream1_); + EXPECT_EQ(guard.stream(), test_stream1_); + } + + auto final_stream_result = getCurrentCUDAStream(0); + ASSERT_TRUE(final_stream_result.ok()); + EXPECT_EQ(final_stream_result.get(), test_stream1_); +} + +TEST_F(CUDAStreamGuardTest, StreamAccessor) { + auto guard_result = CUDAStreamGuard::create(test_stream1_, 0); + ASSERT_TRUE(guard_result.ok()); + CUDAStreamGuard guard = std::move(guard_result.get()); + + EXPECT_EQ(guard.stream(), test_stream1_); + EXPECT_EQ(guard.device_index(), 0); +} + +TEST_F(CUDAStreamGuardTest, SetStreamMethod) { + auto guard_result = CUDAStreamGuard::create(test_stream1_, 0); + ASSERT_TRUE(guard_result.ok()); + CUDAStreamGuard guard = std::move(guard_result.get()); + + EXPECT_EQ(guard.stream(), test_stream1_); + + Error err = guard.set_stream(test_stream2_, 0); + EXPECT_EQ(err, Error::Ok); + + EXPECT_EQ(guard.stream(), test_stream2_); + + auto current_stream_result = getCurrentCUDAStream(0); + ASSERT_TRUE(current_stream_result.ok()); + EXPECT_EQ(current_stream_result.get(), test_stream2_); +} + +TEST_F(CUDAStreamGuardTest, MoveConstructor) { + auto guard1_result = CUDAStreamGuard::create(test_stream1_, 0); + ASSERT_TRUE(guard1_result.ok()); + CUDAStreamGuard guard1 = std::move(guard1_result.get()); + + EXPECT_EQ(guard1.stream(), test_stream1_); + EXPECT_EQ(guard1.device_index(), 0); + + CUDAStreamGuard guard2 = std::move(guard1); + + EXPECT_EQ(guard2.stream(), test_stream1_); + EXPECT_EQ(guard2.device_index(), 0); + + auto current_stream_result = getCurrentCUDAStream(0); + ASSERT_TRUE(current_stream_result.ok()); + EXPECT_EQ(current_stream_result.get(), test_stream1_); +} + +TEST_F(CUDAStreamGuardTest, MoveConstructorRestoresOnlyOnce) { + cudaStream_t initial_stream; + ASSERT_EQ(cudaStreamCreate(&initial_stream), cudaSuccess); + + Error err = setCurrentCUDAStream(initial_stream, 0); + ASSERT_EQ(err, Error::Ok); + + { + auto guard1_result = CUDAStreamGuard::create(test_stream1_, 0); + ASSERT_TRUE(guard1_result.ok()); + CUDAStreamGuard guard1 = std::move(guard1_result.get()); + + { CUDAStreamGuard guard2 = std::move(guard1); } + + auto stream_result = getCurrentCUDAStream(0); + ASSERT_TRUE(stream_result.ok()); + EXPECT_EQ(stream_result.get(), initial_stream); + } + + auto final_stream_result = getCurrentCUDAStream(0); + ASSERT_TRUE(final_stream_result.ok()); + EXPECT_EQ(final_stream_result.get(), initial_stream); + + ASSERT_EQ(cudaStreamDestroy(initial_stream), cudaSuccess); +} + +TEST_F(CUDAStreamGuardTest, InvalidDeviceIndex) { + auto guard_result = CUDAStreamGuard::create(test_stream1_, 999); + EXPECT_FALSE(guard_result.ok()); +} + +TEST_F(CUDAStreamGuardTest, NegativeDeviceIndex) { + auto guard_result = CUDAStreamGuard::create(test_stream1_, -2); + EXPECT_FALSE(guard_result.ok()); +} + +TEST_F(CUDAStreamGuardTest, CopyConstructorDeleted) { + static_assert( + !std::is_copy_constructible_v, + "CUDAStreamGuard should not be copy constructible"); +} + +TEST_F(CUDAStreamGuardTest, CopyAssignmentDeleted) { + static_assert( + !std::is_copy_assignable_v, + "CUDAStreamGuard should not be copy assignable"); +} + +TEST_F(CUDAStreamGuardTest, MoveAssignmentDeleted) { + static_assert( + !std::is_move_assignable_v, + "CUDAStreamGuard should not be move assignable"); +} + +TEST_F(CUDAStreamGuardTest, NullStreamPointer) { + auto guard_result = CUDAStreamGuard::create(nullptr, 0); + ASSERT_TRUE(guard_result.ok()); + CUDAStreamGuard guard = std::move(guard_result.get()); + + EXPECT_EQ(guard.stream(), nullptr); + + auto current_stream_result = getCurrentCUDAStream(0); + ASSERT_TRUE(current_stream_result.ok()); +}