Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
125 changes: 125 additions & 0 deletions include/mori/collective/reducescatter/reducescatter_sdma_class.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
// Copyright © Advanced Micro Devices, Inc. All rights reserved.
//
// MIT License
//
// Permission is hereby granted, free of charge, to any person obtaining a copy
// of this software and associated documentation files (the "Software"), to deal
// in the Software without restriction, including without limitation the rights
// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
// copies of the Software, and to permit persons to whom the Software is
// furnished to do so, subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in all
// copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
// SOFTWARE.

#ifndef REDUCESCATTER_SDMA_CLASS_HPP
#define REDUCESCATTER_SDMA_CLASS_HPP

#include <hip/hip_runtime.h>
#include <mpi.h>
#include <memory>
#include <cstdint>
#include <atomic>

#include "mori/application/application.hpp"
#include "mori/shmem/shmem.hpp"
#include "mori/collective/collective_pub.hpp"

namespace mori {
namespace collective {

struct CrossPeBarrier;

template <typename T>
class ReduceScatterSdma {
private:
int myPe_;
int npes_;
size_t dtype_size_;
int max_blocks_;

// SDMA completion flags
application::SymmMemObjPtr flagsObj_;
std::unique_ptr<uint64_t[], ShmemDeleter> flags_;

// Device-scope barrier for block-0-to-all broadcast
CrossPeBarrier* barrierPtr_;
std::unique_ptr<void, ShmemDeleter> barrierMem_;

// Transit buffer (gather buffer): npes * chunkSize slots for SDMA scatter
void* transit_buffer_;
size_t transit_buffer_size_;
application::SymmMemObjPtr transit_buffer_obj_;
std::unique_ptr<void, ShmemDeleter> transit_buffer_ptr_;

// Async state
std::atomic<bool> async_in_progress_;
T* async_input_;
T* async_output_;
size_t async_total_count_;
hipStream_t async_stream_;
double async_start_time_;

bool copy_output_to_user_;

ReduceScatterSdma(const ReduceScatterSdma&) = delete;
ReduceScatterSdma& operator=(const ReduceScatterSdma&) = delete;

bool ensure_buffer_size(void*& buffer,
std::unique_ptr<void, ShmemDeleter>& buffer_ptr,
size_t& current_size,
application::SymmMemObjPtr& buffer_obj,
size_t required_size,
const char* buffer_name);

void copy_result_to_user(T* output, size_t total_count, hipStream_t stream);

public:
/**
* @param myPe Current PE ID
* @param npes Total number of PEs
* @param transit_buffer_size Transit buffer size in bytes (default 512MB)
* @param copy_output_to_user If true, copy reduced shard to user output buffer
*/
ReduceScatterSdma(int myPe, int npes, size_t transit_buffer_size = 512 * 1024 * 1024,
bool copy_output_to_user = true);

ReduceScatterSdma(int myPe, int npes, size_t input_buffer_size, size_t output_buffer_size,
bool copy_output_to_user = true);

~ReduceScatterSdma();

/**
* @brief Synchronous ReduceScatter via SDMA
* @param input Input data — total_count elements per rank
* @param output Output data — total_count/npes reduced elements per rank
* @param total_count Number of input elements per PE
* @param stream HIP stream
*/
bool operator()(T* input, T* output, size_t total_count, hipStream_t stream = nullptr);

bool start_async(T* input, T* output, size_t total_count, hipStream_t stream = nullptr);
double wait_async(hipStream_t stream = nullptr);
bool is_async_in_progress() const { return async_in_progress_; }
void cancel_async();

application::SymmMemObjPtr getFlagsObj() const { return flagsObj_; }
void* getTransitBuffer() const { return transit_buffer_; }
size_t getTransitBufferSize() const { return transit_buffer_size_; }
application::SymmMemObjPtr getTransitBufferObj() const { return transit_buffer_obj_; }

void resetFlags();
};

} // namespace collective
} // namespace mori

#endif // REDUCESCATTER_SDMA_CLASS_HPP
3 changes: 2 additions & 1 deletion python/mori/ccl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,5 +23,6 @@
from .collective import All2allSdma
from .collective import AllgatherSdma
from .collective import AllreduceSdma
from .collective import ReduceScatterSdma

__all__ = ['All2allSdma', 'AllgatherSdma', 'AllreduceSdma']
__all__ = ['All2allSdma', 'AllgatherSdma', 'AllreduceSdma', 'ReduceScatterSdma']
58 changes: 58 additions & 0 deletions python/mori/ccl/collective.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,3 +410,61 @@ def get_output_transit_buffer(self, device=None):
while an operation is in progress.
"""
return self._handle.get_output_transit_buffer(device)


def _cpp_reducescatter_factory(entity_name: str):
"""Factory function to get C++ entities from mori_cpp module"""
return getattr(mori_cpp, entity_name)


class ReduceScatterSdma:
"""Python wrapper for ReduceScatterSdma C++ class.

Performs ReduceScatter via SDMA: each rank contributes total_count
elements; the result is total_count/npes reduced elements per rank.
"""

def __init__(self, my_pe: int, npes: int,
input_buffer_size: Optional[int] = None,
output_buffer_size: Optional[int] = None,
transit_buffer_size: Optional[int] = None,
copy_output_to_user: bool = True):
self.my_pe = my_pe
self.npes = npes
handle_class = _cpp_reducescatter_factory("ReduceScatterSdmaHandle")

if input_buffer_size is not None and output_buffer_size is not None:
self._handle = handle_class(my_pe, npes, input_buffer_size, output_buffer_size, copy_output_to_user)
elif transit_buffer_size is not None:
self._handle = handle_class(my_pe, npes, transit_buffer_size, copy_output_to_user)
else:
self._handle = handle_class(my_pe, npes, 512 * 1024 * 1024, copy_output_to_user)

def __call__(self, input_data, output_data, count: int, stream=None) -> bool:
"""Execute ReduceScatter SDMA operation.

Args:
input_data: Input CUDA tensor (total_count elements per rank)
output_data: Output CUDA tensor (total_count/npes elements per rank)
count: Total number of input elements per PE
stream: Optional HIP stream
"""
return self._handle(input_data, output_data, count, stream)

def start_async(self, input_data, output_data, count: int, stream=None) -> bool:
return self._handle.start_async(input_data, output_data, count, stream)

def wait_async(self, stream=None) -> float:
return self._handle.wait_async(stream)

def is_async_in_progress(self) -> bool:
return self._handle.is_async_in_progress()

def cancel_async(self):
self._handle.cancel_async()

def reset_flags(self):
self._handle.reset_flags()

def get_transit_buffer(self, device=None, dtype=None):
return self._handle.get_transit_buffer(device, dtype)
1 change: 1 addition & 0 deletions src/collective/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ set(COLLECTIVE_SOURCES
core/oneshot_all2all_sdma_class.cpp
core/oneshot_allgather_sdma_class.cpp
core/twoshot_allreduce_sdma_class.cpp
core/reducescatter_sdma_class.cpp
inter_node/executors/ring_1d_executor.cpp
inter_node/executors/one_shot_executor.cpp
# Note: intra_node_executor is header-only template class
Expand Down
Loading