From 542862f4b43c77eee4b2a30b2e6905a84d94bc50 Mon Sep 17 00:00:00 2001 From: Andy Chang Date: Mon, 16 Mar 2026 22:28:28 -0500 Subject: [PATCH] Add ReduceScatter SDMA implementation and benchmarks - Add ReduceScatterSdma C++ class (hpp/cpp) leveraging SdmaReduceScatterKernel - Add pybind11 binding for ReduceScatterSdmaHandle - Add Python wrapper class ReduceScatterSdma in ccl module - Add standalone and overlap benchmark scripts for AG/RS - Add GEMM standalone baseline script - Update AllGather test scripts with GEMM overlap support - Update CMakeLists.txt to include new source file --- .../reducescatter_sdma_class.hpp | 125 ++++++ python/mori/ccl/__init__.py | 3 +- python/mori/ccl/collective.py | 58 +++ src/collective/CMakeLists.txt | 1 + .../core/reducescatter_sdma_class.cpp | 377 +++++++++++++++++ src/pybind/mori.cpp | 143 +++++++ tests/python/ccl/bench_ag_overlap_sweep.sh | 77 ++++ tests/python/ccl/bench_allgather_sweep.sh | 68 ++++ tests/python/ccl/bench_gemm_standalone.py | 97 +++++ tests/python/ccl/bench_reducescatter_sweep.sh | 68 ++++ tests/python/ccl/bench_rs_overlap_sweep.sh | 76 ++++ tests/python/ccl/test_allgather_overlap.py | 20 +- tests/python/ccl/test_rccl_allgather.py | 26 +- tests/python/ccl/test_rccl_reducescatter.py | 352 ++++++++++++++++ .../python/ccl/test_reducescatter_overlap.py | 378 ++++++++++++++++++ 15 files changed, 1849 insertions(+), 20 deletions(-) create mode 100644 include/mori/collective/reducescatter/reducescatter_sdma_class.hpp create mode 100644 src/collective/core/reducescatter_sdma_class.cpp create mode 100755 tests/python/ccl/bench_ag_overlap_sweep.sh create mode 100755 tests/python/ccl/bench_allgather_sweep.sh create mode 100644 tests/python/ccl/bench_gemm_standalone.py create mode 100755 tests/python/ccl/bench_reducescatter_sweep.sh create mode 100755 tests/python/ccl/bench_rs_overlap_sweep.sh create mode 100644 tests/python/ccl/test_rccl_reducescatter.py create mode 100644 tests/python/ccl/test_reducescatter_overlap.py diff --git a/include/mori/collective/reducescatter/reducescatter_sdma_class.hpp b/include/mori/collective/reducescatter/reducescatter_sdma_class.hpp new file mode 100644 index 00000000..a7f197dd --- /dev/null +++ b/include/mori/collective/reducescatter/reducescatter_sdma_class.hpp @@ -0,0 +1,125 @@ +// Copyright © Advanced Micro Devices, Inc. All rights reserved. +// +// MIT License +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in all +// copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +// SOFTWARE. + +#ifndef REDUCESCATTER_SDMA_CLASS_HPP +#define REDUCESCATTER_SDMA_CLASS_HPP + +#include +#include +#include +#include +#include + +#include "mori/application/application.hpp" +#include "mori/shmem/shmem.hpp" +#include "mori/collective/collective_pub.hpp" + +namespace mori { +namespace collective { + +struct CrossPeBarrier; + +template +class ReduceScatterSdma { +private: + int myPe_; + int npes_; + size_t dtype_size_; + int max_blocks_; + + // SDMA completion flags + application::SymmMemObjPtr flagsObj_; + std::unique_ptr flags_; + + // Device-scope barrier for block-0-to-all broadcast + CrossPeBarrier* barrierPtr_; + std::unique_ptr barrierMem_; + + // Transit buffer (gather buffer): npes * chunkSize slots for SDMA scatter + void* transit_buffer_; + size_t transit_buffer_size_; + application::SymmMemObjPtr transit_buffer_obj_; + std::unique_ptr transit_buffer_ptr_; + + // Async state + std::atomic async_in_progress_; + T* async_input_; + T* async_output_; + size_t async_total_count_; + hipStream_t async_stream_; + double async_start_time_; + + bool copy_output_to_user_; + + ReduceScatterSdma(const ReduceScatterSdma&) = delete; + ReduceScatterSdma& operator=(const ReduceScatterSdma&) = delete; + + bool ensure_buffer_size(void*& buffer, + std::unique_ptr& buffer_ptr, + size_t& current_size, + application::SymmMemObjPtr& buffer_obj, + size_t required_size, + const char* buffer_name); + + void copy_result_to_user(T* output, size_t total_count, hipStream_t stream); + +public: + /** + * @param myPe Current PE ID + * @param npes Total number of PEs + * @param transit_buffer_size Transit buffer size in bytes (default 512MB) + * @param copy_output_to_user If true, copy reduced shard to user output buffer + */ + ReduceScatterSdma(int myPe, int npes, size_t transit_buffer_size = 512 * 1024 * 1024, + bool copy_output_to_user = true); + + ReduceScatterSdma(int myPe, int npes, size_t input_buffer_size, size_t output_buffer_size, + bool copy_output_to_user = true); + + ~ReduceScatterSdma(); + + /** + * @brief Synchronous ReduceScatter via SDMA + * @param input Input data — total_count elements per rank + * @param output Output data — total_count/npes reduced elements per rank + * @param total_count Number of input elements per PE + * @param stream HIP stream + */ + bool operator()(T* input, T* output, size_t total_count, hipStream_t stream = nullptr); + + bool start_async(T* input, T* output, size_t total_count, hipStream_t stream = nullptr); + double wait_async(hipStream_t stream = nullptr); + bool is_async_in_progress() const { return async_in_progress_; } + void cancel_async(); + + application::SymmMemObjPtr getFlagsObj() const { return flagsObj_; } + void* getTransitBuffer() const { return transit_buffer_; } + size_t getTransitBufferSize() const { return transit_buffer_size_; } + application::SymmMemObjPtr getTransitBufferObj() const { return transit_buffer_obj_; } + + void resetFlags(); +}; + +} // namespace collective +} // namespace mori + +#endif // REDUCESCATTER_SDMA_CLASS_HPP diff --git a/python/mori/ccl/__init__.py b/python/mori/ccl/__init__.py index 70e2fc40..e6710f2a 100644 --- a/python/mori/ccl/__init__.py +++ b/python/mori/ccl/__init__.py @@ -23,5 +23,6 @@ from .collective import All2allSdma from .collective import AllgatherSdma from .collective import AllreduceSdma +from .collective import ReduceScatterSdma -__all__ = ['All2allSdma', 'AllgatherSdma', 'AllreduceSdma'] +__all__ = ['All2allSdma', 'AllgatherSdma', 'AllreduceSdma', 'ReduceScatterSdma'] diff --git a/python/mori/ccl/collective.py b/python/mori/ccl/collective.py index e8b9082d..a1451fe1 100644 --- a/python/mori/ccl/collective.py +++ b/python/mori/ccl/collective.py @@ -410,3 +410,61 @@ def get_output_transit_buffer(self, device=None): while an operation is in progress. """ return self._handle.get_output_transit_buffer(device) + + +def _cpp_reducescatter_factory(entity_name: str): + """Factory function to get C++ entities from mori_cpp module""" + return getattr(mori_cpp, entity_name) + + +class ReduceScatterSdma: + """Python wrapper for ReduceScatterSdma C++ class. + + Performs ReduceScatter via SDMA: each rank contributes total_count + elements; the result is total_count/npes reduced elements per rank. + """ + + def __init__(self, my_pe: int, npes: int, + input_buffer_size: Optional[int] = None, + output_buffer_size: Optional[int] = None, + transit_buffer_size: Optional[int] = None, + copy_output_to_user: bool = True): + self.my_pe = my_pe + self.npes = npes + handle_class = _cpp_reducescatter_factory("ReduceScatterSdmaHandle") + + if input_buffer_size is not None and output_buffer_size is not None: + self._handle = handle_class(my_pe, npes, input_buffer_size, output_buffer_size, copy_output_to_user) + elif transit_buffer_size is not None: + self._handle = handle_class(my_pe, npes, transit_buffer_size, copy_output_to_user) + else: + self._handle = handle_class(my_pe, npes, 512 * 1024 * 1024, copy_output_to_user) + + def __call__(self, input_data, output_data, count: int, stream=None) -> bool: + """Execute ReduceScatter SDMA operation. + + Args: + input_data: Input CUDA tensor (total_count elements per rank) + output_data: Output CUDA tensor (total_count/npes elements per rank) + count: Total number of input elements per PE + stream: Optional HIP stream + """ + return self._handle(input_data, output_data, count, stream) + + def start_async(self, input_data, output_data, count: int, stream=None) -> bool: + return self._handle.start_async(input_data, output_data, count, stream) + + def wait_async(self, stream=None) -> float: + return self._handle.wait_async(stream) + + def is_async_in_progress(self) -> bool: + return self._handle.is_async_in_progress() + + def cancel_async(self): + self._handle.cancel_async() + + def reset_flags(self): + self._handle.reset_flags() + + def get_transit_buffer(self, device=None, dtype=None): + return self._handle.get_transit_buffer(device, dtype) diff --git a/src/collective/CMakeLists.txt b/src/collective/CMakeLists.txt index 47c549a3..0c92730c 100644 --- a/src/collective/CMakeLists.txt +++ b/src/collective/CMakeLists.txt @@ -7,6 +7,7 @@ set(COLLECTIVE_SOURCES core/oneshot_all2all_sdma_class.cpp core/oneshot_allgather_sdma_class.cpp core/twoshot_allreduce_sdma_class.cpp + core/reducescatter_sdma_class.cpp inter_node/executors/ring_1d_executor.cpp inter_node/executors/one_shot_executor.cpp # Note: intra_node_executor is header-only template class diff --git a/src/collective/core/reducescatter_sdma_class.cpp b/src/collective/core/reducescatter_sdma_class.cpp new file mode 100644 index 00000000..89e03fed --- /dev/null +++ b/src/collective/core/reducescatter_sdma_class.cpp @@ -0,0 +1,377 @@ +// Copyright © Advanced Micro Devices, Inc. All rights reserved. +// +// MIT License +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in all +// copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +// SOFTWARE. + +#include "mori/collective/reducescatter/reducescatter_sdma_class.hpp" +#include "mori/collective/allreduce/twoshot_sdma_kernel.hpp" +#include "mori/shmem/shmem.hpp" +#include +#include +#include +#include +#include +#include + +namespace mori { +namespace collective { + +// --------------------------------------------------------------------------- +// Delegating constructor +// --------------------------------------------------------------------------- +template +ReduceScatterSdma::ReduceScatterSdma(int myPe, int npes, + size_t transit_buffer_size, + bool copy_output_to_user) + : ReduceScatterSdma(myPe, npes, 0, transit_buffer_size, copy_output_to_user) { +} + +// --------------------------------------------------------------------------- +// Main constructor +// --------------------------------------------------------------------------- +template +ReduceScatterSdma::ReduceScatterSdma(int myPe, int npes, + size_t /*input_buffer_size*/, + size_t output_buffer_size, + bool copy_output_to_user) + : myPe_(myPe), + npes_(npes), + dtype_size_(sizeof(T)), + max_blocks_(getDeviceMaxBlocks()), + flags_(nullptr, ShmemDeleter()), + barrierPtr_(nullptr), + barrierMem_(nullptr, ShmemDeleter()), + transit_buffer_(nullptr), + transit_buffer_size_(output_buffer_size), + transit_buffer_ptr_(nullptr, ShmemDeleter()), + async_in_progress_(false), + async_input_(nullptr), + async_output_(nullptr), + async_total_count_(0), + async_stream_(nullptr), + async_start_time_(0.0), + copy_output_to_user_(copy_output_to_user) { + + // 1. Allocate SDMA completion flags + size_t flagsSize = npes_ * sizeof(uint64_t); + void* flags = shmem::ShmemMalloc(flagsSize); + if (!flags) throw std::runtime_error("Failed to allocate flags memory"); + flags_.reset(static_cast(flags)); + memset(flags_.get(), 0, flagsSize); + flagsObj_ = shmem::ShmemQueryMemObjPtr(flags_.get()); + if (!flagsObj_.IsValid()) + throw std::runtime_error("Failed to get valid flags memory object"); + + // 2. Allocate CrossPeBarrier + size_t barrierSize = sizeof(CrossPeBarrier); + void* bMem = shmem::ShmemMalloc(barrierSize); + if (!bMem) throw std::runtime_error("Failed to allocate barrier memory"); + barrierMem_.reset(bMem); + barrierPtr_ = reinterpret_cast(bMem); + hipError_t me = hipMemset(bMem, 0, barrierSize); + if (me != hipSuccess) + throw std::runtime_error("Failed to zero-init barrier memory"); + + // 3. Allocate transit buffer (gather buffer for SDMA scatter + reduce) + transit_buffer_ = shmem::ShmemMalloc(transit_buffer_size_); + if (!transit_buffer_) + throw std::runtime_error("Failed to allocate transit buffer"); + transit_buffer_ptr_.reset(transit_buffer_); + + transit_buffer_obj_ = + shmem::ShmemSymmetricRegister(transit_buffer_, transit_buffer_size_); + if (!transit_buffer_obj_.IsValid()) + throw std::runtime_error("Failed to register transit buffer"); + + printf("ReduceScatterSdma(SDMA) initialized: PE %d of %d, max_blocks=%d\n", + myPe_, npes_, max_blocks_); + printf(" Flags: %zu bytes at %p\n", flagsSize, flags_.get()); + printf(" Barrier: %zu bytes at %p\n", barrierSize, bMem); + printf(" Transit buffer: %.2f MB at %p\n", + transit_buffer_size_ / (1024.0 * 1024.0), transit_buffer_); +} + +// --------------------------------------------------------------------------- +template +ReduceScatterSdma::~ReduceScatterSdma() { + if (async_in_progress_) { + cancel_async(); + } + if (flags_) { + printf("ReduceScatterSdma destroyed: PE %d\n", myPe_); + } +} + +// --------------------------------------------------------------------------- +template +bool ReduceScatterSdma::ensure_buffer_size(void*& buffer, + std::unique_ptr& buffer_ptr, + size_t& current_size, + application::SymmMemObjPtr& buffer_obj, + size_t required_size, + const char* buffer_name) { + if (required_size <= current_size) { + return true; + } + + printf("PE %d: %s too small: required %.2f MB, current %.2f MB\n", + myPe_, buffer_name, + required_size / (1024.0 * 1024.0), + current_size / (1024.0 * 1024.0)); + + buffer_ptr.reset(); + + current_size = required_size; + buffer = shmem::ShmemMalloc(current_size); + if (buffer == nullptr) { + fprintf(stderr, "PE %d: Failed to reallocate %s of size %.2f MB\n", + myPe_, buffer_name, current_size / (1024.0 * 1024.0)); + return false; + } + buffer_ptr.reset(buffer); + + buffer_obj = shmem::ShmemSymmetricRegister(buffer, current_size); + if (!buffer_obj.IsValid()) { + fprintf(stderr, "PE %d: Failed to re-register %s\n", myPe_, buffer_name); + return false; + } + + printf("PE %d: %s reallocated to %.2f MB\n", + myPe_, buffer_name, current_size / (1024.0 * 1024.0)); + return true; +} + +// --------------------------------------------------------------------------- +// Copy the reduced shard (slot[myPe]) from transit buffer to user output. +// ReduceScatter output = total_count / npes elements per rank. +// --------------------------------------------------------------------------- +template +void ReduceScatterSdma::copy_result_to_user(T* output, size_t total_count, hipStream_t stream) { + using P = typename packed_t::P; + constexpr int pack_size = P::size; + const size_t elementCountPerRank = + ((total_count / npes_ + pack_size - 1) / pack_size) * pack_size; + const size_t bytes = elementCountPerRank * dtype_size_; + + if (!output) throw std::runtime_error("Output pointer is null"); + if (!transit_buffer_) throw std::runtime_error("Transit buffer is null"); + + uint8_t* src = reinterpret_cast(transit_buffer_) + + static_cast(myPe_) * bytes; + + hipError_t err = stream + ? hipMemcpyAsync(output, src, bytes, hipMemcpyDeviceToDevice, stream) + : hipMemcpy(output, src, bytes, hipMemcpyDeviceToDevice); + if (err != hipSuccess) { + fprintf(stderr, "PE %d: copy_result_to_user failed: %s\n", + myPe_, hipGetErrorString(err)); + throw std::runtime_error("Output copy failed"); + } +} + +// --------------------------------------------------------------------------- +// operator() +// --------------------------------------------------------------------------- +template +bool ReduceScatterSdma::operator()(T* input, T* output, + size_t total_count, hipStream_t stream) { + if (async_in_progress_) { + printf("PE %d: Cannot execute sync operation while async is in progress\n", myPe_); + return false; + } + + try { + constexpr int pack_size = packed_t::P::size; + int threads = 512; + int packedPerRank = static_cast( + ((total_count / npes_ + pack_size - 1) / pack_size)); + int blocks = std::min(max_blocks_, + (packedPerRank + threads - 1) / threads); + if (blocks < 1) blocks = 1; + + SdmaReduceScatterKernel<<>>( + myPe_, npes_, + input, + transit_buffer_obj_, + flagsObj_, + barrierPtr_, + total_count); + + hipError_t err = hipGetLastError(); + if (err != hipSuccess) { + fprintf(stderr, "PE %d: SdmaReduceScatter launch failed: %s\n", + myPe_, hipGetErrorString(err)); + return false; + } + + if (copy_output_to_user_) { + copy_result_to_user(output, total_count, stream); + } + + } catch (const std::exception& e) { + fprintf(stderr, "PE %d: ReduceScatter failed: %s\n", myPe_, e.what()); + return false; + } + return true; +} + +// --------------------------------------------------------------------------- +// Async API +// --------------------------------------------------------------------------- +template +bool ReduceScatterSdma::start_async(T* input, T* output, + size_t total_count, hipStream_t stream) { + bool expected = false; + if (!async_in_progress_.compare_exchange_strong(expected, true)) { + printf("PE %d: Another async operation is already in progress\n", myPe_); + return false; + } + + async_input_ = input; + async_output_ = output; + async_total_count_ = total_count; + async_stream_ = stream; + async_start_time_ = MPI_Wtime(); + + try { + size_t required_size = total_count * dtype_size_; + if (!ensure_buffer_size(transit_buffer_, transit_buffer_ptr_, + transit_buffer_size_, transit_buffer_obj_, + required_size, "transit buffer")) { + async_in_progress_ = false; + return false; + } + + constexpr int pack_size = packed_t::P::size; + int threads = 512; + int packedPerRank = static_cast( + ((total_count / npes_ + pack_size - 1) / pack_size)); + int blocks = std::min(max_blocks_, + (packedPerRank + threads - 1) / threads); + if (blocks < 1) blocks = 1; + + SdmaReduceScatterKernel<<>>( + myPe_, npes_, + input, + transit_buffer_obj_, + flagsObj_, + barrierPtr_, + total_count); + + hipError_t kernel_err = hipGetLastError(); + if (kernel_err != hipSuccess) { + fprintf(stderr, "PE %d: Async kernel launch failed: %s\n", + myPe_, hipGetErrorString(kernel_err)); + throw std::runtime_error("Kernel launch failed"); + } + + return true; + + } catch (const std::exception& e) { + fprintf(stderr, "PE %d: Failed to start async operation: %s\n", myPe_, e.what()); + async_in_progress_ = false; + return false; + } +} + +template +double ReduceScatterSdma::wait_async(hipStream_t stream) { + if (!async_in_progress_) { + printf("PE %d: No async operation in progress\n", myPe_); + return -1.0; + } + + try { + hipStream_t wait_stream = (stream != nullptr) ? stream : async_stream_; + + if (copy_output_to_user_) { + copy_result_to_user(async_output_, async_total_count_, wait_stream); + } + + if (wait_stream != nullptr) { + hipError_t err = hipStreamSynchronize(wait_stream); + if (err != hipSuccess) { + fprintf(stderr, "PE %d: Stream synchronization failed: %s\n", + myPe_, hipGetErrorString(err)); + throw std::runtime_error("Stream synchronization failed"); + } + } else { + hipError_t err = hipDeviceSynchronize(); + if (err != hipSuccess) { + fprintf(stderr, "PE %d: Device synchronization failed: %s\n", + myPe_, hipGetErrorString(err)); + throw std::runtime_error("Device synchronization failed"); + } + } + + double end_time = MPI_Wtime(); + double duration = end_time - async_start_time_; + + async_in_progress_ = false; + async_input_ = nullptr; + async_output_ = nullptr; + async_total_count_ = 0; + async_stream_ = nullptr; + async_start_time_ = 0.0; + + return duration; + + } catch (const std::exception& e) { + fprintf(stderr, "PE %d: Async wait failed: %s\n", myPe_, e.what()); + cancel_async(); + return -1.0; + } +} + +template +void ReduceScatterSdma::cancel_async() { + if (async_in_progress_) { + printf("PE %d: Cancelling async operation\n", myPe_); + async_in_progress_ = false; + async_input_ = nullptr; + async_output_ = nullptr; + async_total_count_ = 0; + async_stream_ = nullptr; + async_start_time_ = 0.0; + } +} + +// --------------------------------------------------------------------------- +template +void ReduceScatterSdma::resetFlags() { + if (flags_) { + memset(flags_.get(), 0, npes_ * sizeof(uint64_t)); + } +} + +// --------------------------------------------------------------------------- +// Explicit instantiations +// --------------------------------------------------------------------------- +template class ReduceScatterSdma; +template class ReduceScatterSdma; +template class ReduceScatterSdma; +template class ReduceScatterSdma; +template class ReduceScatterSdma; +template class ReduceScatterSdma; +template class ReduceScatterSdma; +template class ReduceScatterSdma<__hip_bfloat16>; + +} // namespace collective +} // namespace mori diff --git a/src/pybind/mori.cpp b/src/pybind/mori.cpp index 70242791..dc3fdc26 100644 --- a/src/pybind/mori.cpp +++ b/src/pybind/mori.cpp @@ -40,6 +40,7 @@ #include "mori/shmem/shmem.hpp" #include "src/pybind/torch_utils.hpp" #include "mori/collective/collective.hpp" +#include "mori/collective/reducescatter/reducescatter_sdma_class.hpp" /* ---------------------------------------------------------------------------------------------- */ /* Ops APIs */ @@ -1422,6 +1423,148 @@ void RegisterMoriCcl(pybind11::module_& m) { py::arg("device") = py::none(), "Get output transit buffer as a PyTorch tensor (bf16)"); + // ========================================================================= + // Bind ReduceScatterSdma class (uint32_t version) + // ========================================================================= + py::class_>(m, "ReduceScatterSdmaHandle") + .def(py::init(), + py::arg("my_pe"), + py::arg("npes"), + py::arg("input_buffer_size"), + py::arg("output_buffer_size"), + py::arg("copy_output_to_user") = true, + "Initialize ReduceScatterSdma with PE ID, number of PEs, and buffer sizes") + .def(py::init(), + py::arg("my_pe"), + py::arg("npes"), + py::arg("transit_buffer_size") = 512 * 1024 * 1024, + py::arg("copy_output_to_user") = true, + "Initialize ReduceScatterSdma with PE ID, number of PEs, and transit buffer size") + .def("__call__", + [](mori::collective::ReduceScatterSdma& self, + const torch::Tensor& input_tensor, + const torch::Tensor& output_tensor, + size_t count, + py::object stream_obj) -> bool { + + if (input_tensor.dim() != 1) + throw std::runtime_error("Input tensor must be 1-dimensional"); + if (output_tensor.dim() != 1) + throw std::runtime_error("Output tensor must be 1-dimensional"); + if (!input_tensor.is_cuda()) + throw std::runtime_error("Input tensor must be CUDA tensor"); + if (!output_tensor.is_cuda()) + throw std::runtime_error("Output tensor must be CUDA tensor"); + + size_t byte_count = count * input_tensor.element_size(); + size_t u32_count = (byte_count + sizeof(uint32_t) - 1) / sizeof(uint32_t); + + uint32_t* input_ptr = reinterpret_cast(input_tensor.data_ptr()); + uint32_t* output_ptr = reinterpret_cast(output_tensor.data_ptr()); + + int device_index = input_tensor.device().index(); + hipStream_t stream = convert_torch_stream_to_hip(stream_obj, device_index); + + return self(input_ptr, output_ptr, u32_count, stream); + }, + py::arg("input"), + py::arg("output"), + py::arg("count"), + py::arg("stream") = py::none(), + "Execute ReduceScatter SDMA operation") + .def("start_async", + [](mori::collective::ReduceScatterSdma& self, + const torch::Tensor& input_tensor, + const torch::Tensor& output_tensor, + size_t count, + py::object stream_obj) -> bool { + + if (input_tensor.dim() != 1 || output_tensor.dim() != 1) + throw std::runtime_error("Tensors must be 1-dimensional"); + if (!input_tensor.is_cuda() || !output_tensor.is_cuda()) + throw std::runtime_error("Tensors must be CUDA tensors"); + + size_t byte_count = count * input_tensor.element_size(); + size_t u32_count = (byte_count + sizeof(uint32_t) - 1) / sizeof(uint32_t); + + uint32_t* input_ptr = reinterpret_cast(input_tensor.data_ptr()); + uint32_t* output_ptr = reinterpret_cast(output_tensor.data_ptr()); + + int device_index = input_tensor.device().index(); + hipStream_t stream = convert_torch_stream_to_hip(stream_obj, device_index); + + return self.start_async(input_ptr, output_ptr, u32_count, stream); + }, + py::arg("input"), + py::arg("output"), + py::arg("count"), + py::arg("stream") = py::none(), + "Start asynchronous ReduceScatter SDMA operation") + .def("wait_async", + [](mori::collective::ReduceScatterSdma& self, + py::object stream_obj) -> double { + hipStream_t stream = convert_torch_stream_to_hip(stream_obj); + return self.wait_async(stream); + }, + py::arg("stream") = py::none(), + "Wait for asynchronous ReduceScatter SDMA operation to complete") + .def("is_async_in_progress", + &mori::collective::ReduceScatterSdma::is_async_in_progress, + "Check if async operation is in progress") + .def("cancel_async", + &mori::collective::ReduceScatterSdma::cancel_async, + "Cancel ongoing async operation") + .def("reset_flags", + &mori::collective::ReduceScatterSdma::resetFlags, + "Reset synchronization flags") + .def("get_transit_buffer", + [](mori::collective::ReduceScatterSdma& self, + py::object device_obj, + py::object dtype_obj) -> torch::Tensor { + void* buffer_ptr = self.getTransitBuffer(); + size_t buffer_size = self.getTransitBufferSize(); + + if (buffer_ptr == nullptr) + throw std::runtime_error("Transit buffer is null"); + + torch::Dtype torch_dtype = torch::kUInt32; + if (!dtype_obj.is_none()) + torch_dtype = py::cast(dtype_obj); + size_t elem_size = torch::elementSize(torch_dtype); + size_t num_elements = buffer_size / elem_size; + + int device_index = 0; + if (!device_obj.is_none()) { + py::object torch_module = py::module_::import("torch"); + py::object tensor_class = torch_module.attr("Tensor"); + if (py::isinstance(device_obj, tensor_class)) { + torch::Tensor t = device_obj.cast(); + if (t.is_cuda()) { + device_index = t.device().index(); + } else { + throw std::runtime_error("device tensor must be a CUDA tensor"); + } + } else { + try { + device_index = device_obj.cast(); + } catch (const py::cast_error&) { + throw std::runtime_error("device must be an int, a CUDA tensor, or None"); + } + } + } else { + device_index = at::cuda::current_device(); + } + + return torch::from_blob( + buffer_ptr, + {static_cast(num_elements)}, + torch::TensorOptions().dtype(torch_dtype).device(torch::kCUDA, device_index) + ); + }, + py::arg("device") = py::none(), + py::arg("dtype") = py::none(), + "Get transit buffer as a PyTorch tensor"); + // Keep old function-based interface for backward compatibility (optional) m.def("allreduce_sdma", [](py::array_t input_array, diff --git a/tests/python/ccl/bench_ag_overlap_sweep.sh b/tests/python/ccl/bench_ag_overlap_sweep.sh new file mode 100755 index 00000000..ff6204e9 --- /dev/null +++ b/tests/python/ccl/bench_ag_overlap_sweep.sh @@ -0,0 +1,77 @@ +#!/bin/bash +# AllGather + GEMM Overlap Sweep: SDMA vs RCCL across GEMM sizes +# Fixed AllGather at 128MB/PE, sweeps GEMM M=N=K from 4096 to 32768 + +export PATH=$HOME/.local/bin:$PATH +export PYTHONPATH=$HOME/mori:$PYTHONPATH +export HSA_NO_SCRATCH_RECLAIM=1 + +MORI_DIR="$HOME/mori" +RESULTS_FILE="$MORI_DIR/ag_overlap_sweep_results.csv" + +AG_ELEMS=33554432 # 128 MB/PE +WORLD_SIZE=8 +ITERS=10 +WARMUP=10 + +GEMM_SIZES="4096 8192 16384" + +echo "GEMM_MNK,SDMA_AG_ms,SDMA_GEMM_ms,SDMA_Overlap_ms,SDMA_SeqTotal_ms,SDMA_Speedup,RCCL_AG_ms,RCCL_GEMM_ms,RCCL_Overlap_ms,RCCL_SeqTotal_ms,RCCL_Speedup" > "$RESULTS_FILE" + +for GEMM_SIZE in $GEMM_SIZES; do + echo "" + echo "================================================================" + echo " GEMM M=N=K=${GEMM_SIZE}, AllGather ${AG_ELEMS} elems (128MB/PE)" + echo "================================================================" + + # --- SDMA ON (MORI AllgatherSdma + GEMM overlap) --- + echo "[GEMM=${GEMM_SIZE}] Running SDMA + GEMM overlap ..." + SDMA_OUT=$(cd "$MORI_DIR" && python3 ./tests/python/ccl/test_allgather_overlap.py \ + --elems "$AG_ELEMS" --world-size "$WORLD_SIZE" \ + --iterations "$ITERS" --warmup "$WARMUP" --enable-sdma 1 \ + --use-custom-stream --test-gemm-overlap \ + --gemm-m "$GEMM_SIZE" --gemm-n "$GEMM_SIZE" --gemm-k "$GEMM_SIZE" 2>&1) + + SDMA_AG=$(echo "$SDMA_OUT" | grep "AllGather avg:" | tail -1 | awk '{print $3}' | sed 's/s//') + SDMA_GEMM=$(echo "$SDMA_OUT" | grep "GEMM avg:" | tail -1 | awk '{print $3}' | sed 's/s//') + SDMA_OVERLAP=$(echo "$SDMA_OUT" | grep "Avg time (measured):" | tail -1 | awk '{print $4}' | sed 's/s//') + SDMA_SEQ=$(echo "$SDMA_OUT" | grep "Sequential baseline time:" | tail -1 | awk '{print $4}' | sed 's/s//') + SDMA_SPEEDUP=$(echo "$SDMA_OUT" | grep "Speedup:" | tail -1 | awk '{print $2}') + + echo " SDMA: AG=${SDMA_AG}s GEMM=${SDMA_GEMM}s Overlap=${SDMA_OVERLAP}s Speedup=${SDMA_SPEEDUP}" + + # --- SDMA OFF (RCCL + GEMM overlap) --- + echo "[GEMM=${GEMM_SIZE}] Running RCCL + GEMM overlap ..." + RCCL_OUT=$(cd "$MORI_DIR" && python3 ./tests/python/ccl/test_rccl_allgather.py \ + --elems "$AG_ELEMS" --world-size "$WORLD_SIZE" \ + --iterations "$ITERS" --warmup "$WARMUP" \ + --use-custom-stream --test-gemm-overlap \ + --gemm-m "$GEMM_SIZE" --gemm-n "$GEMM_SIZE" --gemm-k "$GEMM_SIZE" 2>&1) + + RCCL_AG=$(echo "$RCCL_OUT" | grep "AllGather avg:" | tail -1 | awk '{print $3}' | sed 's/s//') + RCCL_GEMM=$(echo "$RCCL_OUT" | grep "GEMM avg:" | tail -1 | awk '{print $3}' | sed 's/s//') + RCCL_OVERLAP=$(echo "$RCCL_OUT" | grep "Avg time (measured):" | tail -1 | awk '{print $4}' | sed 's/s//') + RCCL_SEQ=$(echo "$RCCL_OUT" | grep "Sequential baseline time:" | tail -1 | awk '{print $4}' | sed 's/s//') + RCCL_SPEEDUP=$(echo "$RCCL_OUT" | grep "Speedup:" | tail -1 | awk '{print $2}') + + echo " RCCL: AG=${RCCL_AG}s GEMM=${RCCL_GEMM}s Overlap=${RCCL_OVERLAP}s Speedup=${RCCL_SPEEDUP}" + + # Convert to ms for CSV + SDMA_AG_MS=$(echo "$SDMA_AG" | awk '{printf "%.3f", $1 * 1000}') + SDMA_GEMM_MS=$(echo "$SDMA_GEMM" | awk '{printf "%.3f", $1 * 1000}') + SDMA_OVERLAP_MS=$(echo "$SDMA_OVERLAP" | awk '{printf "%.3f", $1 * 1000}') + SDMA_SEQ_MS=$(echo "$SDMA_SEQ" | awk '{printf "%.3f", $1 * 1000}') + RCCL_AG_MS=$(echo "$RCCL_AG" | awk '{printf "%.3f", $1 * 1000}') + RCCL_GEMM_MS=$(echo "$RCCL_GEMM" | awk '{printf "%.3f", $1 * 1000}') + RCCL_OVERLAP_MS=$(echo "$RCCL_OVERLAP" | awk '{printf "%.3f", $1 * 1000}') + RCCL_SEQ_MS=$(echo "$RCCL_SEQ" | awk '{printf "%.3f", $1 * 1000}') + + echo "${GEMM_SIZE},${SDMA_AG_MS},${SDMA_GEMM_MS},${SDMA_OVERLAP_MS},${SDMA_SEQ_MS},${SDMA_SPEEDUP},${RCCL_AG_MS},${RCCL_GEMM_MS},${RCCL_OVERLAP_MS},${RCCL_SEQ_MS},${RCCL_SPEEDUP}" >> "$RESULTS_FILE" +done + +echo "" +echo "================================================================" +echo " AG OVERLAP SWEEP COMPLETE — Results: $RESULTS_FILE" +echo "================================================================" +echo "" +cat "$RESULTS_FILE" | column -t -s',' diff --git a/tests/python/ccl/bench_allgather_sweep.sh b/tests/python/ccl/bench_allgather_sweep.sh new file mode 100755 index 00000000..7d8b4f95 --- /dev/null +++ b/tests/python/ccl/bench_allgather_sweep.sh @@ -0,0 +1,68 @@ +#!/bin/bash +# AllGather Latency Sweep: SDMA (MORI) vs RCCL +# Sweeps buffer sizes from 10MB to 128MB per PE + +export PATH=$HOME/.local/bin:$PATH +export PYTHONPATH=$HOME/mori:$PYTHONPATH +export HSA_NO_SCRATCH_RECLAIM=1 + +MORI_DIR="$HOME/mori" +RESULTS_FILE="$MORI_DIR/allgather_sweep_results.csv" + +# uint32 = 4 bytes per element +declare -A SIZE_TO_ELEMS +SIZE_TO_ELEMS[10]=2621440 +SIZE_TO_ELEMS[20]=5242880 +SIZE_TO_ELEMS[40]=10485760 +SIZE_TO_ELEMS[80]=20971520 +SIZE_TO_ELEMS[128]=33554432 + +WORLD_SIZE=8 +ITERS=10 +WARMUP=10 + +echo "DataSize_MB,SDMA_AvgTime_ms,SDMA_BW_GBs,RCCL_AvgTime_ms,RCCL_BW_GBs" > "$RESULTS_FILE" + +for SIZE_MB in 10 20 40 80 128; do + ELEMS=${SIZE_TO_ELEMS[$SIZE_MB]} + TOTAL_MB=$((SIZE_MB * WORLD_SIZE)) + echo "" + echo "================================================================" + echo " Buffer: ${SIZE_MB}MB/PE (${TOTAL_MB}MB total, ${ELEMS} uint32 elements)" + echo "================================================================" + + # --- SDMA ON (MORI AllgatherSdma) --- + echo "[${SIZE_MB}MB] Running AllgatherSdma (SDMA ON) ..." + SDMA_OUT=$(cd "$MORI_DIR" && python3 ./tests/python/ccl/test_allgather_overlap.py \ + --elems "$ELEMS" --world-size "$WORLD_SIZE" \ + --iterations "$ITERS" --warmup "$WARMUP" --enable-sdma 1 2>&1) + + SDMA_AVG=$(echo "$SDMA_OUT" | grep "Avg time:" | tail -1 | awk '{print $3}' | sed 's/s//') + SDMA_BW=$(echo "$SDMA_OUT" | grep "Bandwidth:" | tail -1 | awk '{print $2}') + + echo " SDMA Avg time: ${SDMA_AVG}s, BW: ${SDMA_BW} GB/s" + + # --- SDMA OFF (RCCL dist.all_gather) --- + echo "[${SIZE_MB}MB] Running RCCL AllGather (SDMA OFF) ..." + RCCL_OUT=$(cd "$MORI_DIR" && python3 ./tests/python/ccl/test_rccl_allgather.py \ + --elems "$ELEMS" --world-size "$WORLD_SIZE" \ + --iterations "$ITERS" --warmup "$WARMUP" 2>&1) + + RCCL_AVG=$(echo "$RCCL_OUT" | grep "Avg time:" | tail -1 | awk '{print $3}' | sed 's/s//') + RCCL_BW=$(echo "$RCCL_OUT" | grep "Bandwidth:" | tail -1 | awk '{print $2}') + + echo " RCCL Avg time: ${RCCL_AVG}s, BW: ${RCCL_BW} GB/s" + + # Convert seconds to milliseconds for CSV + SDMA_MS=$(echo "$SDMA_AVG" | awk '{printf "%.3f", $1 * 1000}') + RCCL_MS=$(echo "$RCCL_AVG" | awk '{printf "%.3f", $1 * 1000}') + + echo "${SIZE_MB},${SDMA_MS},${SDMA_BW},${RCCL_MS},${RCCL_BW}" >> "$RESULTS_FILE" +done + +echo "" +echo "================================================================" +echo " ALLGATHER SWEEP COMPLETE — Results: $RESULTS_FILE" +echo "================================================================" +echo "" +cat "$RESULTS_FILE" | column -t -s',' diff --git a/tests/python/ccl/bench_gemm_standalone.py b/tests/python/ccl/bench_gemm_standalone.py new file mode 100644 index 00000000..f83f3fc4 --- /dev/null +++ b/tests/python/ccl/bench_gemm_standalone.py @@ -0,0 +1,97 @@ +#!/usr/bin/env python3 +""" +GEMM standalone benchmark — measure baseline latency without any CCL overlap. +Runs on a single GPU (default GPU 0). +""" + +import argparse +import torch +import numpy as np + +try: + import aiter + HAS_AITER = True +except ImportError: + HAS_AITER = False + + +def bench_gemm(M, N, K, iterations, warmup, device_id=0): + device = torch.device(f"cuda:{device_id}") + A_q = torch.randint(-127, 127, (M, K), dtype=torch.int8, device=device) + B_q = torch.randint(-127, 127, (K, N), dtype=torch.int8, device=device) + A_scale = torch.randn(M, dtype=torch.float32, device=device) + B_scale = torch.randn(N, dtype=torch.float32, device=device) + bias = torch.randn(N, dtype=torch.bfloat16, device=device) + + stream = torch.cuda.Stream(device=device) + start_ev = torch.cuda.Event(enable_timing=True) + end_ev = torch.cuda.Event(enable_timing=True) + + times = [] + total_iters = warmup + iterations + + for i in range(total_iters): + torch.cuda.synchronize(device) + start_ev.record(stream) + with torch.cuda.stream(stream): + _ = aiter.gemm_a8w8_CK(A_q, B_q, A_scale, B_scale, bias, torch.bfloat16) + end_ev.record(stream) + stream.synchronize() + + t = start_ev.elapsed_time(end_ev) # ms + if i >= warmup: + times.append(t) + + return times + + +def main(): + parser = argparse.ArgumentParser(description="GEMM standalone benchmark") + parser.add_argument("--sizes", type=int, nargs="+", default=[4096, 8192, 16384], + help="List of M=N=K sizes to benchmark") + parser.add_argument("--iterations", type=int, default=20) + parser.add_argument("--warmup", type=int, default=10) + parser.add_argument("--device", type=int, default=0, help="GPU device id") + args = parser.parse_args() + + if not HAS_AITER: + print("ERROR: aiter is required for GEMM benchmark") + return + + print(f"GEMM Standalone Benchmark (INT8 A8W8, aiter.gemm_a8w8_CK)") + print(f" Device: cuda:{args.device}") + print(f" Iterations: {args.iterations}, Warmup: {args.warmup}") + print(f" Sizes: {args.sizes}") + print("=" * 60) + + results = {} + for size in args.sizes: + M = N = K = size + print(f"\nBenchmarking GEMM M=N=K={size} ...") + times = bench_gemm(M, N, K, args.iterations, args.warmup, args.device) + + avg = np.mean(times) + mn = np.min(times) + mx = np.max(times) + std = np.std(times) + results[size] = {"avg": avg, "min": mn, "max": mx, "std": std} + + print(f" Min: {mn:.4f} ms") + print(f" Avg: {avg:.4f} ms") + print(f" Max: {mx:.4f} ms") + print(f" Std: {std:.4f} ms") + print(f" All times (ms): {[f'{t:.4f}' for t in times]}") + + print(f"\n{'='*60}") + print(f"Summary") + print(f"{'='*60}") + print(f"| {'GEMM Size':>10} | {'Min (ms)':>10} | {'Avg (ms)':>10} | {'Max (ms)':>10} | {'Std (ms)':>10} |") + print(f"|{'-'*12}|{'-'*12}|{'-'*12}|{'-'*12}|{'-'*12}|") + for size in args.sizes: + r = results[size] + print(f"| {size:>10} | {r['min']:>10.4f} | {r['avg']:>10.4f} | {r['max']:>10.4f} | {r['std']:>10.4f} |") + print(f"{'='*60}") + + +if __name__ == "__main__": + main() diff --git a/tests/python/ccl/bench_reducescatter_sweep.sh b/tests/python/ccl/bench_reducescatter_sweep.sh new file mode 100755 index 00000000..05a5edf5 --- /dev/null +++ b/tests/python/ccl/bench_reducescatter_sweep.sh @@ -0,0 +1,68 @@ +#!/bin/bash +# ReduceScatter Latency Sweep: SDMA (MORI) vs RCCL +# Sweeps buffer sizes from 10MB to 128MB per PE (output) + +export PATH=$HOME/.local/bin:$PATH +export PYTHONPATH=$HOME/mori:$PYTHONPATH +export HSA_NO_SCRATCH_RECLAIM=1 + +MORI_DIR="$HOME/mori" +RESULTS_FILE="$MORI_DIR/reducescatter_sweep_results.csv" + +# uint32 = 4 bytes per element +# --elems = output elements per PE +declare -A SIZE_TO_ELEMS +SIZE_TO_ELEMS[10]=2621440 +SIZE_TO_ELEMS[20]=5242880 +SIZE_TO_ELEMS[40]=10485760 +SIZE_TO_ELEMS[80]=20971520 +SIZE_TO_ELEMS[128]=33554432 + +WORLD_SIZE=8 +ITERS=10 +WARMUP=10 + +echo "OutputSize_MB,SDMA_AvgTime_ms,SDMA_BW_GBs,RCCL_AvgTime_ms,RCCL_BW_GBs" > "$RESULTS_FILE" + +for SIZE_MB in 10 20 40 80 128; do + ELEMS=${SIZE_TO_ELEMS[$SIZE_MB]} + INPUT_MB=$((SIZE_MB * WORLD_SIZE)) + echo "" + echo "================================================================" + echo " Output: ${SIZE_MB}MB/PE, Input: ${INPUT_MB}MB/PE (${ELEMS} elements)" + echo "================================================================" + + # --- SDMA ON (MORI ReduceScatterSdma) --- + echo "[${SIZE_MB}MB] Running ReduceScatterSdma (SDMA ON) ..." + SDMA_OUT=$(cd "$MORI_DIR" && python3 ./tests/python/ccl/test_reducescatter_overlap.py \ + --elems "$ELEMS" --world-size "$WORLD_SIZE" \ + --iterations "$ITERS" --warmup "$WARMUP" --enable-sdma 1 2>&1) + + SDMA_AVG=$(echo "$SDMA_OUT" | grep "Avg:" | tail -1 | awk '{print $4}' | sed 's/s,//') + SDMA_BW=$(echo "$SDMA_OUT" | grep "Bandwidth:" | tail -1 | awk '{print $2}') + + echo " SDMA Avg time: ${SDMA_AVG}s, BW: ${SDMA_BW} GB/s" + + # --- SDMA OFF (RCCL dist.reduce_scatter) --- + echo "[${SIZE_MB}MB] Running RCCL ReduceScatter (SDMA OFF) ..." + RCCL_OUT=$(cd "$MORI_DIR" && python3 ./tests/python/ccl/test_rccl_reducescatter.py \ + --elems "$ELEMS" --world-size "$WORLD_SIZE" \ + --iterations "$ITERS" --warmup "$WARMUP" 2>&1) + + RCCL_AVG=$(echo "$RCCL_OUT" | grep "Avg:" | tail -1 | awk '{print $4}' | sed 's/s,//') + RCCL_BW=$(echo "$RCCL_OUT" | grep "Bandwidth:" | tail -1 | awk '{print $2}') + + echo " RCCL Avg time: ${RCCL_AVG}s, BW: ${RCCL_BW} GB/s" + + SDMA_MS=$(echo "$SDMA_AVG" | awk '{printf "%.3f", $1 * 1000}') + RCCL_MS=$(echo "$RCCL_AVG" | awk '{printf "%.3f", $1 * 1000}') + + echo "${SIZE_MB},${SDMA_MS},${SDMA_BW},${RCCL_MS},${RCCL_BW}" >> "$RESULTS_FILE" +done + +echo "" +echo "================================================================" +echo " REDUCESCATTER SWEEP COMPLETE — Results: $RESULTS_FILE" +echo "================================================================" +echo "" +cat "$RESULTS_FILE" | column -t -s',' diff --git a/tests/python/ccl/bench_rs_overlap_sweep.sh b/tests/python/ccl/bench_rs_overlap_sweep.sh new file mode 100755 index 00000000..46c244b5 --- /dev/null +++ b/tests/python/ccl/bench_rs_overlap_sweep.sh @@ -0,0 +1,76 @@ +#!/bin/bash +# ReduceScatter + GEMM Overlap Sweep: SDMA vs RCCL across GEMM sizes +# Fixed ReduceScatter at 128MB/PE output, sweeps GEMM M=N=K from 4096 to 16384 + +export PATH=$HOME/.local/bin:$PATH +export PYTHONPATH=$HOME/mori:$PYTHONPATH +export HSA_NO_SCRATCH_RECLAIM=1 + +MORI_DIR="$HOME/mori" +RESULTS_FILE="$MORI_DIR/rs_overlap_sweep_results.csv" + +RS_ELEMS=33554432 # 128 MB/PE output +WORLD_SIZE=8 +ITERS=10 +WARMUP=10 + +GEMM_SIZES="4096 8192 16384" + +echo "GEMM_MNK,SDMA_RS_ms,SDMA_GEMM_ms,SDMA_Overlap_ms,SDMA_SeqTotal_ms,SDMA_Speedup,RCCL_RS_ms,RCCL_GEMM_ms,RCCL_Overlap_ms,RCCL_SeqTotal_ms,RCCL_Speedup" > "$RESULTS_FILE" + +for GEMM_SIZE in $GEMM_SIZES; do + echo "" + echo "================================================================" + echo " GEMM M=N=K=${GEMM_SIZE}, ReduceScatter ${RS_ELEMS} output elems (128MB/PE)" + echo "================================================================" + + # --- SDMA ON --- + echo "[GEMM=${GEMM_SIZE}] Running SDMA + GEMM overlap ..." + SDMA_OUT=$(cd "$MORI_DIR" && python3 ./tests/python/ccl/test_reducescatter_overlap.py \ + --elems "$RS_ELEMS" --world-size "$WORLD_SIZE" \ + --iterations "$ITERS" --warmup "$WARMUP" --enable-sdma 1 \ + --use-custom-stream --test-gemm-overlap \ + --gemm-m "$GEMM_SIZE" --gemm-n "$GEMM_SIZE" --gemm-k "$GEMM_SIZE" 2>&1) + + SDMA_RS=$(echo "$SDMA_OUT" | grep "ReduceScatter avg:" | tail -1 | awk '{print $3}' | sed 's/s//') + SDMA_GEMM=$(echo "$SDMA_OUT" | grep "GEMM avg:" | tail -1 | awk '{print $3}' | sed 's/s//') + SDMA_OVERLAP=$(echo "$SDMA_OUT" | grep "Overlap time (measured):" | tail -1 | awk '{print $4}' | sed 's/s//') + SDMA_SEQ=$(echo "$SDMA_OUT" | grep "Sequential baseline:" | tail -1 | awk '{print $3}' | sed 's/s//') + SDMA_SPEEDUP=$(echo "$SDMA_OUT" | grep "Speedup:" | tail -1 | awk '{print $2}') + + echo " SDMA: RS=${SDMA_RS}s GEMM=${SDMA_GEMM}s Overlap=${SDMA_OVERLAP}s Speedup=${SDMA_SPEEDUP}" + + # --- RCCL --- + echo "[GEMM=${GEMM_SIZE}] Running RCCL + GEMM overlap ..." + RCCL_OUT=$(cd "$MORI_DIR" && python3 ./tests/python/ccl/test_rccl_reducescatter.py \ + --elems "$RS_ELEMS" --world-size "$WORLD_SIZE" \ + --iterations "$ITERS" --warmup "$WARMUP" \ + --use-custom-stream --test-gemm-overlap \ + --gemm-m "$GEMM_SIZE" --gemm-n "$GEMM_SIZE" --gemm-k "$GEMM_SIZE" 2>&1) + + RCCL_RS=$(echo "$RCCL_OUT" | grep "ReduceScatter avg:" | tail -1 | awk '{print $3}' | sed 's/s//') + RCCL_GEMM=$(echo "$RCCL_OUT" | grep "GEMM avg:" | tail -1 | awk '{print $3}' | sed 's/s//') + RCCL_OVERLAP=$(echo "$RCCL_OUT" | grep "Overlap time (measured):" | tail -1 | awk '{print $4}' | sed 's/s//') + RCCL_SEQ=$(echo "$RCCL_OUT" | grep "Sequential baseline:" | tail -1 | awk '{print $3}' | sed 's/s//') + RCCL_SPEEDUP=$(echo "$RCCL_OUT" | grep "Speedup:" | tail -1 | awk '{print $2}') + + echo " RCCL: RS=${RCCL_RS}s GEMM=${RCCL_GEMM}s Overlap=${RCCL_OVERLAP}s Speedup=${RCCL_SPEEDUP}" + + SDMA_RS_MS=$(echo "$SDMA_RS" | awk '{printf "%.3f", $1 * 1000}') + SDMA_GEMM_MS=$(echo "$SDMA_GEMM" | awk '{printf "%.3f", $1 * 1000}') + SDMA_OVERLAP_MS=$(echo "$SDMA_OVERLAP" | awk '{printf "%.3f", $1 * 1000}') + SDMA_SEQ_MS=$(echo "$SDMA_SEQ" | awk '{printf "%.3f", $1 * 1000}') + RCCL_RS_MS=$(echo "$RCCL_RS" | awk '{printf "%.3f", $1 * 1000}') + RCCL_GEMM_MS=$(echo "$RCCL_GEMM" | awk '{printf "%.3f", $1 * 1000}') + RCCL_OVERLAP_MS=$(echo "$RCCL_OVERLAP" | awk '{printf "%.3f", $1 * 1000}') + RCCL_SEQ_MS=$(echo "$RCCL_SEQ" | awk '{printf "%.3f", $1 * 1000}') + + echo "${GEMM_SIZE},${SDMA_RS_MS},${SDMA_GEMM_MS},${SDMA_OVERLAP_MS},${SDMA_SEQ_MS},${SDMA_SPEEDUP},${RCCL_RS_MS},${RCCL_GEMM_MS},${RCCL_OVERLAP_MS},${RCCL_SEQ_MS},${RCCL_SPEEDUP}" >> "$RESULTS_FILE" +done + +echo "" +echo "================================================================" +echo " RS OVERLAP SWEEP COMPLETE — Results: $RESULTS_FILE" +echo "================================================================" +echo "" +cat "$RESULTS_FILE" | column -t -s',' diff --git a/tests/python/ccl/test_allgather_overlap.py b/tests/python/ccl/test_allgather_overlap.py index f1003f72..bf5138e6 100644 --- a/tests/python/ccl/test_allgather_overlap.py +++ b/tests/python/ccl/test_allgather_overlap.py @@ -19,7 +19,7 @@ print("Warning: aiter not available, gemm timing will be disabled") -def _test_allgather(rank, world_size, port, elems, iterations, warmup, use_custom_stream, test_gemm_overlap): +def _test_allgather(rank, world_size, port, elems, iterations, warmup, use_custom_stream, test_gemm_overlap, gemm_m=4096, gemm_n=4096, gemm_k=4096): """Worker function for each process""" with TorchDistContext(rank=rank, world_size=world_size, master_port=port): @@ -84,8 +84,7 @@ def _test_allgather(rank, world_size, port, elems, iterations, warmup, use_custo # Prepare GEMM test data if testing overlap A_q = B_q = A_scale = B_scale = bias = None if test_gemm_overlap and HAS_AITER: - # Create sample GEMM matrices for testing overlap - M, N, K = 4096, 4096, 4096 + M, N, K = gemm_m, gemm_n, gemm_k A_q = torch.randint(-127, 127, (M, K), dtype=torch.int8, device=device) B_q = torch.randint(-127, 127, (K, N), dtype=torch.int8, device=device) A_scale = torch.randn(M, dtype=torch.float32, device=device) @@ -618,13 +617,13 @@ def _test_allgather(rank, world_size, port, elems, iterations, warmup, use_custo raise AssertionError(f"PE {rank}: Allgather verification failed") -def test_allgather(elems=67108864, world_size=8, iterations=10, warmup=1, use_custom_stream=False, test_gemm_overlap=False): +def test_allgather(elems=67108864, world_size=8, iterations=10, warmup=1, use_custom_stream=False, test_gemm_overlap=False, gemm_m=4096, gemm_n=4096, gemm_k=4096): """Run Allgather SDMA test""" os.environ.setdefault('MORI_ENABLE_SDMA', '1') port = get_free_port() torch.multiprocessing.spawn( _test_allgather, - args=(world_size, port, elems, iterations, warmup, use_custom_stream, test_gemm_overlap), + args=(world_size, port, elems, iterations, warmup, use_custom_stream, test_gemm_overlap, gemm_m, gemm_n, gemm_k), nprocs=world_size, join=True, ) @@ -644,6 +643,9 @@ def test_allgather(elems=67108864, world_size=8, iterations=10, warmup=1, use_cu parser.add_argument("--enable-sdma", type=int, default=1, choices=[0, 1], help="Enable SDMA") parser.add_argument("--use-custom-stream", action="store_true", help="Use custom CUDA stream instead of default stream") parser.add_argument("--test-gemm-overlap", action="store_true", help="Test GEMM and AllGather overlap on different streams") + parser.add_argument("--gemm-m", type=int, default=4096, help="GEMM M dimension (default: 4096)") + parser.add_argument("--gemm-n", type=int, default=4096, help="GEMM N dimension (default: 4096)") + parser.add_argument("--gemm-k", type=int, default=4096, help="GEMM K dimension (default: 4096)") args = parser.parse_args() os.environ['MORI_ENABLE_SDMA'] = str(args.enable_sdma) @@ -654,8 +656,10 @@ def test_allgather(elems=67108864, world_size=8, iterations=10, warmup=1, use_cu print(f" Warmup: {args.warmup}") print(f" Custom Stream: {args.use_custom_stream}") print(f" Test GEMM Overlap: {args.test_gemm_overlap}") - if args.test_gemm_overlap and not HAS_AITER: - print(f" WARNING: aiter not available, GEMM testing will be skipped") + if args.test_gemm_overlap: + print(f" GEMM Dimensions: M={args.gemm_m}, N={args.gemm_n}, K={args.gemm_k}") + if not HAS_AITER: + print(f" WARNING: aiter not available, GEMM testing will be skipped") print("-" * 60) - test_allgather(args.elems, args.world_size, args.iterations, args.warmup, args.use_custom_stream, args.test_gemm_overlap) + test_allgather(args.elems, args.world_size, args.iterations, args.warmup, args.use_custom_stream, args.test_gemm_overlap, args.gemm_m, args.gemm_n, args.gemm_k) diff --git a/tests/python/ccl/test_rccl_allgather.py b/tests/python/ccl/test_rccl_allgather.py index ec151b3b..26bb91b6 100644 --- a/tests/python/ccl/test_rccl_allgather.py +++ b/tests/python/ccl/test_rccl_allgather.py @@ -17,7 +17,7 @@ print("Warning: aiter not available, gemm timing will be disabled") -def _test_allgather(rank, world_size, port, elems, iterations, warmup, use_custom_stream, test_gemm_overlap): +def _test_allgather(rank, world_size, port, elems, iterations, warmup, use_custom_stream, test_gemm_overlap, gemm_m=4096, gemm_n=4096, gemm_k=4096): """Worker function for each process""" with TorchDistContext(rank=rank, world_size=world_size, master_port=port): @@ -42,14 +42,14 @@ def _test_allgather(rank, world_size, port, elems, iterations, warmup, use_custo # Allocate GPU memory device = torch.device(f"cuda:{rank}") - input_tensor = torch.zeros(elems_per_pe, dtype=torch.uint32, device=device) + input_tensor = torch.zeros(elems_per_pe, dtype=torch.int32, device=device) # Prepare output tensor list for all_gather - output_tensor_list = [torch.zeros(elems_per_pe, dtype=torch.uint32, device=device) for _ in range(npes)] + output_tensor_list = [torch.zeros(elems_per_pe, dtype=torch.int32, device=device) for _ in range(npes)] # Prepare data: Each PE has unique value = (rank + 1) * 1000 value = (rank + 1) * 1000 - input_data_cpu = np.full(elems_per_pe, value, dtype=np.uint32) + input_data_cpu = np.full(elems_per_pe, value, dtype=np.int32) # Copy to GPU input_tensor.copy_(torch.from_numpy(input_data_cpu)) @@ -71,8 +71,7 @@ def _test_allgather(rank, world_size, port, elems, iterations, warmup, use_custo # Prepare GEMM test data if testing overlap A_q = B_q = A_scale = B_scale = bias = None if test_gemm_overlap and HAS_AITER: - # Create sample GEMM matrices for testing overlap - M, N, K = 4096, 4096, 4096 + M, N, K = gemm_m, gemm_n, gemm_k A_q = torch.randint(-127, 127, (M, K), dtype=torch.int8, device=device) B_q = torch.randint(-127, 127, (K, N), dtype=torch.int8, device=device) A_scale = torch.randn(M, dtype=torch.float32, device=device) @@ -543,12 +542,12 @@ def _test_allgather(rank, world_size, port, elems, iterations, warmup, use_custo raise AssertionError(f"PE {rank}: Allgather verification failed") -def test_allgather(elems=67108864, world_size=8, iterations=10, warmup=1, use_custom_stream=False, test_gemm_overlap=False): +def test_allgather(elems=67108864, world_size=8, iterations=10, warmup=1, use_custom_stream=False, test_gemm_overlap=False, gemm_m=4096, gemm_n=4096, gemm_k=4096): """Run Allgather RCCL test""" port = get_free_port() torch.multiprocessing.spawn( _test_allgather, - args=(world_size, port, elems, iterations, warmup, use_custom_stream, test_gemm_overlap), + args=(world_size, port, elems, iterations, warmup, use_custom_stream, test_gemm_overlap, gemm_m, gemm_n, gemm_k), nprocs=world_size, join=True, ) @@ -567,6 +566,9 @@ def test_allgather(elems=67108864, world_size=8, iterations=10, warmup=1, use_cu parser.add_argument("--warmup", type=int, default=20, help="Warmup iterations") parser.add_argument("--use-custom-stream", action="store_true", help="Use custom CUDA stream instead of default stream") parser.add_argument("--test-gemm-overlap", action="store_true", help="Test GEMM and AllGather overlap on different streams") + parser.add_argument("--gemm-m", type=int, default=4096, help="GEMM M dimension (default: 4096)") + parser.add_argument("--gemm-n", type=int, default=4096, help="GEMM N dimension (default: 4096)") + parser.add_argument("--gemm-k", type=int, default=4096, help="GEMM K dimension (default: 4096)") args = parser.parse_args() print(f"Allgather RCCL Test") @@ -576,8 +578,10 @@ def test_allgather(elems=67108864, world_size=8, iterations=10, warmup=1, use_cu print(f" Warmup: {args.warmup}") print(f" Custom Stream: {args.use_custom_stream}") print(f" Test GEMM Overlap: {args.test_gemm_overlap}") - if args.test_gemm_overlap and not HAS_AITER: - print(f" WARNING: aiter not available, GEMM testing will be skipped") + if args.test_gemm_overlap: + print(f" GEMM Dimensions: M={args.gemm_m}, N={args.gemm_n}, K={args.gemm_k}") + if not HAS_AITER: + print(f" WARNING: aiter not available, GEMM testing will be skipped") print("-" * 60) - test_allgather(args.elems, args.world_size, args.iterations, args.warmup, args.use_custom_stream, args.test_gemm_overlap) + test_allgather(args.elems, args.world_size, args.iterations, args.warmup, args.use_custom_stream, args.test_gemm_overlap, args.gemm_m, args.gemm_n, args.gemm_k) diff --git a/tests/python/ccl/test_rccl_reducescatter.py b/tests/python/ccl/test_rccl_reducescatter.py new file mode 100644 index 00000000..4421bc93 --- /dev/null +++ b/tests/python/ccl/test_rccl_reducescatter.py @@ -0,0 +1,352 @@ +#!/usr/bin/env python3 +""" +ReduceScatter RCCL Test using torch.distributed and multiprocessing +""" + +import os +import numpy as np +import torch +import torch.distributed as dist +from tests.python.utils import TorchDistContext, get_free_port + +try: + import aiter + HAS_AITER = True +except ImportError: + HAS_AITER = False + print("Warning: aiter not available, gemm timing will be disabled") + + +def _test_reducescatter(rank, world_size, port, elems, iterations, warmup, + use_custom_stream, test_gemm_overlap, + gemm_m=4096, gemm_n=4096, gemm_k=4096): + """Worker function for each process""" + + with TorchDistContext(rank=rank, world_size=world_size, master_port=port): + npes = world_size + + elems_per_pe = elems + total_elems = elems_per_pe * npes + input_bytes = total_elems * 4 + output_bytes = elems_per_pe * 4 + + if rank == 0: + print(f"\n{'='*60}") + print(f"ReduceScatter RCCL Test") + print(f"World size: {world_size}") + print(f"Elements per PE (output): {elems_per_pe:,}") + print(f"Total elements per PE (input): {total_elems:,}") + print(f"Data size: {input_bytes / (1024**2):.2f} MB input, {output_bytes / (1024**2):.2f} MB output per PE") + print(f"Iterations: {iterations}" + (f" (warmup: {warmup})" if warmup > 0 else "")) + print(f"Custom Stream: {'Yes' if use_custom_stream else 'No (default stream)'}") + print(f"{'='*60}\n") + + print(f"PE {rank}/{world_size}: Initialized") + + device = torch.device(f"cuda:{rank}") + + # Each PE has total_elems input elements, divided into npes chunks. + # Chunk[i] = (rank + 1) * 1000 + i for verification. + input_tensor = torch.zeros(total_elems, dtype=torch.int32, device=device) + for i in range(npes): + start = i * elems_per_pe + end = (i + 1) * elems_per_pe + input_tensor[start:end] = (rank + 1) * 1000 + i + + # Output: elems_per_pe elements (the reduced chunk for this rank) + output_tensor = torch.zeros(elems_per_pe, dtype=torch.int32, device=device) + + # For dist.reduce_scatter, input is a list of tensors (one per PE's contribution) + input_list = list(input_tensor.chunk(npes)) + + if rank == 0: + print(f"\n=== Data Pattern ===") + print(f"Each PE contributes {npes} chunks of {elems_per_pe:,} elements") + print(f"PE r, chunk i has value: (r+1)*1000 + i") + print(f"\nAfter ReduceScatter, PE r gets reduced chunk r:") + for r in range(npes): + expected = sum((pe + 1) * 1000 + r for pe in range(npes)) + print(f" PE {r} output = sum over all PEs of chunk[{r}] = {expected}") + print() + + # Prepare GEMM test data if testing overlap + A_q = B_q = A_scale = B_scale = bias = None + if test_gemm_overlap and HAS_AITER: + M, N, K = gemm_m, gemm_n, gemm_k + A_q = torch.randint(-127, 127, (M, K), dtype=torch.int8, device=device) + B_q = torch.randint(-127, 127, (K, N), dtype=torch.int8, device=device) + A_scale = torch.randn(M, dtype=torch.float32, device=device) + B_scale = torch.randn(N, dtype=torch.float32, device=device) + bias = torch.randn(N, dtype=torch.bfloat16, device=device) + if rank == 0: + print(f"PE {rank}: Prepared GEMM test data (M={M}, N={N}, K={K})") + + stream_gemm = None + if use_custom_stream: + stream = torch.cuda.Stream(device=device) + if test_gemm_overlap and HAS_AITER: + stream_gemm = torch.cuda.Stream(device=device) + if rank == 0: + print(f"PE {rank}: Created separate CUDA streams for RS and GEMM") + else: + if rank == 0: + print(f"PE {rank}: Created custom CUDA stream") + else: + stream = None + if rank == 0: + print(f"PE {rank}: Using default CUDA stream") + + torch.cuda.synchronize() + dist.barrier() + + exec_times = [] + gemm_times = [] + overlap_times = [] + sequential_rs_times = [] + sequential_gemm_times = [] + total_iters = warmup + iterations + + rs_start = torch.cuda.Event(enable_timing=True) + rs_end = torch.cuda.Event(enable_timing=True) + + if test_gemm_overlap and HAS_AITER and stream_gemm is not None: + gemm_start = torch.cuda.Event(enable_timing=True) + gemm_end = torch.cuda.Event(enable_timing=True) + overlap_start = torch.cuda.Event(enable_timing=True) + overlap_end = torch.cuda.Event(enable_timing=True) + + # Step 1: Sequential baseline (if overlap test) + if use_custom_stream and test_gemm_overlap and HAS_AITER and stream_gemm is not None: + if rank == 0: + print(f"\n{'='*60}") + print(f"Step 1: Sequential Baseline Tests") + print(f"{'='*60}") + + if rank == 0: + print(f"\nTesting ReduceScatter sequentially (baseline)...") + for iter_idx in range(total_iters): + torch.cuda.synchronize() + if use_custom_stream: + rs_start.record(stream) + with torch.cuda.stream(stream): + dist.reduce_scatter(output_tensor, input_list) + rs_end.record(stream) + stream.synchronize() + else: + rs_start.record() + dist.reduce_scatter(output_tensor, input_list) + rs_end.record() + torch.cuda.synchronize() + rs_time = rs_start.elapsed_time(rs_end) / 1000.0 + if iter_idx >= warmup: + sequential_rs_times.append(rs_time) + elif rank == 0: + print(f" Warmup {iter_idx + 1}/{warmup}: {rs_time:.6f}s") + + if rank == 0: + print(f"\nTesting GEMM sequentially (baseline)...") + for iter_idx in range(total_iters): + torch.cuda.synchronize() + gemm_start.record(stream_gemm) + with torch.cuda.stream(stream_gemm): + _ = aiter.gemm_a8w8_CK(A_q, B_q, A_scale, B_scale, bias, torch.bfloat16) + gemm_end.record(stream_gemm) + stream_gemm.synchronize() + gemm_time = gemm_start.elapsed_time(gemm_end) / 1000.0 + if iter_idx >= warmup: + sequential_gemm_times.append(gemm_time) + elif rank == 0: + print(f" Warmup {iter_idx + 1}/{warmup}: {gemm_time:.6f}s") + + if rank == 0: + seq_rs_avg = np.mean(sequential_rs_times) + seq_gemm_avg = np.mean(sequential_gemm_times) + print(f"\nSequential Baseline Results:") + print(f" ReduceScatter: Min={np.min(sequential_rs_times):.6f}s, Avg={seq_rs_avg:.6f}s, Max={np.max(sequential_rs_times):.6f}s") + print(f" GEMM: Min={np.min(sequential_gemm_times):.6f}s, Avg={seq_gemm_avg:.6f}s, Max={np.max(sequential_gemm_times):.6f}s") + print(f" Total sequential: {seq_rs_avg + seq_gemm_avg:.6f}s") + print(f"\n{'='*60}") + print(f"Step 2: Concurrent Overlap Tests") + print(f"{'='*60}\n") + + # Step 2: Main test loop + for iter_idx in range(total_iters): + if use_custom_stream and test_gemm_overlap and HAS_AITER and stream_gemm is not None: + torch.cuda.synchronize() + overlap_start.record() + rs_start.record(stream) + gemm_start.record(stream_gemm) + + with torch.cuda.stream(stream): + dist.reduce_scatter(output_tensor, input_list) + with torch.cuda.stream(stream_gemm): + _ = aiter.gemm_a8w8_CK(A_q, B_q, A_scale, B_scale, bias, torch.bfloat16) + + rs_end.record(stream) + gemm_end.record(stream_gemm) + stream.synchronize() + stream_gemm.synchronize() + overlap_end.record() + torch.cuda.synchronize() + + rs_time = rs_start.elapsed_time(rs_end) / 1000.0 + gemm_time = gemm_start.elapsed_time(gemm_end) / 1000.0 + overlap_time = overlap_start.elapsed_time(overlap_end) / 1000.0 + + if iter_idx >= warmup: + exec_times.append(rs_time) + gemm_times.append(gemm_time) + overlap_times.append(overlap_time) + elif rank == 0: + print(f"Warmup {iter_idx+1}/{warmup}: RS={rs_time:.6f}s, GEMM={gemm_time:.6f}s, Overlap={overlap_time:.6f}s") + else: + if use_custom_stream: + rs_start.record(stream) + with torch.cuda.stream(stream): + dist.reduce_scatter(output_tensor, input_list) + rs_end.record(stream) + stream.synchronize() + else: + rs_start.record() + dist.reduce_scatter(output_tensor, input_list) + rs_end.record() + torch.cuda.synchronize() + rs_time = rs_start.elapsed_time(rs_end) / 1000.0 + if iter_idx >= warmup: + exec_times.append(rs_time) + elif rank == 0: + print(f"Warmup {iter_idx+1}/{warmup}: {rs_time:.6f}s") + + if use_custom_stream and stream: + stream.synchronize() + torch.cuda.synchronize() + + # Stats + avg_time = np.mean(exec_times) if exec_times else 0.0 + min_time = np.min(exec_times) if exec_times else 0.0 + max_time = np.max(exec_times) if exec_times else 0.0 + gemm_avg_time = np.mean(gemm_times) if gemm_times else 0.0 + overlap_avg_time = np.mean(overlap_times) if overlap_times else 0.0 + seq_rs_avg = np.mean(sequential_rs_times) if sequential_rs_times else 0.0 + seq_gemm_avg = np.mean(sequential_gemm_times) if sequential_gemm_times else 0.0 + + if rank == 0: + print(f"\n{'='*60}") + print(f"Performance Statistics") + print(f"{'='*60}") + print(f"ReduceScatter Times:") + print(f" Min time: {min_time:.6f}s") + print(f" Max time: {max_time:.6f}s") + print(f" Avg time: {avg_time:.6f}s") + bw = input_bytes / avg_time / (1024**3) if avg_time > 0 else 0 + print(f" Bandwidth: {bw:.2f} GB/s") + + if gemm_times: + print(f"\nSequential Baseline:") + print(f" ReduceScatter avg: {seq_rs_avg:.6f}s") + print(f" GEMM avg: {seq_gemm_avg:.6f}s") + print(f" Sequential total: {seq_rs_avg + seq_gemm_avg:.6f}s") + print(f"\nConcurrent:") + print(f" ReduceScatter avg: {avg_time:.6f}s") + print(f" GEMM avg: {gemm_avg_time:.6f}s") + + if overlap_times: + ideal = max(avg_time, gemm_avg_time) + seq_total = seq_rs_avg + seq_gemm_avg + speedup = seq_total / overlap_avg_time if overlap_avg_time > 0 else 0 + efficiency = (ideal / overlap_avg_time * 100) if overlap_avg_time > 0 else 0 + print(f"\nOverlap Analysis:") + print(f" Overlap time (measured): {overlap_avg_time:.6f}s") + print(f" Theoretical best: {ideal:.6f}s") + print(f" Sequential baseline: {seq_total:.6f}s") + print(f" Time saved: {seq_total - overlap_avg_time:.6f}s") + print(f" Speedup: {speedup:.2f}x") + print(f" Concurrency efficiency: {efficiency:.2f}%") + + # Verify + output_cpu = output_tensor.cpu().numpy() + expected = sum((pe + 1) * 1000 + rank for pe in range(npes)) + success = np.all(output_cpu == expected) + if success: + print(f"PE {rank}: Verification PASSED (all values = {expected})") + else: + print(f"PE {rank}: Verification FAILED! Expected {expected}, got unique: {np.unique(output_cpu)}") + + torch.cuda.synchronize() + dist.barrier() + + # Global stats + min_t = torch.tensor([min_time], dtype=torch.float64) + max_t = torch.tensor([max_time], dtype=torch.float64) + avg_t = torch.tensor([avg_time], dtype=torch.float64) + success_t = torch.tensor([1 if success else 0], dtype=torch.int32) + + dist.all_reduce(min_t, op=dist.ReduceOp.MIN) + dist.all_reduce(max_t, op=dist.ReduceOp.MAX) + dist.all_reduce(avg_t, op=dist.ReduceOp.SUM) + dist.all_reduce(success_t, op=dist.ReduceOp.SUM) + + if rank == 0: + g_avg = avg_t.item() / npes + g_bw = input_bytes / g_avg / (1024**3) if g_avg > 0 else 0 + print(f"\n{'='*60}") + print(f"Global Results") + print(f"{'='*60}") + print(f" Min: {min_t.item():.6f}s, Avg: {g_avg:.6f}s, Max: {max_t.item():.6f}s") + print(f" Bandwidth: {g_bw:.2f} GB/s") + print(f" PEs passed: {success_t.item()}/{npes}") + print(f"\n=== Test {'PASSED' if success_t.item() == npes else 'FAILED'} ===") + print(f"{'='*60}\n") + + torch.cuda.synchronize() + dist.barrier() + + if not success: + raise AssertionError(f"PE {rank}: ReduceScatter verification failed") + + +def test_reducescatter(elems=67108864, world_size=8, iterations=10, warmup=1, + use_custom_stream=False, test_gemm_overlap=False, + gemm_m=4096, gemm_n=4096, gemm_k=4096): + port = get_free_port() + torch.multiprocessing.spawn( + _test_reducescatter, + args=(world_size, port, elems, iterations, warmup, use_custom_stream, + test_gemm_overlap, gemm_m, gemm_n, gemm_k), + nprocs=world_size, + join=True, + ) + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser( + description="Test ReduceScatter RCCL (torch.distributed)", + formatter_class=argparse.RawDescriptionHelpFormatter + ) + parser.add_argument("--elems", type=int, default=33554432, help="Output elements per PE") + parser.add_argument("--world-size", type=int, default=8, help="Number of processes") + parser.add_argument("--iterations", type=int, default=50, help="Number of iterations") + parser.add_argument("--warmup", type=int, default=20, help="Warmup iterations") + parser.add_argument("--use-custom-stream", action="store_true") + parser.add_argument("--test-gemm-overlap", action="store_true") + parser.add_argument("--gemm-m", type=int, default=4096, help="GEMM M dimension") + parser.add_argument("--gemm-n", type=int, default=4096, help="GEMM N dimension") + parser.add_argument("--gemm-k", type=int, default=4096, help="GEMM K dimension") + args = parser.parse_args() + + print(f"ReduceScatter RCCL Test") + print(f" Output elements per PE: {args.elems:,}") + print(f" World size: {args.world_size}") + print(f" Iterations: {args.iterations}") + print(f" Warmup: {args.warmup}") + if args.test_gemm_overlap: + print(f" GEMM Dimensions: M={args.gemm_m}, N={args.gemm_n}, K={args.gemm_k}") + if not HAS_AITER: + print(f" WARNING: aiter not available") + print("-" * 60) + + test_reducescatter(args.elems, args.world_size, args.iterations, args.warmup, + args.use_custom_stream, args.test_gemm_overlap, + args.gemm_m, args.gemm_n, args.gemm_k) diff --git a/tests/python/ccl/test_reducescatter_overlap.py b/tests/python/ccl/test_reducescatter_overlap.py new file mode 100644 index 00000000..7f6e65e0 --- /dev/null +++ b/tests/python/ccl/test_reducescatter_overlap.py @@ -0,0 +1,378 @@ +#!/usr/bin/env python3 +""" +ReduceScatter SDMA Test using MORI ReduceScatterSdma and multiprocessing +""" + +import os +import numpy as np +import torch +import torch.distributed as dist +import mori.shmem as shmem +from mori.ccl import ReduceScatterSdma +from tests.python.utils import TorchDistContext, get_free_port + +try: + import aiter + HAS_AITER = True +except ImportError: + HAS_AITER = False + print("Warning: aiter not available, gemm timing will be disabled") + + +def _test_reducescatter(rank, world_size, port, elems, iterations, warmup, + use_custom_stream, test_gemm_overlap, + gemm_m=4096, gemm_n=4096, gemm_k=4096): + """Worker function for each process""" + + with TorchDistContext(rank=rank, world_size=world_size, master_port=port): + shmem.shmem_torch_process_group_init("default") + + my_pe = shmem.shmem_mype() + npes = shmem.shmem_npes() + + assert my_pe == rank + assert npes == world_size + + # ReduceScatter: input = elems_per_pe * npes, output = elems_per_pe + elems_per_pe = elems + total_elems = elems_per_pe * npes + input_bytes = total_elems * 4 + output_bytes = elems_per_pe * 4 + # Transit buffer needs to hold the gather buffer = total_elems * sizeof(T) + transit_buffer_bytes = total_elems * 4 + + if rank == 0: + print(f"\n{'='*60}") + print(f"ReduceScatter SDMA Test") + print(f"World size: {world_size}") + print(f"Elements per PE (output): {elems_per_pe:,}") + print(f"Total elements per PE (input): {total_elems:,}") + print(f"Data size: {input_bytes / (1024**2):.2f} MB input, {output_bytes / (1024**2):.2f} MB output per PE") + print(f"Transit buffer: {transit_buffer_bytes / (1024**2):.2f} MB") + print(f"Iterations: {iterations}" + (f" (warmup: {warmup})" if warmup > 0 else "")) + print(f"Custom Stream: {'Yes' if use_custom_stream else 'No (default stream)'}") + print(f"{'='*60}\n") + + print(f"PE {rank}/{world_size}: SHMEM initialized, myPe={my_pe}, npes={npes}") + + # Create ReduceScatter object + rs = ReduceScatterSdma(my_pe, npes, + input_buffer_size=input_bytes, + output_buffer_size=transit_buffer_bytes) + print(f"PE {rank}: Created ReduceScatterSdma object") + + device = torch.device(f"cuda:{rank}") + + # Input: total_elems elements. Chunk[i] = (rank+1)*1000 + i + input_tensor = torch.zeros(total_elems, dtype=torch.uint32, device=device) + for i in range(npes): + start = i * elems_per_pe + end = (i + 1) * elems_per_pe + val = (my_pe + 1) * 1000 + i + input_data = np.full(elems_per_pe, val, dtype=np.uint32) + input_tensor[start:end] = torch.from_numpy(input_data).to(device) + + # Output: elems_per_pe elements + output_tensor = torch.zeros(elems_per_pe, dtype=torch.uint32, device=device) + + if rank == 0: + print(f"\n=== Data Pattern ===") + print(f"Each PE contributes {npes} chunks of {elems_per_pe:,} elements") + print(f"PE r, chunk i has value: (r+1)*1000 + i") + print(f"\nAfter ReduceScatter, PE r gets reduced chunk r:") + for r in range(npes): + expected = sum((pe + 1) * 1000 + r for pe in range(npes)) + print(f" PE {r} output = {expected}") + print() + + # GEMM setup + A_q = B_q = A_scale = B_scale = bias = None + if test_gemm_overlap and HAS_AITER: + M, N, K = gemm_m, gemm_n, gemm_k + A_q = torch.randint(-127, 127, (M, K), dtype=torch.int8, device=device) + B_q = torch.randint(-127, 127, (K, N), dtype=torch.int8, device=device) + A_scale = torch.randn(M, dtype=torch.float32, device=device) + B_scale = torch.randn(N, dtype=torch.float32, device=device) + bias = torch.randn(N, dtype=torch.bfloat16, device=device) + if rank == 0: + print(f"PE {rank}: Prepared GEMM test data (M={M}, N={N}, K={K})") + + stream_gemm = None + if use_custom_stream: + stream = torch.cuda.Stream(device=device) + if test_gemm_overlap and HAS_AITER: + stream_gemm = torch.cuda.Stream(device=device) + if rank == 0: + print(f"PE {rank}: Created separate CUDA streams for RS and GEMM") + else: + if rank == 0: + print(f"PE {rank}: Created custom CUDA stream") + else: + stream = None + + torch.cuda.synchronize() + dist.barrier() + + exec_times = [] + gemm_times = [] + overlap_times = [] + sequential_rs_times = [] + sequential_gemm_times = [] + total_iters = warmup + iterations + + rs_start = torch.cuda.Event(enable_timing=True) + rs_end = torch.cuda.Event(enable_timing=True) + + if test_gemm_overlap and HAS_AITER and stream_gemm is not None: + gemm_start = torch.cuda.Event(enable_timing=True) + gemm_end = torch.cuda.Event(enable_timing=True) + overlap_start = torch.cuda.Event(enable_timing=True) + overlap_end = torch.cuda.Event(enable_timing=True) + + # Step 1: Sequential baseline + if use_custom_stream and test_gemm_overlap and HAS_AITER and stream_gemm is not None: + if rank == 0: + print(f"\n{'='*60}") + print(f"Step 1: Sequential Baseline Tests") + print(f"{'='*60}") + print(f"\nTesting ReduceScatter sequentially (baseline)...") + + for iter_idx in range(total_iters): + torch.cuda.synchronize() + if use_custom_stream: + rs_start.record(stream) + with torch.cuda.stream(stream): + success = rs(input_tensor, output_tensor, total_elems) + rs_end.record(stream) + stream.synchronize() + else: + rs_start.record() + success = rs(input_tensor, output_tensor, total_elems) + rs_end.record() + torch.cuda.synchronize() + rs_time = rs_start.elapsed_time(rs_end) / 1000.0 + if iter_idx >= warmup: + sequential_rs_times.append(rs_time) + elif rank == 0: + print(f" Warmup {iter_idx+1}/{warmup}: {rs_time:.6f}s") + + if rank == 0: + print(f"\nTesting GEMM sequentially (baseline)...") + for iter_idx in range(total_iters): + torch.cuda.synchronize() + gemm_start.record(stream_gemm) + with torch.cuda.stream(stream_gemm): + _ = aiter.gemm_a8w8_CK(A_q, B_q, A_scale, B_scale, bias, torch.bfloat16) + gemm_end.record(stream_gemm) + stream_gemm.synchronize() + gemm_time = gemm_start.elapsed_time(gemm_end) / 1000.0 + if iter_idx >= warmup: + sequential_gemm_times.append(gemm_time) + elif rank == 0: + print(f" Warmup {iter_idx+1}/{warmup}: {gemm_time:.6f}s") + + if rank == 0: + seq_rs_avg = np.mean(sequential_rs_times) + seq_gemm_avg = np.mean(sequential_gemm_times) + print(f"\nSequential Baseline Results:") + print(f" ReduceScatter: Min={np.min(sequential_rs_times):.6f}s, Avg={seq_rs_avg:.6f}s, Max={np.max(sequential_rs_times):.6f}s") + print(f" GEMM: Min={np.min(sequential_gemm_times):.6f}s, Avg={seq_gemm_avg:.6f}s, Max={np.max(sequential_gemm_times):.6f}s") + print(f" Total sequential: {seq_rs_avg + seq_gemm_avg:.6f}s") + print(f"\n{'='*60}") + print(f"Step 2: Concurrent Overlap Tests") + print(f"{'='*60}\n") + + # Step 2: Main test loop + for iter_idx in range(total_iters): + op_success = True + if use_custom_stream and test_gemm_overlap and HAS_AITER and stream_gemm is not None: + torch.cuda.synchronize() + overlap_start.record() + rs_start.record(stream) + gemm_start.record(stream_gemm) + + with torch.cuda.stream(stream): + op_success = rs(input_tensor, output_tensor, total_elems) + with torch.cuda.stream(stream_gemm): + _ = aiter.gemm_a8w8_CK(A_q, B_q, A_scale, B_scale, bias, torch.bfloat16) + + rs_end.record(stream) + gemm_end.record(stream_gemm) + stream.synchronize() + stream_gemm.synchronize() + overlap_end.record() + torch.cuda.synchronize() + + rs_time = rs_start.elapsed_time(rs_end) / 1000.0 + gemm_time = gemm_start.elapsed_time(gemm_end) / 1000.0 + overlap_time = overlap_start.elapsed_time(overlap_end) / 1000.0 + + if iter_idx >= warmup: + exec_times.append(rs_time) + gemm_times.append(gemm_time) + overlap_times.append(overlap_time) + elif rank == 0: + print(f"Warmup {iter_idx+1}/{warmup}: RS={rs_time:.6f}s, GEMM={gemm_time:.6f}s, Overlap={overlap_time:.6f}s") + else: + if use_custom_stream: + rs_start.record(stream) + with torch.cuda.stream(stream): + op_success = rs(input_tensor, output_tensor, total_elems) + rs_end.record(stream) + stream.synchronize() + else: + rs_start.record() + op_success = rs(input_tensor, output_tensor, total_elems) + rs_end.record() + torch.cuda.synchronize() + rs_time = rs_start.elapsed_time(rs_end) / 1000.0 + if iter_idx >= warmup: + exec_times.append(rs_time) + elif rank == 0: + print(f"Warmup {iter_idx+1}/{warmup}: {rs_time:.6f}s") + + if not op_success: + print(f"PE {rank}: ReduceScatter failed at iteration {iter_idx}") + break + + if use_custom_stream and stream: + stream.synchronize() + torch.cuda.synchronize() + + # Stats + avg_time = np.mean(exec_times) if exec_times else 0.0 + min_time = np.min(exec_times) if exec_times else 0.0 + max_time = np.max(exec_times) if exec_times else 0.0 + gemm_avg_time = np.mean(gemm_times) if gemm_times else 0.0 + overlap_avg_time = np.mean(overlap_times) if overlap_times else 0.0 + seq_rs_avg = np.mean(sequential_rs_times) if sequential_rs_times else 0.0 + seq_gemm_avg = np.mean(sequential_gemm_times) if sequential_gemm_times else 0.0 + + if rank == 0: + bw = input_bytes / avg_time / (1024**3) if avg_time > 0 else 0 + print(f"\n{'='*60}") + print(f"Performance Statistics") + print(f"{'='*60}") + print(f"ReduceScatter Times:") + print(f" Min: {min_time:.6f}s, Avg: {avg_time:.6f}s, Max: {max_time:.6f}s") + print(f" Bandwidth: {bw:.2f} GB/s") + + if gemm_times: + print(f"\nSequential Baseline:") + print(f" ReduceScatter avg: {seq_rs_avg:.6f}s") + print(f" GEMM avg: {seq_gemm_avg:.6f}s") + print(f" Sequential total: {seq_rs_avg + seq_gemm_avg:.6f}s") + print(f"\nConcurrent:") + print(f" ReduceScatter avg: {avg_time:.6f}s") + print(f" GEMM avg: {gemm_avg_time:.6f}s") + + if overlap_times: + ideal = max(avg_time, gemm_avg_time) + seq_total = seq_rs_avg + seq_gemm_avg + speedup = seq_total / overlap_avg_time if overlap_avg_time > 0 else 0 + efficiency = (ideal / overlap_avg_time * 100) if overlap_avg_time > 0 else 0 + print(f"\nOverlap Analysis:") + print(f" Overlap time (measured): {overlap_avg_time:.6f}s") + print(f" Theoretical best: {ideal:.6f}s") + print(f" Sequential baseline: {seq_total:.6f}s") + print(f" Time saved: {seq_total - overlap_avg_time:.6f}s") + print(f" Speedup: {speedup:.2f}x") + print(f" Concurrency efficiency: {efficiency:.2f}%") + print(f"{'='*60}") + + # Verify + output_cpu = output_tensor.cpu().numpy() + expected = sum((pe + 1) * 1000 + rank for pe in range(npes)) + # For uint32, cast expected + expected_val = np.uint32(expected) + success = np.all(output_cpu == expected_val) + if success: + print(f"PE {rank}: Verification PASSED (all values = {expected})") + else: + print(f"PE {rank}: Verification FAILED! Expected {expected}, got unique: {np.unique(output_cpu)}") + + torch.cuda.synchronize() + dist.barrier() + + # Global stats + min_t = torch.tensor([min_time], dtype=torch.float64) + max_t = torch.tensor([max_time], dtype=torch.float64) + avg_t = torch.tensor([avg_time], dtype=torch.float64) + success_t = torch.tensor([1 if success else 0], dtype=torch.int32) + + dist.all_reduce(min_t, op=dist.ReduceOp.MIN) + dist.all_reduce(max_t, op=dist.ReduceOp.MAX) + dist.all_reduce(avg_t, op=dist.ReduceOp.SUM) + dist.all_reduce(success_t, op=dist.ReduceOp.SUM) + + if rank == 0: + g_avg = avg_t.item() / npes + g_bw = input_bytes / g_avg / (1024**3) if g_avg > 0 else 0 + print(f"\n{'='*60}") + print(f"Global Results") + print(f"{'='*60}") + print(f" Min: {min_t.item():.6f}s, Avg: {g_avg:.6f}s, Max: {max_t.item():.6f}s") + print(f" Bandwidth: {g_bw:.2f} GB/s") + print(f" PEs passed: {success_t.item()}/{npes}") + print(f"\n=== Test {'PASSED' if success_t.item() == npes else 'FAILED'} ===") + print(f"{'='*60}\n") + + # Cleanup + torch.cuda.synchronize() + dist.barrier() + del rs + dist.barrier() + shmem.shmem_finalize() + + if not success: + raise AssertionError(f"PE {rank}: ReduceScatter verification failed") + + +def test_reducescatter(elems=33554432, world_size=8, iterations=10, warmup=1, + use_custom_stream=False, test_gemm_overlap=False, + gemm_m=4096, gemm_n=4096, gemm_k=4096): + os.environ.setdefault('MORI_ENABLE_SDMA', '1') + port = get_free_port() + torch.multiprocessing.spawn( + _test_reducescatter, + args=(world_size, port, elems, iterations, warmup, use_custom_stream, + test_gemm_overlap, gemm_m, gemm_n, gemm_k), + nprocs=world_size, + join=True, + ) + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser( + description="Test ReduceScatter SDMA (MORI)", + formatter_class=argparse.RawDescriptionHelpFormatter + ) + parser.add_argument("--elems", type=int, default=33554432, help="Output elements per PE") + parser.add_argument("--world-size", type=int, default=8, help="Number of processes") + parser.add_argument("--iterations", type=int, default=50, help="Number of iterations") + parser.add_argument("--warmup", type=int, default=20, help="Warmup iterations") + parser.add_argument("--enable-sdma", type=int, default=1, choices=[0, 1]) + parser.add_argument("--use-custom-stream", action="store_true") + parser.add_argument("--test-gemm-overlap", action="store_true") + parser.add_argument("--gemm-m", type=int, default=4096, help="GEMM M dimension") + parser.add_argument("--gemm-n", type=int, default=4096, help="GEMM N dimension") + parser.add_argument("--gemm-k", type=int, default=4096, help="GEMM K dimension") + args = parser.parse_args() + os.environ['MORI_ENABLE_SDMA'] = str(args.enable_sdma) + + print(f"ReduceScatter SDMA Test") + print(f" Output elements per PE: {args.elems:,}") + print(f" World size: {args.world_size}") + print(f" Iterations: {args.iterations}") + print(f" Warmup: {args.warmup}") + if args.test_gemm_overlap: + print(f" GEMM Dimensions: M={args.gemm_m}, N={args.gemm_n}, K={args.gemm_k}") + if not HAS_AITER: + print(f" WARNING: aiter not available") + print("-" * 60) + + test_reducescatter(args.elems, args.world_size, args.iterations, args.warmup, + args.use_custom_stream, args.test_gemm_overlap, + args.gemm_m, args.gemm_n, args.gemm_k)