diff --git a/CMakeLists.txt b/CMakeLists.txt index d775c41cf..3e4d5e4d6 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -27,6 +27,7 @@ option(USE_GLOO "Whether to build Gloo or not" ON) option(USE_RCCL "Whether to build RCCL or not" OFF) option(USE_RCCLX "Whether to build RCCLX or not" OFF) option(USE_XCCL "Whether to build XCCL or not" OFF) +option(USE_HCCL "Whether to build HCCL or not" OFF) option(USE_TRANSPORT "Whether to build TRANSPORT or not" ON) option(USE_TRITON "Whether to build Triton device bitcode or not" OFF) option(BUILD_TESTS "Whether to build tests or not" OFF) @@ -36,6 +37,7 @@ message(STATUS " USE_GLOO : ${USE_GLOO}") message(STATUS " USE_RCCL : ${USE_RCCL}") message(STATUS " USE_RCCLX : ${USE_RCCLX}") message(STATUS " USE_XCCL : ${USE_XCCL}") +message(STATUS " USE_HCCL : ${USE_HCCL}") message(STATUS " USE_TRANSPORT : ${USE_TRANSPORT}") message(STATUS " USE_TRITON : ${USE_TRITON}") message(STATUS " BUILD_TESTS : ${BUILD_TESTS}") @@ -240,6 +242,9 @@ endif() if (USE_XCCL) include(comms/torchcomms/xccl/CMakeLists.txt) endif() +if (USE_HCCL) + include(comms/torchcomms/hccl/CMakeLists.txt) +endif() if (USE_TRANSPORT) include(comms/torchcomms/transport/CMakeLists.txt) endif() diff --git a/README.md b/README.md index ea08252cf..f9d6e3c4d 100644 --- a/README.md +++ b/README.md @@ -22,6 +22,7 @@ torchcomms requires the following software and hardware: - PyTorch 2.8 or higher - CUDA-capable GPU (for NCCL/NCCLX or RCCL backends) - Intel XPU (for XCCL backend) +- Huawei Ascend NPU (for HCCL backend) ## Installation @@ -167,6 +168,23 @@ export USE_TRANSPORT=OFF pip install --no-build-isolation -v . ``` +##### HCCL Backend + +Source Ascend toolkit environment (update path to your Ascend installation) +```bash +export ASCEND_TOOLKIT_HOME=/usr/local/Ascend/ascend-toolkit/latest +source $ASCEND_TOOLKIT_HOME/set_env.sh +``` + +Enable HCCL backend and install +```bash +export USE_HCCL=ON +export USE_NCCL=OFF +export USE_NCCLX=OFF +export USE_TRANSPORT=OFF +pip install --no-build-isolation -v . +``` + #### Install torchcomms: @@ -201,6 +219,7 @@ export USE_GLOO=ON # Default: ON export USE_RCCL=OFF # Default: OFF export USE_RCCLX=OFF # Default: OFF export USE_XCCL=OFF # Default: OFF +export USE_HCCL=OFF # Default: OFF ``` Then run: diff --git a/comms/torchcomms/device/npu/NpuApi.cpp b/comms/torchcomms/device/npu/NpuApi.cpp new file mode 100644 index 000000000..b07ddc805 --- /dev/null +++ b/comms/torchcomms/device/npu/NpuApi.cpp @@ -0,0 +1,293 @@ +#include "comms/torchcomms/device/npu/NpuApi.hpp" +#include +#include +#include +#include +#include +#include +#include "comms/torchcomms/utils/Logging.hpp" + +namespace torch::comms { + +npu_result_t DefaultNpuApi::setDevice(int device) { + try { + ::c10_npu::set_device(device); + return NPU_SUCCESS; + } catch (const std::exception&) { + return NPU_ERROR_INVALID_VALUE; + } +} + +npu_result_t DefaultNpuApi::getDeviceProperties( + npuDeviceProp* prop, + int device) { + if (!prop) { + return NPU_ERROR_INVALID_VALUE; + } + + // Get device name + // ACL does not provide a simple "get device name" API here; use a + // descriptive default name instead. + snprintf(prop->name, sizeof(prop->name), "Ascend NPU %d", device); + + // Set device before getting memory info + auto result = setDevice(device); + if (result != NPU_SUCCESS) { + return result; + } + + size_t free_mem = 0; + size_t total_mem = 0; + if (aclrtGetMemInfo(ACL_HBM_MEM, &free_mem, &total_mem) != ACL_SUCCESS) { + return NPU_ERROR_INVALID_VALUE; + } + prop->totalGlobalMem = total_mem; + + if (aclGetDeviceCapability( + device, ACL_DEVICE_INFO_AI_CORE_NUM, &prop->cubeCoreNum) != + ACL_SUCCESS) { + return NPU_ERROR_INVALID_VALUE; + } + return NPU_SUCCESS; +} + +npu_result_t DefaultNpuApi::memGetInfo(size_t* free, size_t* total) { + if (!free || !total) { + return NPU_ERROR_INVALID_VALUE; + } + + if (aclrtGetMemInfo(ACL_HBM_MEM, free, total) != ACL_SUCCESS) { + *total = 0; + *free = 0; + return NPU_ERROR_INVALID_VALUE; + } + return NPU_SUCCESS; +} + +npu_result_t DefaultNpuApi::getDeviceCount(int* count) { + if (!count) { + return NPU_ERROR_INVALID_VALUE; + } + + try { + *count = static_cast(::c10_npu::device_count()); + return NPU_SUCCESS; + } catch (const std::exception&) { + return NPU_ERROR_INVALID_VALUE; + } +} + +npu_result_t DefaultNpuApi::streamCreateWithPriority( + npuStream_t& stream, + unsigned int flags, + int priority) { + (void)flags; + try { + bool is_high_priority = priority != 0; + auto device_index = ::c10_npu::current_device(); + stream = ::c10_npu::getStreamFromPool(is_high_priority, device_index); + return NPU_SUCCESS; + } catch (const std::exception&) { + return NPU_ERROR_INVALID_VALUE; + } +} + +npu_result_t DefaultNpuApi::streamDestroy(const npuStream_t& stream) { + (void)stream; + // Stream is managed by torch_npu + return NPU_SUCCESS; +} + +npu_result_t DefaultNpuApi::streamWaitEvent( + const npuStream_t& stream, + npuEvent_t& event, + unsigned int flags) { + (void)flags; + try { + event.block(stream); + return NPU_SUCCESS; + } catch (const std::exception&) { + return NPU_ERROR_INVALID_HANDLE; + } +} + +npuStream_t DefaultNpuApi::getCurrentNPUStream(int device_index) { + return ::c10_npu::getCurrentNPUStream(device_index); +} + +npu_result_t DefaultNpuApi::streamSynchronize(const npuStream_t& stream) { + try { + stream.synchronize(); + return NPU_SUCCESS; + } catch (const std::exception&) { + return NPU_ERROR_INVALID_HANDLE; + } +} + +npu_result_t DefaultNpuApi::streamIsCapturing( + npuStream_t stream, + npuStreamCaptureStatus* pCaptureStatus) { + if (!pCaptureStatus) { + return NPU_ERROR_INVALID_VALUE; + } + + // NPU/ACL doesn't support stream capture + *pCaptureStatus = npuStreamCaptureStatusNone; + return NPU_SUCCESS; +} + +npu_result_t DefaultNpuApi::streamGetCaptureInfo( + npuStream_t stream, + npuStreamCaptureStatus* pCaptureStatus, + unsigned long long* pId) { + if (!pCaptureStatus) { + return NPU_ERROR_INVALID_VALUE; + } + + *pCaptureStatus = npuStreamCaptureStatusNone; + if (pId) { + *pId = 0; + } + return NPU_SUCCESS; +} + +npu_result_t DefaultNpuApi::malloc(void** devPtr, size_t size) { + if (!devPtr) { + return NPU_ERROR_INVALID_VALUE; + } + + if (size == 0) { + *devPtr = nullptr; + return NPU_SUCCESS; + } + + if (aclrtMalloc(devPtr, size, ACL_MEM_MALLOC_HUGE_FIRST) != ACL_SUCCESS) { + *devPtr = nullptr; + return NPU_ERROR_OUT_OF_MEMORY; + } + return NPU_SUCCESS; +} + +npu_result_t DefaultNpuApi::free(void* devPtr) { + if (!devPtr) { + return NPU_SUCCESS; + } + + return aclrtFree(devPtr) == ACL_SUCCESS ? NPU_SUCCESS + : NPU_ERROR_INVALID_VALUE; +} + +npu_result_t DefaultNpuApi::memcpyAsync( + void* dst, + const void* src, + size_t count, + npuStream_t stream) { + if (!dst || !src) { + return NPU_ERROR_INVALID_VALUE; + } + + if (count == 0) { + return NPU_SUCCESS; + } + + return aclrtMemcpyAsync( + dst, count, src, count, ACL_MEMCPY_DEVICE_TO_DEVICE, stream.stream()); +} + +npu_result_t DefaultNpuApi::eventCreate(npuEvent_t& event) { + try { + event = ::c10_npu::NPUEvent(); + return NPU_SUCCESS; + } catch (const std::exception&) { + return NPU_ERROR_INVALID_VALUE; + } +} + +npu_result_t DefaultNpuApi::eventCreateWithFlags( + npuEvent_t& event, + unsigned int flags) { + try { + event = ::c10_npu::NPUEvent(flags); + return NPU_SUCCESS; + } catch (const std::exception&) { + return NPU_ERROR_INVALID_VALUE; + } +} + +npu_result_t DefaultNpuApi::eventDestroy(npuEvent_t& event) { + (void)event; + // NPUEvent is RAII, nothing to do + return NPU_SUCCESS; +} + +npu_result_t DefaultNpuApi::eventRecord( + npuEvent_t& event, + const npuStream_t& stream) { + try { + event.record(stream); + return NPU_SUCCESS; + } catch (const std::exception&) { + return NPU_ERROR_INVALID_HANDLE; + } +} + +npu_result_t DefaultNpuApi::eventQuery(const npuEvent_t& event) { + try { + return event.query() ? NPU_SUCCESS : NPU_ERROR_NOT_READY; + } catch (const std::exception&) { + return NPU_ERROR_INVALID_HANDLE; + } +} + +// Graph Operations (Unsupported) +npu_result_t DefaultNpuApi::userObjectCreate( + npuUserObject_t* object_out, + void* ptr, + npuHostFn_t destroy, + unsigned int initialRefcount, + unsigned int flags) { + // NPU/ACL doesn't support user objects + return NPU_ERROR_UNSUPPORTED; +} + +npu_result_t DefaultNpuApi::graphRetainUserObject( + npuGraph_t graph, + npuUserObject_t object, + unsigned int count, + unsigned int flags) { + // Currently, NPU/ACL doesn't support graphs + return NPU_ERROR_UNSUPPORTED; +} + +npu_result_t DefaultNpuApi::streamGetCaptureInfo_v2( + npuStream_t stream, + npuStreamCaptureStatus* captureStatus_out, + unsigned long long* id_out, + npuGraph_t* graph_out, + const npuGraphNode_t** dependencies_out, + size_t* numDependencies_out) { + // Currently, NPU/ACL doesn't support graphs + return NPU_ERROR_UNSUPPORTED; +} + +// Error Handling +const char* DefaultNpuApi::getErrorString(npu_result_t error) { + // ACL provides aclGetRecentErrMsg() for detailed errors + // For now, return basic error descriptions + switch (error) { + case ACL_SUCCESS: + return "success"; + case ACL_ERROR_INVALID_PARAM: + return "invalid parameter"; + case ACL_ERROR_INVALID_RESOURCE_HANDLE: + return "invalid handle"; + case ACL_ERROR_RT_MEMORY_ALLOCATION: + return "memory allocation failed"; + case ACL_ERROR_RT_FEATURE_NOT_SUPPORT: + return "feature not supported"; + default: + return aclGetRecentErrMsg(); + } +} + +} // namespace torch::comms diff --git a/comms/torchcomms/device/npu/NpuApi.hpp b/comms/torchcomms/device/npu/NpuApi.hpp new file mode 100644 index 000000000..508734497 --- /dev/null +++ b/comms/torchcomms/device/npu/NpuApi.hpp @@ -0,0 +1,200 @@ +#pragma once + +#include +#include +#include +#include +#include + +namespace torch::comms { + +using npuStream_t = ::c10_npu::NPUStream; +using npuEvent_t = ::c10_npu::NPUEvent; + +struct npuDeviceProp { + char name[256]; + size_t totalGlobalMem; + int64_t cubeCoreNum; +}; + +// Graph-related types (placeholder - may not be supported in NPU) +using npuGraph_t = void*; +using npuGraphNode_t = void*; +using npuUserObject_t = void*; +using npuHostFn_t = void (*)(void*); + +// Stream capture status (may not be supported in NPU) +enum npuStreamCaptureStatus { + npuStreamCaptureStatusNone = 0, +}; + +// Error code type +using npu_result_t = aclError; +constexpr npu_result_t NPU_SUCCESS = ACL_SUCCESS; +constexpr npu_result_t NPU_ERROR_INVALID_VALUE = ACL_ERROR_INVALID_PARAM; +constexpr npu_result_t NPU_ERROR_NOT_READY = ACL_ERROR_RT_FEATURE_NOT_SUPPORT; +constexpr npu_result_t NPU_ERROR_INVALID_HANDLE = + ACL_ERROR_INVALID_RESOURCE_HANDLE; +constexpr npu_result_t NPU_ERROR_OUT_OF_MEMORY = ACL_ERROR_RT_MEMORY_ALLOCATION; +constexpr npu_result_t NPU_ERROR_UNSUPPORTED = ACL_ERROR_RT_FEATURE_NOT_SUPPORT; + +#define NPU_CHECK(npu_api, call, err_str) \ + do { \ + npu_result_t status = call; \ + if (status != NPU_SUCCESS) { \ + std::stringstream ss; \ + ss << err_str << ": " << npu_api->getErrorString(status) << " at " \ + << __FILE__ << ":" << __LINE__; \ + throw std::runtime_error(ss.str()); \ + } \ + } while (0) + +/** + * Abstract interface for NPU API operations. + * This allows for dependency injection and testing by providing + * a way to override NPU API calls. + */ +class NpuApi { + public: + virtual ~NpuApi() = default; + + // Device management + virtual npu_result_t setDevice(int device) = 0; + virtual npu_result_t getDeviceProperties(npuDeviceProp* prop, int device) = 0; + virtual npu_result_t memGetInfo(size_t* free, size_t* total) = 0; + virtual npu_result_t getDeviceCount(int* count) = 0; + + // Stream management + virtual npu_result_t streamCreateWithPriority( + npuStream_t& stream, + unsigned int flags, + int priority) = 0; + virtual npu_result_t streamDestroy(const npuStream_t& stream) = 0; + virtual npu_result_t streamWaitEvent( + const npuStream_t& stream, + npuEvent_t& event, + unsigned int flags) = 0; + virtual npuStream_t getCurrentNPUStream(int device_index) = 0; + virtual npu_result_t streamSynchronize(const npuStream_t& stream) = 0; + virtual npu_result_t streamIsCapturing( + npuStream_t stream, + npuStreamCaptureStatus* pCaptureStatus) = 0; + virtual npu_result_t streamGetCaptureInfo( + npuStream_t stream, + npuStreamCaptureStatus* pCaptureStatus, + unsigned long long* pId) = 0; + + // Memory management + virtual npu_result_t malloc(void** devPtr, size_t size) = 0; + virtual npu_result_t free(void* devPtr) = 0; + virtual npu_result_t + memcpyAsync(void* dst, const void* src, size_t count, npuStream_t stream) = 0; + + // Event management + virtual npu_result_t eventCreate(npuEvent_t& event) = 0; + virtual npu_result_t eventCreateWithFlags( + npuEvent_t& event, + unsigned int flags) = 0; + virtual npu_result_t eventDestroy(npuEvent_t& event) = 0; + virtual npu_result_t eventRecord( + npuEvent_t& event, + const npuStream_t& stream) = 0; + virtual npu_result_t eventQuery(const npuEvent_t& event) = 0; + + // Graph operations (unsupported, kept for API compatibility) + virtual npu_result_t userObjectCreate( + npuUserObject_t* object_out, + void* ptr, + npuHostFn_t destroy, + unsigned int initialRefcount, + unsigned int flags) = 0; + virtual npu_result_t graphRetainUserObject( + npuGraph_t graph, + npuUserObject_t object, + unsigned int count, + unsigned int flags) = 0; + virtual npu_result_t streamGetCaptureInfo_v2( + npuStream_t stream, + npuStreamCaptureStatus* captureStatus_out, + unsigned long long* id_out, + npuGraph_t* graph_out, + const npuGraphNode_t** dependencies_out, + size_t* numDependencies_out) = 0; + + // Error handling + virtual const char* getErrorString(npu_result_t error) = 0; +}; + +class DefaultNpuApi : public NpuApi { + public: + ~DefaultNpuApi() override = default; + + // Device management + npu_result_t setDevice(int device) override; + npu_result_t getDeviceProperties(npuDeviceProp* prop, int device) override; + npu_result_t memGetInfo(size_t* free, size_t* total) override; + npu_result_t getDeviceCount(int* count) override; + + // Stream management + npu_result_t streamCreateWithPriority( + npuStream_t& stream, + unsigned int flags, + int priority) override; + npu_result_t streamDestroy(const npuStream_t& stream) override; + npu_result_t streamWaitEvent( + const npuStream_t& stream, + npuEvent_t& event, + unsigned int flags) override; + npuStream_t getCurrentNPUStream(int device_index) override; + npu_result_t streamSynchronize(const npuStream_t& stream) override; + npu_result_t streamIsCapturing( + npuStream_t stream, + npuStreamCaptureStatus* pCaptureStatus) override; + npu_result_t streamGetCaptureInfo( + npuStream_t stream, + npuStreamCaptureStatus* pCaptureStatus, + unsigned long long* pId) override; + + // Memory management + npu_result_t malloc(void** devPtr, size_t size) override; + npu_result_t free(void* devPtr) override; + npu_result_t memcpyAsync( + void* dst, + const void* src, + size_t count, + npuStream_t stream) override; + + // Event management + npu_result_t eventCreate(npuEvent_t& event) override; + npu_result_t eventCreateWithFlags(npuEvent_t& event, unsigned int flags) + override; + npu_result_t eventDestroy(npuEvent_t& event) override; + npu_result_t eventRecord(npuEvent_t& event, const npuStream_t& stream) + override; + npu_result_t eventQuery(const npuEvent_t& event) override; + + // Graph operations (unsupported) + npu_result_t userObjectCreate( + npuUserObject_t* object_out, + void* ptr, + npuHostFn_t destroy, + unsigned int initialRefcount, + unsigned int flags) override; + npu_result_t graphRetainUserObject( + npuGraph_t graph, + npuUserObject_t object, + unsigned int count, + unsigned int flags) override; + npu_result_t streamGetCaptureInfo_v2( + npuStream_t stream, + npuStreamCaptureStatus* captureStatus_out, + unsigned long long* id_out, + npuGraph_t* graph_out, + const npuGraphNode_t** dependencies_out, + size_t* numDependencies_out) override; + + // Error handling + const char* getErrorString(npu_result_t error) override; +}; + +} // namespace torch::comms diff --git a/comms/torchcomms/hccl/CMakeLists.txt b/comms/torchcomms/hccl/CMakeLists.txt new file mode 100644 index 000000000..79e3f9c31 --- /dev/null +++ b/comms/torchcomms/hccl/CMakeLists.txt @@ -0,0 +1,117 @@ +# Extension: torchcomms._comms_hccl + +# Check if ASCEND_TOOLKIT_HOME is set +if(NOT DEFINED ENV{ASCEND_TOOLKIT_HOME}) + message( + WARNING + "HCCL environment not found (ASCEND_TOOLKIT_HOME not set), Skipping HCCL backend compilation." + ) + return() +endif() + +# Set HCCL paths +set(HCCL_INCLUDE "$ENV{ASCEND_TOOLKIT_HOME}/include") +set(HCCL_LIB_PATH "$ENV{ASCEND_TOOLKIT_HOME}/aarch64-linux/lib64") +set(HCCL_SHARED_LIB "${HCCL_LIB_PATH}/libhccl.so") +set(ACL_SHARED_LIB "${HCCL_LIB_PATH}/libascendcl.so") +execute_process( + COMMAND ${Python3_EXECUTABLE} -c "import importlib.util, pathlib, sys; spec = importlib.util.find_spec('torch_npu'); sys.stdout.write(str(pathlib.Path(spec.origin).resolve().parent) if spec and spec.origin else '')" + OUTPUT_VARIABLE TORCH_NPU_PKG_DIR + OUTPUT_STRIP_TRAILING_WHITESPACE + ERROR_QUIET + RESULT_VARIABLE TORCH_NPU_DISCOVERY_RESULT +) + +if(NOT TORCH_NPU_DISCOVERY_RESULT EQUAL 0 OR NOT TORCH_NPU_PKG_DIR) + message(WARNING "torch_npu package not found. Skipping HCCL backend compilation.") + return() +endif() + +set(TORCH_NPU_INCLUDE "${TORCH_NPU_PKG_DIR}/include") +set(TORCH_NPU_LIB_DIR "${TORCH_NPU_PKG_DIR}/lib") +set(TORCH_NPU_SHARED_LIB "${TORCH_NPU_LIB_DIR}/libtorch_npu.so") + +# Validate HCCL installation +if(NOT EXISTS "${HCCL_INCLUDE}" OR NOT EXISTS "${HCCL_SHARED_LIB}") + message(WARNING "Invalid HCCL path. Skipping HCCL backend compilation.") + message(STATUS "HCCL include path : ${HCCL_INCLUDE}") + message(STATUS "HCCL library : ${HCCL_SHARED_LIB}") + return() +endif() + +message(STATUS "HCCL include path : ${HCCL_INCLUDE}") +message(STATUS "HCCL library : ${HCCL_SHARED_LIB}") +message(STATUS "ACL library : ${ACL_SHARED_LIB}") +message(STATUS "Torch NPU include : ${TORCH_NPU_INCLUDE}") +message(STATUS "Torch NPU library : ${TORCH_NPU_SHARED_LIB}") + +file(GLOB TORCHCOMMS_HCCL_SOURCES "comms/torchcomms/hccl/*.cpp") +file(GLOB TORCHCOMMS_NPU_API_SOURCE "comms/torchcomms/device/npu/NpuApi.cpp") + +include(FindPackageHandleStandardArgs) + +add_library(torchcomms_comms_hccl MODULE + ${TORCHCOMMS_HCCL_SOURCES} + ${TORCHCOMMS_NPU_API_SOURCE} +) +set_target_properties(torchcomms_comms_hccl PROPERTIES + PREFIX "" + OUTPUT_NAME "_comms_hccl" + SUFFIX ".${Python3_SOABI}.so" + LIBRARY_OUTPUT_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}/comms/torchcomms" + BUILD_RPATH "$ORIGIN" + INSTALL_RPATH "$ORIGIN" +) +target_include_directories(torchcomms_comms_hccl PRIVATE + ${ROOT} + ${HCCL_INCLUDE} + ${CONDA_INCLUDE} + ${Python3_INCLUDE_DIRS} + ${TORCH_NPU_INCLUDE} +) +target_compile_features(torchcomms_comms_hccl PRIVATE cxx_std_20) +target_link_directories(torchcomms_comms_hccl PRIVATE + ${CONDA_LIB} + ${TORCH_NPU_LIB_DIR} +) +target_link_libraries(torchcomms_comms_hccl PRIVATE + ${TORCH_LIBRARIES} + ${TORCH_PYTHON_LIB} + torchcomms +) + +target_link_libraries(torchcomms_comms_hccl PRIVATE + ${HCCL_SHARED_LIB} + ${ACL_SHARED_LIB} + ${TORCH_NPU_SHARED_LIB} +) + +if(USE_SYSTEM_LIBS) + target_link_libraries(torchcomms_comms_hccl PRIVATE + "-lglog" + "-lgflags" + "-lfmt" + ) +else() + target_include_directories(torchcomms_comms_hccl PRIVATE + ${ROOT}/third-party/fmt/include + ) + target_link_libraries(torchcomms_comms_hccl PRIVATE + "-l:libglog.a" + "-l:libgflags.a" + "-l:libfmt.a" + ) +endif() + +set_property(TARGET torchcomms_comms_hccl APPEND PROPERTY BUILD_RPATH + "${CONDA_LIB}" + "${TORCH_NPU_LIB_DIR}" +) +set_property(TARGET torchcomms_comms_hccl APPEND PROPERTY INSTALL_RPATH + "${CONDA_LIB}" + "${TORCH_NPU_LIB_DIR}" +) + +install(TARGETS torchcomms_comms_hccl + LIBRARY DESTINATION . +) diff --git a/comms/torchcomms/hccl/HcclApi.cpp b/comms/torchcomms/hccl/HcclApi.cpp new file mode 100644 index 000000000..96d3144b5 --- /dev/null +++ b/comms/torchcomms/hccl/HcclApi.cpp @@ -0,0 +1,224 @@ +#include "comms/torchcomms/hccl/HcclApi.hpp" +#include "comms/torchcomms/utils/Logging.hpp" + +namespace torch::comms { + +const char* DefaultHcclApi::getErrorString(HcclResult result) { + // HCCL error codes are typically defined in hccl_types.h + switch (result) { + case HCCL_SUCCESS: + return "success"; + case HCCL_E_PARA: + return "invalid parameter"; + case HCCL_E_PTR: + return "invalid pointer"; + case HCCL_E_MEMORY: + return "memory error"; + case HCCL_E_NOT_SUPPORT: + return "not supported"; + case HCCL_E_NOT_FOUND: + return "not found"; + case HCCL_E_UNAVAIL: + return "unavailable"; + case HCCL_E_SYSCALL: + return "system call error"; + case HCCL_E_TIMEOUT: + return "timeout"; + case HCCL_E_OPEN_FILE_FAILURE: + return "open file failure"; + case HCCL_E_TCP_CONNECT: + return "TCP connect error"; + case HCCL_E_ROCE_CONNECT: + return "ROCE connect error"; + case HCCL_E_TCP_TRANSFER: + return "TCP transfer error"; + case HCCL_E_ROCE_TRANSFER: + return "ROCE transfer error"; + case HCCL_E_RUNTIME: + return "runtime error"; + case HCCL_E_DRV: + return "driver error"; + case HCCL_E_PROFILING: + return "profiling error"; + case HCCL_E_CCE: + return "CCE error"; + case HCCL_E_NETWORK: + return "network error"; + case HCCL_E_RESERVED: + return "reserved error"; + case HCCL_E_INTERNAL: + return "internal error"; + case HCCL_E_AGAIN: + return "try again"; + default: + return "unknown error"; + } +} + +HcclResult DefaultHcclApi::getUniqueId(HcclRootInfo* uniqueId) { + return HcclGetRootInfo(uniqueId); +} + +HcclResult DefaultHcclApi::commInitRankConfig( + HcclComm* comm, + int nranks, + HcclRootInfo rootInfo, + int rank, + const HcclCommConfig* config) { + if (config) { + return HcclCommInitRootInfoConfig( + nranks, &rootInfo, rank, config, comm); + } else { + return HcclCommInitRootInfo(nranks, &rootInfo, rank, comm); + } +} + +HcclResult DefaultHcclApi::commDestroy(HcclComm comm) { + return HcclCommDestroy(comm); +} + +HcclResult DefaultHcclApi::commAbort(HcclComm comm) { + // HCCL may not have an explicit abort function + // Destroy is the closest equivalent + return HcclCommDestroy(comm); +} + +HcclResult DefaultHcclApi::commGetAsyncError( + HcclComm comm, + HcclResult* asyncError) { + // HCCL may not support async error checking + // Return not supported + if (asyncError) { + *asyncError = HCCL_SUCCESS; + } + return HCCL_E_NOT_SUPPORT; +} + +HcclResult DefaultHcclApi::send( + const void* sendbuff, + uint64_t count, + HcclDataType datatype, + uint32_t peer, + HcclComm comm, + npuStream_t stream) { + // HCCL C API doesn't use const for sendbuff, but it shouldn't modify it + return HcclSend(const_cast(sendbuff), count, datatype, peer, comm, stream); +} + +HcclResult DefaultHcclApi::recv( + void* recvbuff, + uint64_t count, + HcclDataType datatype, + uint32_t peer, + HcclComm comm, + npuStream_t stream) { + return HcclRecv(recvbuff, count, datatype, peer, comm, stream); +} + +HcclResult DefaultHcclApi::broadcast( + void* buff, + uint64_t count, + HcclDataType datatype, + uint32_t root, + HcclComm comm, + npuStream_t stream) { + return HcclBroadcast(buff, count, datatype, root, comm, stream); +} + +HcclResult DefaultHcclApi::allReduce( + const void* sendbuff, + void* recvbuff, + uint64_t count, + HcclDataType datatype, + HcclReduceOp op, + HcclComm comm, + npuStream_t stream) { + // HCCL C API doesn't use const for sendbuff, but it shouldn't modify it + return HcclAllReduce(const_cast(sendbuff), recvbuff, count, datatype, op, comm, stream); +} + +HcclResult DefaultHcclApi::reduce( + const void* sendbuff, + void* recvbuff, + uint64_t count, + HcclDataType datatype, + HcclReduceOp op, + uint32_t root, + HcclComm comm, + npuStream_t stream) { + // HCCL C API doesn't use const for sendbuff, but it shouldn't modify it + return HcclReduce(const_cast(sendbuff), recvbuff, count, datatype, op, root, comm, stream); +} + +HcclResult DefaultHcclApi::allGather( + const void* sendbuff, + void* recvbuff, + uint64_t sendcount, + HcclDataType datatype, + HcclComm comm, + npuStream_t stream) { + // HCCL C API doesn't use const for sendbuff, but it shouldn't modify it + return HcclAllGather(const_cast(sendbuff), recvbuff, sendcount, datatype, comm, stream); +} + +HcclResult DefaultHcclApi::reduceScatter( + const void* sendbuff, + void* recvbuff, + uint64_t recvcount, + HcclDataType datatype, + HcclReduceOp op, + HcclComm comm, + npuStream_t stream) { + // HCCL C API doesn't use const for sendbuff, but it shouldn't modify it + return HcclReduceScatter(const_cast(sendbuff), recvbuff, recvcount, datatype, op, comm, stream); +} + +HcclResult DefaultHcclApi::allToAll( + const void* sendbuff, + void* recvbuff, + uint64_t count, + HcclDataType datatype, + HcclComm comm, + npuStream_t stream) { + // HCCL may not have a direct alltoall, might need to use alltoallv + // For now, return not supported + return HCCL_E_NOT_SUPPORT; +} + +HcclResult DefaultHcclApi::allToAllv( + const void* sendbuff, + const uint64_t* sendcounts, + const uint64_t* sdispls, + HcclDataType sendtype, + void* recvbuff, + const uint64_t* recvcounts, + const uint64_t* rdispls, + HcclDataType recvtype, + HcclComm comm, + npuStream_t stream) { + // HCCL alltoallv may have a different signature + // This is a placeholder implementation + return HCCL_E_NOT_SUPPORT; +} + +HcclResult DefaultHcclApi::groupStart() { + // HCCL may not support explicit group operations + // Return success as a no-op + return HCCL_SUCCESS; +} + +HcclResult DefaultHcclApi::groupEnd() { + // HCCL may not support explicit group operations + // Return success as a no-op + return HCCL_SUCCESS; +} + +HcclResult DefaultHcclApi::getRankId(HcclComm comm, uint32_t* userRank) { + return HcclGetRankId(comm, userRank); +} + +HcclResult DefaultHcclApi::getRankSize(HcclComm comm, uint32_t* count) { + return HcclGetRankSize(comm, count); +} + +} // namespace torch::comms diff --git a/comms/torchcomms/hccl/HcclApi.hpp b/comms/torchcomms/hccl/HcclApi.hpp new file mode 100644 index 000000000..b8e99b351 --- /dev/null +++ b/comms/torchcomms/hccl/HcclApi.hpp @@ -0,0 +1,242 @@ +#pragma once + +#include +#include + +#include "comms/torchcomms/device/npu/NpuApi.hpp" + +namespace torch::comms { + +class HcclApi { + public: + virtual ~HcclApi() = default; + + virtual const char* getErrorString(HcclResult result) = 0; + + virtual HcclResult getUniqueId(HcclRootInfo* uniqueId) = 0; + + virtual HcclResult commInitRankConfig( + HcclComm* comm, + int nranks, + HcclRootInfo commId, + int rank, + const HcclCommConfig* config) = 0; + + virtual HcclResult commDestroy(HcclComm comm) = 0; + + virtual HcclResult commAbort(HcclComm comm) = 0; + + virtual HcclResult commGetAsyncError( + HcclComm comm, + HcclResult* asyncError) = 0; + + // Point-to-point operations + virtual HcclResult send( + const void* sendbuff, + uint64_t count, + HcclDataType datatype, + uint32_t peer, + HcclComm comm, + npuStream_t stream) = 0; + + virtual HcclResult recv( + void* recvbuff, + uint64_t count, + HcclDataType datatype, + uint32_t peer, + HcclComm comm, + npuStream_t stream) = 0; + + // Collective operations + virtual HcclResult broadcast( + void* buff, + uint64_t count, + HcclDataType datatype, + uint32_t root, + HcclComm comm, + npuStream_t stream) = 0; + + virtual HcclResult allReduce( + const void* sendbuff, + void* recvbuff, + uint64_t count, + HcclDataType datatype, + HcclReduceOp op, + HcclComm comm, + npuStream_t stream) = 0; + + virtual HcclResult reduce( + const void* sendbuff, + void* recvbuff, + uint64_t count, + HcclDataType datatype, + HcclReduceOp op, + uint32_t root, + HcclComm comm, + npuStream_t stream) = 0; + + virtual HcclResult allGather( + const void* sendbuff, + void* recvbuff, + uint64_t sendcount, + HcclDataType datatype, + HcclComm comm, + npuStream_t stream) = 0; + + virtual HcclResult reduceScatter( + const void* sendbuff, + void* recvbuff, + uint64_t recvcount, + HcclDataType datatype, + HcclReduceOp op, + HcclComm comm, + npuStream_t stream) = 0; + + virtual HcclResult allToAll( + const void* sendbuff, + void* recvbuff, + uint64_t count, + HcclDataType datatype, + HcclComm comm, + npuStream_t stream) = 0; + + virtual HcclResult allToAllv( + const void* sendbuff, + const uint64_t* sendcounts, + const uint64_t* sdispls, + HcclDataType sendtype, + void* recvbuff, + const uint64_t* recvcounts, + const uint64_t* rdispls, + HcclDataType recvtype, + HcclComm comm, + npuStream_t stream) = 0; + + // Group operations + virtual HcclResult groupStart() = 0; + virtual HcclResult groupEnd() = 0; + + virtual HcclResult getRankId(HcclComm comm, uint32_t* userRank) = 0; + virtual HcclResult getRankSize(HcclComm comm, uint32_t* count) = 0; +}; + +/** + * Default implementation that calls the underlying HCCL APIs directly. + */ +class DefaultHcclApi : public HcclApi { + public: + ~DefaultHcclApi() override = default; + + // Error handling + const char* getErrorString(HcclResult result) override; + + // Unique ID generation + HcclResult getUniqueId(HcclRootInfo* uniqueId) override; + + // Communicator management + HcclResult commInitRankConfig( + HcclComm* comm, + int nranks, + HcclRootInfo commId, + int rank, + const HcclCommConfig* config) override; + + HcclResult commDestroy(HcclComm comm) override; + + HcclResult commAbort(HcclComm comm) override; + + HcclResult commGetAsyncError( + HcclComm comm, + HcclResult* asyncError) override; + + // Point-to-point operations + HcclResult send( + const void* sendbuff, + uint64_t count, + HcclDataType datatype, + uint32_t peer, + HcclComm comm, + npuStream_t stream) override; + + HcclResult recv( + void* recvbuff, + uint64_t count, + HcclDataType datatype, + uint32_t peer, + HcclComm comm, + npuStream_t stream) override; + + // Collective operations + HcclResult broadcast( + void* buff, + uint64_t count, + HcclDataType datatype, + uint32_t root, + HcclComm comm, + npuStream_t stream) override; + + HcclResult allReduce( + const void* sendbuff, + void* recvbuff, + uint64_t count, + HcclDataType datatype, + HcclReduceOp op, + HcclComm comm, + npuStream_t stream) override; + + HcclResult reduce( + const void* sendbuff, + void* recvbuff, + uint64_t count, + HcclDataType datatype, + HcclReduceOp op, + uint32_t root, + HcclComm comm, + npuStream_t stream) override; + + HcclResult allGather( + const void* sendbuff, + void* recvbuff, + uint64_t sendcount, + HcclDataType datatype, + HcclComm comm, + npuStream_t stream) override; + + HcclResult reduceScatter( + const void* sendbuff, + void* recvbuff, + uint64_t recvcount, + HcclDataType datatype, + HcclReduceOp op, + HcclComm comm, + npuStream_t stream) override; + + HcclResult allToAll( + const void* sendbuff, + void* recvbuff, + uint64_t count, + HcclDataType datatype, + HcclComm comm, + npuStream_t stream) override; + + HcclResult allToAllv( + const void* sendbuff, + const uint64_t* sendcounts, + const uint64_t* sdispls, + HcclDataType sendtype, + void* recvbuff, + const uint64_t* recvcounts, + const uint64_t* rdispls, + HcclDataType recvtype, + HcclComm comm, + npuStream_t stream) override; + + // Group operations + HcclResult groupStart() override; + HcclResult groupEnd() override; + + HcclResult getRankId(HcclComm comm, uint32_t* userRank) override; + HcclResult getRankSize(HcclComm comm, uint32_t* count) override; +}; + +} // namespace torch::comms diff --git a/comms/torchcomms/hccl/TorchCommHCCL.cpp b/comms/torchcomms/hccl/TorchCommHCCL.cpp new file mode 100644 index 000000000..c66dfe01d --- /dev/null +++ b/comms/torchcomms/hccl/TorchCommHCCL.cpp @@ -0,0 +1,586 @@ +#include "comms/torchcomms/hccl/TorchCommHCCL.hpp" + +#include +#include +#include +#include "comms/torchcomms/TorchCommFactory.hpp" +#include "comms/torchcomms/utils/Logging.hpp" +#include "comms/torchcomms/utils/TracingGuard.hpp" +#include "comms/torchcomms/utils/Utils.hpp" +#include "comms/torchcomms/hccl/TorchCommHCCLBootstrap.hpp" + +namespace torch::comms { + +HcclResult HCCLException::getResult() const { + return result_; +} + +static void preReduce(at::Tensor& tensor, const ReduceOp& r) { + if (r.type() == ReduceOp::RedOpType::PREMUL_SUM) { + std::visit([&tensor](auto&& arg) { tensor.mul_(arg); }, *r.factor()); + } +} + +TorchCommHCCL::TorchCommHCCL() + : hccl_comm_{nullptr}, + device_(at::kPrivateUse1), + init_state_(InitializationState::UNINITIALIZED), + shutdown_(false) {} + +TorchCommHCCL::TorchCommHCCL(const HcclComm hccl_comm) + : hccl_comm_(hccl_comm), + device_(at::kPrivateUse1), + init_state_(InitializationState::UNINITIALIZED), + shutdown_(false) {} + +TorchCommHCCL::~TorchCommHCCL() { + if (init_state_ == InitializationState::INITIALIZED) { + TC_LOG(ERROR) << "TorchCommHCCL was not finalized before destruction"; + + // If finalize was not called, we need to clean up the timeout thread + if (timeout_thread_.joinable()) { + shutdown_.store(true); + timeout_thread_.join(); + } + } +} + +void TorchCommHCCL::init( + at::Device device, + const std::string& name, + const CommOptions& options) { + // Initialize private members + device_ = device; + name_ = name; + options_ = options; + + // Only initialize once + if (init_state_ == InitializationState::INITIALIZED) { + throw std::runtime_error("TorchCommHCCL already initialized"); + } else if (init_state_ == InitializationState::FINALIZED) { + throw std::runtime_error("TorchCommHCCL already finalized"); + } + init_state_ = InitializationState::INITIALIZED; + + // Initialize default HCCL API implementation if not already set + if (!hccl_api_) { + hccl_api_ = std::make_shared(); + } + + // Initialize default NPU API implementation if not already set + if (!npu_api_) { + npu_api_ = std::make_shared(); + } + + if (device_.index() == -1 || hccl_comm_ == nullptr) { + auto bootstrap = new TorchCommHCCLBootstrap( + options_.store, device_, hccl_api_, npu_api_, options_.timeout); + device_ = bootstrap->getDevice(); + + if (hccl_comm_ == nullptr) { + hccl_comm_ = bootstrap->createHcclComm(name, options); + } + + delete bootstrap; + } + + // Set NPU device and verify it's accessible + NPU_CHECK( + npu_api_, + npu_api_->setDevice(device_.index()), + "Failed to set NPU device to " + std::to_string(device_.index())); + + // Verify device properties and memory availability + [[maybe_unused]] npuDeviceProp device_prop = {}; + NPU_CHECK( + npu_api_, + npu_api_->getDeviceProperties(&device_prop, device_.index()), + "Failed to get device properties for device " + + std::to_string(device_.index())); + + // Check available memory + [[maybe_unused]] size_t free_memory, total_memory; + NPU_CHECK( + npu_api_, + npu_api_->memGetInfo(&free_memory, &total_memory), + "Failed to get memory info for device " + + std::to_string(device_.index())); + + // Read hints and store them + for (auto const& [key, val] : options_.hints) { + if (key.starts_with("torchcomm::hccl::")) { + if (key == "torchcomm::hccl::high_priority_stream") { + high_priority_stream_ = string_to_bool(val); + } else { + throw std::runtime_error("Unrecognized hint " + key); + } + } else { + // Ignore keys that do not start with "torchcomm::hccl::" + } + } + + // Create internal stream + int stream_priority = 0; + + // Check for high priority stream hint + if (high_priority_stream_) { + stream_priority = -1; + } + + // Initialize internal stream + npuStream_t temp_stream = npu_api_->getCurrentNPUStream(device_.index()); + NPU_CHECK( + npu_api_, + npu_api_->streamCreateWithPriority( + temp_stream, /*flags=*/0, stream_priority), + "Failed to create internal NPU stream on device " + + std::to_string(device_.index())); + internal_stream_ = std::move(temp_stream); + + // Create dependency event for stream synchronization + npuEvent_t temp_event; + NPU_CHECK( + npu_api_, + npu_api_->eventCreate(temp_event), + "Failed to create dependency event on device " + + std::to_string(device_.index())); + dependency_event_ = std::move(temp_event); + + // Allocate NPU buffer for barrier operations + NPU_CHECK( + npu_api_, + npu_api_->malloc(&barrier_buffer_, sizeof(float)), + "Failed to allocate barrier buffer"); + + if (options_.hints.contains("torchcomm::hccl::max_event_pool_size")) { + max_event_pool_size_ = + std::stoull(options_.hints.at("torchcomm::hccl::max_event_pool_size")); + } else { + max_event_pool_size_ = kMaxEventPoolSize; + } + + // Give up our internal reference to the store object here. The caller + // would still need to keep a reference to the store object till the init + // call returns, at which point the HCCL communicator would already be + // created. + if (options_.store) { + options_.store.reset(); + } + + uint32_t rank_u32; + HcclResult hcclErr = hccl_api_->getRankId(hccl_comm_, &rank_u32); + if (hcclErr != HCCL_SUCCESS) { + throw std::runtime_error("HCCL getRankId failed"); + } + rank_ = static_cast(rank_u32); + + tryTorchCommLoggingInit("torchcomm"); + + uint32_t comm_size_u32; + hcclErr = hccl_api_->getRankSize(hccl_comm_, &comm_size_u32); + if (hcclErr != HCCL_SUCCESS) { + throw std::runtime_error("HCCL getRankSize failed"); + } + comm_size_ = static_cast(comm_size_u32); + + TracingGuard tracingGuard(name_, comm_size_, "init", rank_); + + // Start timeout watchdog thread + timeout_thread_ = std::thread(&TorchCommHCCL::timeoutWatchdog, this); +} + +void TorchCommHCCL::finalize() { + if (init_state_ == InitializationState::UNINITIALIZED) { + throw std::runtime_error("TorchCommHCCL not initialized"); + } else if (init_state_ == InitializationState::FINALIZED) { + throw std::runtime_error("TorchCommHCCL already finalized"); + } + init_state_ = InitializationState::FINALIZED; + + // Signal shutdown to timeout watchdog + shutdown_ = true; + + // Wake up the timeout watchdog thread + { + std::lock_guard lock(timeout_mutex_); + timeout_cv_.notify_all(); + } + + // Wait for timeout thread to finish + if (timeout_thread_.joinable()) { + timeout_thread_.join(); + } + + // Wait for all pending work objects to complete and get final status + auto work_status = workq_.finalize(); + + if (work_status == TorchWorkHCCL::WorkStatus::NOT_STARTED || + work_status == TorchWorkHCCL::WorkStatus::INPROGRESS) { + throw std::runtime_error( + "WorkQ finalize returned in progress or not started state"); + } + + // Update comm_state_ based on the work status + if (work_status == TorchWorkHCCL::WorkStatus::TIMEDOUT) { + comm_state_ = CommState::TIMEOUT; + abortHcclComm(); + throw std::runtime_error("Work timed out during finalize"); + } else if (work_status == TorchWorkHCCL::WorkStatus::ERROR) { + comm_state_ = CommState::ERROR; + HcclResult asyncErr; + hccl_api_->commGetAsyncError(hccl_comm_, &asyncErr); + HCCLException hcclException(*hccl_api_, "HCCL Async Error", asyncErr); + abortHcclComm(); + throw hcclException; + } + + // Clean up event pool + { + std::lock_guard lock(event_pool_mutex_); + while (!event_pool_.empty()) { + npuEvent_t event = std::move(event_pool_.front()); + event_pool_.pop(); + NPU_CHECK( + npu_api_, npu_api_->eventDestroy(event), "Failed to destroy event"); + } + } + + // Free barrier buffer + if (barrier_buffer_) { + NPU_CHECK( + npu_api_, + npu_api_->free(barrier_buffer_), + "Failed to free barrier buffer"); + barrier_buffer_ = nullptr; + } + + // Destroy dependency event + if (dependency_event_.has_value()) { + NPU_CHECK( + npu_api_, + npu_api_->eventDestroy(dependency_event_.value()), + "Failed to destroy dependency event"); + dependency_event_.reset(); + } + + // Destroy internal stream + if (internal_stream_.has_value()) { + NPU_CHECK( + npu_api_, + npu_api_->streamDestroy(internal_stream_.value()), + "Failed to destroy internal stream"); + internal_stream_.reset(); + } + + // Destroy HCCL communicator + if (hccl_comm_) { + hccl_api_->commDestroy(hccl_comm_); + hccl_comm_ = nullptr; + } +} + +void TorchCommHCCL::abortHcclComm() { + if (hccl_comm_) { + hccl_api_->commAbort(hccl_comm_); + hccl_comm_ = nullptr; + } + if (options_.abort_process_on_timeout_or_error) { + TC_LOG(ERROR) << "Aborting process due to timeout"; + abort(); + } +} + +int TorchCommHCCL::getRank() const { + checkInitialized(); + + uint32_t rank; + HcclResult hcclErr = hccl_api_->getRankId(hccl_comm_, &rank); + if (hcclErr != HCCL_SUCCESS) { + throw HCCLException(*hccl_api_, "HCCL getRankId failed", hcclErr); + } + return static_cast(rank); +} + +int TorchCommHCCL::getSize() const { + checkInitialized(); + + uint32_t comm_size; + HcclResult hcclErr = hccl_api_->getRankSize(hccl_comm_, &comm_size); + if (hcclErr != HCCL_SUCCESS) { + throw HCCLException(*hccl_api_, "HCCL getRankSize failed", hcclErr); + } + return static_cast(comm_size); +} + +std::string_view TorchCommHCCL::getBackendName() const { + return kBackendName; +} + +std::string_view TorchCommHCCL::getCommName() const { + return name_; +} + +static inline std::chrono::milliseconds getOperationTimeout( + std::chrono::milliseconds timeout, + std::chrono::milliseconds default_timeout) { + // If timeout is kNoTimeout (0ms), use the default timeout from options + if (timeout == kNoTimeout) { + return default_timeout; + } + return timeout; +} + +// Point-to-Point Operations +c10::intrusive_ptr TorchCommHCCL::send( + const at::Tensor& tensor, + int dst, + bool async_op, + const SendOptions& options) { + throw std::runtime_error( + "HCCL send is not supported now and will be added later"); +} + +c10::intrusive_ptr TorchCommHCCL::recv( + at::Tensor& tensor, + int src, + bool async_op, + const RecvOptions& options) { + throw std::runtime_error( + "HCCL recv is not supported now and will be added later"); +} + +// Batch P2P Operations +c10::intrusive_ptr TorchCommHCCL::batch_op_issue( + const std::vector& ops, + bool async_op, + const BatchP2POptions& options) { + throw std::runtime_error( + "HCCL batch_op_issue is not supported now and will be added later"); +} + +// Collective Operations +c10::intrusive_ptr TorchCommHCCL::broadcast( + at::Tensor& tensor, + int root, + bool async_op, + const BroadcastOptions& options) { + throw std::runtime_error( + "HCCL broadcast is not supported now and will be added later"); +} + +c10::intrusive_ptr TorchCommHCCL::all_reduce( + at::Tensor& tensor, + const ReduceOp& op, + bool async_op, + const AllReduceOptions& options) { + checkInitialized(); + checkAndAbortIfTimedOutOrError(); + // Ensure correct device is set before HCCL calls + NPU_CHECK( + npu_api_, + npu_api_->setDevice(device_.index()), + "Failed to set NPU device to " + std::to_string(device_.index())); + ensureTensorContiguous(tensor); + + TracingGuard tracingGuard( + name_, comm_size_, "all_reduce", rank_, tensor, tensor); + + npuStream_t stream = getOperationStream(async_op); + auto work = createWork( + stream, getOperationTimeout(options.timeout, options_.timeout), {tensor}); + + work->recordStart(); + + // No-op for empty input tensor + if (tensor.numel() == 0) [[unlikely]] { + TC_LOG(WARNING) << "all_reduce called with empty input tensor"; + work->recordEnd(); + enqueueWork(work, stream); + return work; + } + + // HCCL handles premul sum differently, apply locally if comm_size is 1 + if (comm_size_ == 1) { + preReduce(tensor, op); + } + + const auto dataType = getHcclDataType(tensor); + HcclResult result = hccl_api_->allReduce( + tensor.data_ptr(), + tensor.data_ptr(), // In-place operation + tensor.numel(), + dataType, + getHcclReduceOp(op, hccl_comm_, dataType), + hccl_comm_, + stream); + + if (result != HCCL_SUCCESS) { + throw HCCLException(*hccl_api_, "HCCL AllReduce failed", result); + } + + work->recordEnd(); + + enqueueWork(work, stream); + + return work; +} + +c10::intrusive_ptr TorchCommHCCL::reduce( + const at::Tensor& tensor, + int root, + const ReduceOp& op, + bool async_op, + const ReduceOptions& options) { + throw std::runtime_error( + "HCCL reduce is not supported now and will be added later"); +} + +c10::intrusive_ptr TorchCommHCCL::all_gather( + const std::vector& tensor_list, + const at::Tensor& tensor, + bool async_op, + const AllGatherOptions& options) { + throw std::runtime_error( + "HCCL all_gather is not supported now and will be added later"); +} + +c10::intrusive_ptr TorchCommHCCL::all_gather_v( + const std::vector& tensor_list, + const at::Tensor& tensor, + bool async_op, + const AllGatherOptions& options) { + throw std::runtime_error("all_gather_v is not supported in HCCL backend"); +} + +c10::intrusive_ptr TorchCommHCCL::all_gather_single( + at::Tensor& output, + const at::Tensor& input, + bool async_op, + const AllGatherSingleOptions& options) { + throw std::runtime_error( + "HCCL all_gather_single is not supported now and will be added later"); +} + +c10::intrusive_ptr TorchCommHCCL::reduce_scatter( + at::Tensor& output, + const std::vector& input_list, + const ReduceOp& op, + bool async_op, + const ReduceScatterOptions& options) { + throw std::runtime_error( + "HCCL reduce_scatter is not supported now and will be added later"); +} + +c10::intrusive_ptr TorchCommHCCL::reduce_scatter_v( + at::Tensor& output, + const std::vector& input_list, + const ReduceOp& op, + bool async_op, + const ReduceScatterOptions& options) { + throw std::runtime_error("reduce_scatter_v is not supported in HCCL backend"); +} + +c10::intrusive_ptr TorchCommHCCL::reduce_scatter_single( + at::Tensor& output, + const at::Tensor& input, + const ReduceOp& op, + bool async_op, + const ReduceScatterSingleOptions& options) { + throw std::runtime_error( + "HCCL reduce_scatter_single is not supported now and will be added later"); +} + +c10::intrusive_ptr TorchCommHCCL::all_to_all_single( + at::Tensor& output, + const at::Tensor& input, + bool async_op, + const AllToAllSingleOptions& options) { + throw std::runtime_error( + "HCCL all_to_all_single is not supported now and will be added later"); +} + +c10::intrusive_ptr TorchCommHCCL::all_to_all_v_single( + at::Tensor& output, + const at::Tensor& input, + const std::vector& output_split_sizes, + const std::vector& input_split_sizes, + bool async_op, + const AllToAllvSingleOptions& options) { + throw std::runtime_error( + "HCCL all_to_all_v_single is not supported now and will be added later"); +} + +c10::intrusive_ptr TorchCommHCCL::all_to_all( + const std::vector& output_tensor_list, + const std::vector& input_tensor_list, + bool async_op, + const AllToAllOptions& options) { + throw std::runtime_error( + "HCCL all_to_all is not supported now and will be added later"); +} + +c10::intrusive_ptr TorchCommHCCL::barrier( + bool async_op, + const BarrierOptions& options) { + throw std::runtime_error( + "HCCL barrier is not supported now and will be added later"); +} + +c10::intrusive_ptr TorchCommHCCL::scatter( + at::Tensor& output_tensor, + const std::vector& input_tensor_list, + int root, + bool async_op, + const ScatterOptions& options) { + throw std::runtime_error( + "HCCL scatter is not supported now and will be added later"); +} + +c10::intrusive_ptr TorchCommHCCL::gather( + const std::vector& output_tensor_list, + const at::Tensor& input_tensor, + int root, + bool async_op, + const GatherOptions& options) { + throw std::runtime_error( + "HCCL gather is not supported now and will be added later"); +} + +std::shared_ptr TorchCommHCCL::split( + const std::vector& ranks, + const std::string& name, + const CommOptions& options) { + throw std::runtime_error( + "HCCL split is not supported now and will be added later"); +} + +std::shared_ptr TorchCommHCCL::getMemAllocator() { + throw std::runtime_error( + "HCCL getMemAllocator is not supported now and will be added later"); +} + +HCCLException::HCCLException( + HcclApi& hccl_api, + const std::string& message, + HcclResult result) + : message_(message + ": " + hccl_api.getErrorString(result)), + result_(result) {} + +const char* HCCLException::what() const noexcept { + return message_.c_str(); +} + +} // namespace torch::comms + +namespace { +class HCCLRegistration { + public: + HCCLRegistration() { + torch::comms::TorchCommFactory::get().register_backend("hccl", []() { + return std::make_shared(); + }); + } +}; + +static HCCLRegistration registration{}; +} // namespace diff --git a/comms/torchcomms/hccl/TorchCommHCCL.hpp b/comms/torchcomms/hccl/TorchCommHCCL.hpp new file mode 100644 index 000000000..2c63463d4 --- /dev/null +++ b/comms/torchcomms/hccl/TorchCommHCCL.hpp @@ -0,0 +1,314 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include // @manual=//caffe2:torch-cpp + +#include "comms/torchcomms/TorchComm.hpp" +#include "comms/torchcomms/TorchCommBackend.hpp" +#include "comms/torchcomms/TorchCommBatch.hpp" +#include "comms/torchcomms/device/npu/NpuApi.hpp" +#include "comms/torchcomms/hccl/HcclApi.hpp" +#include "comms/torchcomms/hccl/TorchWorkHCCL.hpp" + +namespace torch::comms { + +constexpr size_t kMaxEventPoolSize = 1000; + +// Custom exception class for better error handling +class HCCLException : public std::exception { + public: + HCCLException(HcclApi& api, const std::string& message, HcclResult result); + + const char* what() const noexcept override; + HcclResult getResult() const; + + private: + std::string message_; + HcclResult result_; +}; + +class TorchCommHCCL : public TorchCommBackend, + public std::enable_shared_from_this { + public: + static constexpr std::string_view kBackendName = "hccl"; + + TorchCommHCCL(); + ~TorchCommHCCL() override; + + // Delete copy and move operations + TorchCommHCCL(const TorchCommHCCL&) = delete; + TorchCommHCCL(TorchCommHCCL&&) = delete; + TorchCommHCCL& operator=(const TorchCommHCCL&) = delete; + TorchCommHCCL& operator=(TorchCommHCCL&&) = delete; + + void init( + at::Device device, + const std::string& name, + const CommOptions& options = {}) override; + void finalize() override; + int getRank() const override; + int getSize() const override; + std::string_view getBackendName() const override; + std::string_view getCommName() const override; + + // Point-to-Point Operations + c10::intrusive_ptr send( + const at::Tensor& tensor, + int dst, + bool async_op, + const SendOptions& options = {}) override; + c10::intrusive_ptr recv( + at::Tensor& tensor, + int src, + bool async_op, + const RecvOptions& options = {}) override; + + // Batch P2P Operations + c10::intrusive_ptr batch_op_issue( + const std::vector& ops, + bool async_op, + const BatchP2POptions& options = {}) override; + + // Collective Operations + c10::intrusive_ptr broadcast( + at::Tensor& tensor, + int root, + bool async_op, + const BroadcastOptions& options = {}) override; + c10::intrusive_ptr all_reduce( + at::Tensor& tensor, + const ReduceOp& op, + bool async_op, + const AllReduceOptions& options = {}) override; + c10::intrusive_ptr reduce( + const at::Tensor& tensor, + int root, + const ReduceOp& op, + bool async_op, + const ReduceOptions& options = {}) override; + c10::intrusive_ptr all_gather( + const std::vector& tensor_list, + const at::Tensor& tensor, + bool async_op, + const AllGatherOptions& options = {}) override; + c10::intrusive_ptr all_gather_v( + const std::vector& tensor_list, + const at::Tensor& tensor, + bool async_op, + const AllGatherOptions& options = {}) override; + c10::intrusive_ptr all_gather_single( + at::Tensor& output, + const at::Tensor& input, + bool async_op, + const AllGatherSingleOptions& options = {}) override; + c10::intrusive_ptr reduce_scatter( + at::Tensor& output, + const std::vector& input_list, + const ReduceOp& op, + bool async_op, + const ReduceScatterOptions& options = {}) override; + c10::intrusive_ptr reduce_scatter_v( + at::Tensor& output, + const std::vector& input_list, + const ReduceOp& op, + bool async_op, + const ReduceScatterOptions& options = {}) override; + c10::intrusive_ptr reduce_scatter_single( + at::Tensor& output, + const at::Tensor& input, + const ReduceOp& op, + bool async_op, + const ReduceScatterSingleOptions& options = {}) override; + c10::intrusive_ptr all_to_all_single( + at::Tensor& output, + const at::Tensor& input, + bool async_op, + const AllToAllSingleOptions& options = {}) override; + c10::intrusive_ptr all_to_all_v_single( + at::Tensor& output, + const at::Tensor& input, + const std::vector& output_split_sizes, + const std::vector& input_split_sizes, + bool async_op, + const AllToAllvSingleOptions& options = {}) override; + c10::intrusive_ptr all_to_all( + const std::vector& output_tensor_list, + const std::vector& input_tensor_list, + bool async_op, + const AllToAllOptions& options = {}) override; + c10::intrusive_ptr barrier( + bool async_op, + const BarrierOptions& options = {}) override; + + // Scatter and Gather Operations + c10::intrusive_ptr scatter( + at::Tensor& output_tensor, + const std::vector& input_tensor_list, + int root, + bool async_op, + const ScatterOptions& options = {}) override; + c10::intrusive_ptr gather( + const std::vector& output_tensor_list, + const at::Tensor& input_tensor, + int root, + bool async_op, + const GatherOptions& options = {}) override; + + // Communicator Management + std::shared_ptr split( + const std::vector& ranks, + const std::string& name, + const CommOptions& options = {}) override; + + std::shared_ptr getMemAllocator(); + + // Friend access for TorchCommHCCL + friend class TorchWorkHCCL; + + // Getter for NPU API (for friend classes) + NpuApi* getNpuApi() const { + return npu_api_.get(); + } + + // Getter for HCCL API (for friend classes) + HcclApi* getHcclApi() const { + return hccl_api_.get(); + } + + // Method to override the HCCL API implementation for testing + void setHcclApi(std::shared_ptr api) { + hccl_api_ = std::move(api); + } + + // Method to override the NPU API implementation for testing + void setNpuApi(std::shared_ptr api) { + npu_api_ = std::move(api); + } + + const CommOptions& getOptions() const override { + return options_; + } + + const at::Device& getDevice() const override { + return device_; + } + + protected: + // Event management for friend classes + npuEvent_t getEvent(); + void returnEvent(npuEvent_t&& event); + void abortHcclComm(); + + enum class CommState { + NORMAL, + ERROR, + TIMEOUT, + }; + + std::atomic comm_state_{ + CommState::NORMAL}; // State of the communicator + + HcclDataType getHcclDataType(const at::Tensor& tensor); + c10::intrusive_ptr createWork( + npuStream_t stream, + std::chrono::milliseconds timeout, + const std::vector& inputTensors); + + private: + // Helper that automatically cleans up premul sums. + struct RedOpRAII { + /* implicit */ RedOpRAII(HcclReduceOp op); + + // Constructor for Premulsum Reduction + explicit RedOpRAII( + const ReduceOp& op, + const HcclComm comm, + const HcclDataType dataType, + std::shared_ptr hccl_api); + + RedOpRAII() = delete; + RedOpRAII(const RedOpRAII&) = delete; + RedOpRAII& operator=(const RedOpRAII&) = delete; + RedOpRAII(RedOpRAII&& tmp) = delete; + RedOpRAII& operator=(RedOpRAII&&) = delete; + ~RedOpRAII(); + + operator HcclReduceOp() const { + return hcclRedOp_; + } + + HcclReduceOp hcclRedOp_{HCCL_REDUCE_SUM}; + HcclComm comm_{nullptr}; + std::shared_ptr hccl_api_; + }; + + // Constructor for split communicators + explicit TorchCommHCCL(const HcclComm hccl_comm); + + // Private utility methods + size_t wordSize(HcclDataType type) const; + RedOpRAII getHcclReduceOp( + const ReduceOp& op, + const HcclComm comm, + const HcclDataType dataType); + void timeoutWatchdog() noexcept; + void checkInitialized() const; + void checkAndAbortIfTimedOutOrError(); + void checkWorkQueue(bool isMainThread); + void enqueueWork(c10::intrusive_ptr work, npuStream_t stream); + npuStream_t getOperationStream(bool async_op); + void ensureTensorContiguous(const at::Tensor& tensor); + + // Member variables + HcclComm hccl_comm_{}; + at::Device device_; + int comm_size_{}; + int rank_{}; + CommOptions options_; + size_t max_event_pool_size_{}; + std::optional internal_stream_; // Initialized in init() + std::optional + dependency_event_; // Pre-allocated event for stream dependencies + void* barrier_buffer_{}; // Pre-allocated NPU buffer for barrier operations + enum class InitializationState { + UNINITIALIZED, + INITIALIZED, + FINALIZED, + } init_state_; + + // HCCL API abstraction + std::shared_ptr hccl_api_; + + // NPU API abstraction + std::shared_ptr npu_api_; + + // Event pool management + std::queue event_pool_; + std::mutex event_pool_mutex_; + + // Work tracking per stream + TorchWorkHCCLQueue workq_; + + // Timeout monitoring + std::thread timeout_thread_; + std::atomic shutdown_; + std::condition_variable timeout_cv_; + std::mutex timeout_mutex_; + + bool high_priority_stream_{false}; + std::string name_; +}; + +} // namespace torch::comms diff --git a/comms/torchcomms/hccl/TorchCommHCCLBootstrap.cpp b/comms/torchcomms/hccl/TorchCommHCCLBootstrap.cpp new file mode 100644 index 000000000..4c309d0c1 --- /dev/null +++ b/comms/torchcomms/hccl/TorchCommHCCLBootstrap.cpp @@ -0,0 +1,263 @@ +#include "comms/torchcomms/hccl/TorchCommHCCLBootstrap.hpp" +#include +#include +#include +#include +#include "comms/torchcomms/utils/StoreManager.hpp" +#include "comms/torchcomms/utils/Logging.hpp" +#include "comms/torchcomms/utils/Utils.hpp" +#include "comms/torchcomms/hccl/TorchCommHCCL.hpp" + +namespace torch::comms { + +// Initialize the static counter +int TorchCommHCCLBootstrap::counter_ = 0; + +const std::string kUniqueidXchgMethodAuto = "auto"; +const std::string kUniqueidXchgMethodTCPStore = "tcpstore"; +const std::string kUniqueidXchgMethodDefault = kUniqueidXchgMethodAuto; + +TorchCommHCCLBootstrap::TorchCommHCCLBootstrap( + c10::intrusive_ptr store, + c10::Device device, + std::shared_ptr hccl_api, + std::shared_ptr npu_api, + std::chrono::milliseconds timeout) + : timeout_(timeout), + store_(store), + created_internal_store_(false), + device_(device), + hccl_api_(hccl_api), + npu_api_(npu_api) { + // Query rank and size using the utility function + auto ranksize = query_ranksize(); + rank_ = ranksize.first; + comm_size_ = ranksize.second; + + const char* uniqueid_xchg_env = + std::getenv("TORCHCOMM_HCCL_BOOTSTRAP_UNIQUEID_EXCHANGE_METHOD"); + if (uniqueid_xchg_env == nullptr) { + TC_LOG(INFO) + << "TORCHCOMM_HCCL_BOOTSTRAP_UNIQUEID_EXCHANGE_METHOD not set, " + << "defaulting to " << kUniqueidXchgMethodDefault; + uniqueid_xchg_method_ = kUniqueidXchgMethodDefault; + } else { + uniqueid_xchg_method_ = uniqueid_xchg_env; + } + std::transform( + uniqueid_xchg_method_.begin(), + uniqueid_xchg_method_.end(), + uniqueid_xchg_method_.begin(), + [](unsigned char c) { return std::tolower(c); }); + + if (device_.index() == -1) { + int device_count; + NPU_CHECK( + npu_api_, + npu_api_->getDeviceCount(&device_count), + "Failed to get NPU device count"); + + device_ = c10::Device(c10::kPrivateUse1, rank_ % device_count); + TC_LOG(INFO) << "User did not provide device ID; using device npu:" + << static_cast(device_.index()); + } + + NPU_CHECK( + npu_api_, + npu_api_->setDevice(device_.index()), + "Failed to set device to " + std::to_string(device_.index())); + + // Allocate NPU memory for a single float32 value used in barrier operations + NPU_CHECK( + npu_api_, + npu_api_->malloc(&barrier_buffer_, sizeof(float)), + "Failed to allocate barrier buffer"); +} + +TorchCommHCCLBootstrap::~TorchCommHCCLBootstrap() { + if (barrier_buffer_ != nullptr) { + try { + NPU_CHECK( + npu_api_, + npu_api_->free(barrier_buffer_), + "Failed to free barrier buffer"); + } catch (const std::exception& e) { + TC_LOG(ERROR) << e.what(); + } + barrier_buffer_ = nullptr; + } +} + +std::string TorchCommHCCLBootstrap::getHCCLStoreKey() { + std::string key = getHCCLStoreKeyPrefix() + std::to_string(counter_); + counter_++; + return key; +} + +std::string TorchCommHCCLBootstrap::getHCCLStoreKeyPrefix() { + return "hccl_storekey_"; +}; + +int TorchCommHCCLBootstrap::getHCCLStoreKeyCounter() { + return counter_; +} + +HcclRootInfo TorchCommHCCLBootstrap::exchangeUniqueIdStore() { + HcclRootInfo uniqueId; + + auto key = getHCCLStoreKey(); + + if (rank_ == 0) { + // Generate unique ID on rank 0 + HcclResult hcclErr = hccl_api_->getUniqueId(&uniqueId); + + if (hcclErr != HCCL_SUCCESS) { + throw std::runtime_error( + "Failed to get HCCL unique ID: " + + std::string(hccl_api_->getErrorString(hcclErr))); + } + + // Set the unique ID in the store + std::vector vec( + reinterpret_cast(&uniqueId), + reinterpret_cast(&uniqueId) + sizeof(uniqueId)); + store_->set(key, vec); + } else { + // Other ranks read the broadcast ID + auto vec = store_->get(key); + + if (vec.size() != sizeof(HcclRootInfo)) { + throw std::runtime_error("Invalid HCCL unique ID size"); + } + uniqueId = *(reinterpret_cast(vec.data())); + } + + return uniqueId; +} + +HcclRootInfo TorchCommHCCLBootstrap::exchangeUniqueIdTCPStore( + std::string_view name) { + store_ = createPrefixStore(std::string(name), timeout_); + created_internal_store_ = true; + + return exchangeUniqueIdStore(); +} + +bool TorchCommHCCLBootstrap::isTCPStoreEnabled() { + return std::getenv("MASTER_ADDR") && std::getenv("MASTER_PORT"); +} + +HcclRootInfo TorchCommHCCLBootstrap::exchangeUniqueId(std::string_view name) { + if (store_ != nullptr) { + return exchangeUniqueIdStore(); + } + + bool is_tcp_store_enabled = isTCPStoreEnabled(); + if (uniqueid_xchg_method_ != kUniqueidXchgMethodAuto && + uniqueid_xchg_method_ != kUniqueidXchgMethodTCPStore) { + throw std::runtime_error( + "Invalid unique ID exchange method " + uniqueid_xchg_method_); + } + if (!is_tcp_store_enabled) { + throw std::runtime_error("No way to exchange unique ID"); + } + return exchangeUniqueIdTCPStore(name); +} + +void TorchCommHCCLBootstrap::cleanupTCPStore(HcclComm hccl_comm) { + if (created_internal_store_) { + // Delete the internal store object and do a barrier to ensure that all + // processes have deleted their store object too. This way, when we + // create the next torchcomm, we can use the same port to create a new store + // object. + store_.reset(); + + auto stream = npu_api_->getCurrentNPUStream(device_.index()); + HcclResult result = hccl_api_->allReduce( + barrier_buffer_, + barrier_buffer_, + 1, + HCCL_DATA_TYPE_FP32, + HCCL_REDUCE_SUM, + hccl_comm, + stream); + if (result != HCCL_SUCCESS) { + TC_LOG(ERROR) << "HCCL AllReduce failed: " + << hccl_api_->getErrorString(result); + } + + NPU_CHECK( + npu_api_, + npu_api_->streamSynchronize(stream), + "Stream synchronization failed"); + } +} + +// Helper function to populate HCCL config from hints +void populateHcclConfigFromHints( + HcclCommConfig& config, + const CommOptions& options, + const std::string& name) { + // Iterate over the hints and set the corresponding fields in the config + for (const auto& [key, val] : options.hints) { + if (key == "deterministic") { + config.hcclDeterministic = std::stoi(val); + TC_LOG(INFO) << "[comm=" << name << "] Setting config.hcclDeterministic=" + << config.hcclDeterministic; + } else if (key == "hcclBufferSize" || key == "hccl_buffer_size") { + TC_LOG(INFO) + << "[comm=" << name + << "] HCCL hint 'hcclBufferSize' is recognized but may not be applicable"; + } else if ( + key == "blocking" || key == "cgaClusterSize" || + key == "cga_cluster_size" || key == "minCTAs" || key == "min_ctas" || + key == "maxCTAs" || key == "max_ctas" || key == "netName" || + key == "splitShare" || key == "split_share" || key == "trafficClass" || + key == "traffic_class" || key == "commName" || key == "collnetEnable" || + key == "collnet_enable" || key == "CTAPolicy" || key == "cta_policy" || + key == "shrinkShare" || key == "nvlsCTAs" || key == "nvls_ctas" || + key == "nChannelsPerNetPeer" || key == "n_channels_per_net_peer" || + key == "nvlinkCentricSched" || key == "nvlink_centric_sched") { + TC_LOG(WARNING) << "HCCL hint '" << key + << "' is NCCL/XCCL-specific and not supported by HCCL, " + "ignoring for comm '" + << name << "'"; + } else { + TC_LOG(WARNING) + << "HCCL hint '" << key + << "' is not supported in this HCCL version, ignoring for comm '" + << name << "'"; + } + } +} + +HcclComm TorchCommHCCLBootstrap::createHcclComm( + const std::string& name, + const CommOptions& options) { + HcclRootInfo uniqueId; + HcclComm hccl_comm = nullptr; + + uniqueId = exchangeUniqueId(name); + + HcclCommConfig config; + // Initialize config with HCCL defaults before applying hints + HcclCommConfigInit(&config); + + // NOTE: root info must be identical across ranks. + // We generate it on rank 0 and broadcast via store_ in exchangeUniqueId(). + // Do NOT call HcclGetRootInfo() on every rank here. + HcclResult hcclErr = hccl_api_->commInitRankConfig( + &hccl_comm, comm_size_, uniqueId, rank_, &config); + + if (hcclErr != HCCL_SUCCESS || hccl_comm == nullptr) { + throw std::runtime_error( + "Failed to initialize HCCL communicator: " + + std::string(hccl_api_->getErrorString(hcclErr))); + } + + cleanupTCPStore(hccl_comm); + + return hccl_comm; +} + +} // namespace torch::comms diff --git a/comms/torchcomms/hccl/TorchCommHCCLBootstrap.hpp b/comms/torchcomms/hccl/TorchCommHCCLBootstrap.hpp new file mode 100644 index 000000000..b314e58ed --- /dev/null +++ b/comms/torchcomms/hccl/TorchCommHCCLBootstrap.hpp @@ -0,0 +1,80 @@ +#pragma once + +#include + +#include +#include // @manual=//caffe2:torch-cpp + +#include +#include +#include "comms/torchcomms/TorchCommOptions.hpp" +#include "comms/torchcomms/device/npu/NpuApi.hpp" +#include "comms/torchcomms/hccl/HcclApi.hpp" + +namespace torch::comms { + +constexpr uint16_t kTCPStorePort = 29500; + +class TorchCommHCCLBootstrap { + public: + TorchCommHCCLBootstrap( + c10::intrusive_ptr store, + c10::Device device, + std::shared_ptr hccl_api, + std::shared_ptr npu_api, + std::chrono::milliseconds timeout); + ~TorchCommHCCLBootstrap(); + + // Delete copy and move operations + TorchCommHCCLBootstrap(const TorchCommHCCLBootstrap&) = delete; + TorchCommHCCLBootstrap& operator=(const TorchCommHCCLBootstrap&) = delete; + TorchCommHCCLBootstrap(TorchCommHCCLBootstrap&&) = delete; + TorchCommHCCLBootstrap& operator=(TorchCommHCCLBootstrap&&) = delete; + + HcclComm createHcclComm( + const std::string& name, + const CommOptions& options = {}); + static std::string getHCCLStoreKey(); + static std::string getHCCLStoreKeyPrefix(); + static int getHCCLStoreKeyCounter(); + + int getRank() { + return rank_; + } + int getSize() { + return comm_size_; + } + c10::Device getDevice() { + return device_; + } + + private: + HcclRootInfo exchangeUniqueId(std::string_view name); + HcclRootInfo exchangeUniqueIdStore(); + HcclRootInfo exchangeUniqueIdTCPStore(std::string_view name); + bool isTCPStoreEnabled(); + void cleanupTCPStore(HcclComm hccl_comm); + + private: + const std::chrono::milliseconds timeout_; + static int counter_; + + c10::intrusive_ptr store_; + bool created_internal_store_; + c10::Device device_; + std::shared_ptr hccl_api_; + std::shared_ptr npu_api_; + void* barrier_buffer_{nullptr}; + int rank_; + int comm_size_; + + std::string uniqueid_xchg_method_; +}; + +// Helper function to populate HCCL config from hints +void populateHcclConfigFromHints( + HcclCommConfig& config, + const CommOptions& options, + const std::string& name); + +} // namespace torch::comms diff --git a/comms/torchcomms/hccl/TorchCommHCCLPy.cpp b/comms/torchcomms/hccl/TorchCommHCCLPy.cpp new file mode 100644 index 000000000..05e007e60 --- /dev/null +++ b/comms/torchcomms/hccl/TorchCommHCCLPy.cpp @@ -0,0 +1,16 @@ +#include +#include +#include +#include +#include + +#include "comms/torchcomms/hccl/TorchCommHCCL.hpp" + +namespace py = pybind11; +using namespace torch::comms; + +PYBIND11_MODULE(_comms_hccl, m) { + m.doc() = "HCCL specific python bindings for TorchComm"; + + py::class_>(m, "TorchCommHCCL"); +} diff --git a/comms/torchcomms/hccl/TorchCommHCCLUtils.cpp b/comms/torchcomms/hccl/TorchCommHCCLUtils.cpp new file mode 100644 index 000000000..15826148e --- /dev/null +++ b/comms/torchcomms/hccl/TorchCommHCCLUtils.cpp @@ -0,0 +1,318 @@ +#include +#include +#include +#include +#include "comms/torchcomms/utils/Logging.hpp" +#include "comms/torchcomms/hccl/TorchCommHCCL.hpp" + +namespace torch::comms { + +namespace { + +HcclDataType getHcclDataTypeInternal(const at::Tensor& tensor) { + switch (tensor.scalar_type()) { + case at::ScalarType::Float: + return HCCL_DATA_TYPE_FP32; + case at::ScalarType::Double: + return HCCL_DATA_TYPE_FP64; + case at::ScalarType::Half: + return HCCL_DATA_TYPE_FP16; + case at::ScalarType::BFloat16: + return HCCL_DATA_TYPE_BFP16; + case at::ScalarType::Int: + return HCCL_DATA_TYPE_INT32; + case at::ScalarType::Long: + return HCCL_DATA_TYPE_INT64; + case at::ScalarType::Char: + return HCCL_DATA_TYPE_INT8; + case at::ScalarType::Byte: + return HCCL_DATA_TYPE_UINT8; + default: + throw std::runtime_error("Unsupported tensor data type for HCCL"); + } +} + +template +void createPreMulSum( + HcclReduceOp* op, + const PreMulSumFactorT& factor, + const HcclComm& comm, + HcclApi* hccl_api) { + // HCCL doesn't support premul_sum, so we just use sum + // The premul operation will be handled separately + *op = HCCL_REDUCE_SUM; +} + +} // namespace + +TorchCommHCCL::RedOpRAII::RedOpRAII(HcclReduceOp op) + : hcclRedOp_(op), comm_(nullptr) {} + +TorchCommHCCL::RedOpRAII::RedOpRAII( + const ReduceOp& op, + const HcclComm comm, + const HcclDataType dataType, + std::shared_ptr hccl_api) + : comm_(comm), hccl_api_(std::move(hccl_api)) { + TORCH_INTERNAL_ASSERT( + op == ReduceOp::RedOpType::PREMUL_SUM, + "Constructing premul_sum RedOpRAII with non-premul_sum RedOpType"); + + if (!op.factor().has_value()) { + hcclRedOp_ = HCCL_REDUCE_SUM; + comm_ = nullptr; + return; + } + + // HCCL doesn't support premul_sum natively, just use sum + // The premul operation is handled in the preReduce function + const auto& factor = op.factor().value(); + switch (dataType) { + case HCCL_DATA_TYPE_FP16: + createPreMulSum( + &hcclRedOp_, factor, comm, hccl_api_.get()); + break; + case HCCL_DATA_TYPE_FP32: + createPreMulSum( + &hcclRedOp_, factor, comm, hccl_api_.get()); + break; + case HCCL_DATA_TYPE_BFP16: + createPreMulSum( + &hcclRedOp_, factor, comm, hccl_api_.get()); + break; + case HCCL_DATA_TYPE_FP64: + createPreMulSum( + &hcclRedOp_, factor, comm, hccl_api_.get()); + break; + default: + throw std::runtime_error( + "PreMulSum Data type must be half, float, bfloat16 or double"); + } +} + +TorchCommHCCL::RedOpRAII::~RedOpRAII() { + // HCCL doesn't need cleanup for reduce ops +} + +size_t TorchCommHCCL::wordSize(HcclDataType type) const { + switch (type) { + case HCCL_DATA_TYPE_INT8: + case HCCL_DATA_TYPE_UINT8: + return 1; + case HCCL_DATA_TYPE_FP16: + case HCCL_DATA_TYPE_BFP16: + return 2; + case HCCL_DATA_TYPE_INT32: + case HCCL_DATA_TYPE_UINT32: + case HCCL_DATA_TYPE_FP32: + return 4; + case HCCL_DATA_TYPE_INT64: + case HCCL_DATA_TYPE_FP64: + return 8; + default: + return 0; + } +} + +HcclDataType TorchCommHCCL::getHcclDataType(const at::Tensor& tensor) { + return getHcclDataTypeInternal(tensor); +} + +TorchCommHCCL::RedOpRAII TorchCommHCCL::getHcclReduceOp( + const ReduceOp& op, + const HcclComm comm, + const HcclDataType dataType) { + switch (op) { + case ReduceOp::RedOpType::SUM: + return HCCL_REDUCE_SUM; + case ReduceOp::RedOpType::PRODUCT: + return HCCL_REDUCE_PROD; + case ReduceOp::RedOpType::MIN: + return HCCL_REDUCE_MIN; + case ReduceOp::RedOpType::MAX: + return HCCL_REDUCE_MAX; + case ReduceOp::RedOpType::PREMUL_SUM: + return RedOpRAII(op, comm, dataType, hccl_api_); + case ReduceOp::RedOpType::AVG: + // HCCL doesn't support AVG natively + throw std::runtime_error("AVG reduce operation not supported in HCCL"); + case ReduceOp::RedOpType::BAND: + // HCCL doesn't have bitwise AND + throw std::runtime_error("Unsupported BAND reduce operation"); + case ReduceOp::RedOpType::BOR: + // HCCL doesn't have bitwise OR + throw std::runtime_error("Unsupported BOR reduce operation"); + case ReduceOp::RedOpType::BXOR: + // HCCL doesn't have bitwise XOR + throw std::runtime_error("Unsupported BXOR reduce operation"); + default: + throw std::runtime_error("Unsupported reduce operation"); + } +} + +void TorchCommHCCL::checkWorkQueue(bool isMainThread) { + TorchWorkHCCL::WorkStatus status = workq_.garbageCollect(isMainThread); + + switch (status) { + case TorchWorkHCCL::WorkStatus::TIMEDOUT: + comm_state_ = CommState::TIMEOUT; + break; + case TorchWorkHCCL::WorkStatus::ERROR: + comm_state_ = CommState::ERROR; + break; + default: + // For COMPLETED, NOT_STARTED, and INPROGRESS, no state change needed + break; + } +} + +// The timeout thread cannot make HCCL calls. The only NPU call it can make +// is npuEventQuery. +void TorchCommHCCL::timeoutWatchdog() noexcept { + TC_LOG(INFO) << "Timeout thread starting for rank: " << rank_; + while (!shutdown_) { + { + std::unique_lock lock(timeout_mutex_); + // Wait for a shorter interval to check work objects periodically + // Wake up either after 1 second or immediately if shutdown is requested + timeout_cv_.wait_for( + lock, std::chrono::seconds(1), [this]() { return shutdown_.load(); }); + + // If we're shutting down, exit the loop + if (shutdown_) { + break; + } + } + + // Check work objects for completion or timeout + checkWorkQueue(false); + if (comm_state_ != CommState::NORMAL && + options_.abort_process_on_timeout_or_error) { + // Log the error and abort the process. We cannot abort the HCCL + // communicator as it is not safe to call HCCL operations from + // multiple threads at the same time. + if (comm_state_ == CommState::TIMEOUT) { + TC_LOG(ERROR) << "Aborting process due to timeout on rank " << rank_ + << " - timeout watchdog detected operation timeout"; + } else if (comm_state_ == CommState::ERROR) { + TC_LOG(ERROR) << "Aborting process due to error on rank " << rank_ + << " - timeout watchdog detected operation error. "; + } + abort(); + } + } + + TC_LOG(INFO) << "Timeout thread exiting for rank: " << rank_; +} + +void TorchCommHCCL::checkInitialized() const { + if (init_state_ != InitializationState::INITIALIZED) { + throw std::runtime_error("TorchCommHCCL not initialized"); + } +} + +void TorchCommHCCL::checkAndAbortIfTimedOutOrError() { + // First, check work queue status + checkWorkQueue(true); + + if (comm_state_ == CommState::TIMEOUT) { + abortHcclComm(); + if (options_.abort_process_on_timeout_or_error) { + TC_LOG(ERROR) << "Aborting process due to timeout"; + abort(); + } else { + throw std::runtime_error("HCCL operation timed out"); + } + } else if (comm_state_ == CommState::ERROR) { + HcclResult asyncErr; + hccl_api_->commGetAsyncError(hccl_comm_, &asyncErr); + HCCLException hcclException(*hccl_api_, "HCCL Async Error", asyncErr); + abortHcclComm(); + if (options_.abort_process_on_timeout_or_error) { + TC_LOG(ERROR) << "Aborting process due to error: " + << hcclException.what(); + abort(); + } else { + throw hcclException; + } + } +} + +c10::intrusive_ptr TorchCommHCCL::createWork( + npuStream_t stream, + std::chrono::milliseconds timeout, + const std::vector& inputTensors) { + // Only create the work object without enqueuing it + auto work = c10::make_intrusive( + shared_from_this(), stream, timeout, inputTensors); + return work; +} + +void TorchCommHCCL::enqueueWork( + c10::intrusive_ptr work, + npuStream_t stream) { + // Add work to stream's queue after events have been recorded + workq_.enqueueWork(std::move(work), stream); +} + +npuStream_t TorchCommHCCL::getOperationStream(bool async_op) { + if (async_op) { + // Get current PyTorch NPU stream for this device + npuStream_t current_stream = npu_api_->getCurrentNPUStream(device_.index()); + + // Record event on current stream and wait for it on internal stream + NPU_CHECK( + npu_api_, + npu_api_->eventRecord(dependency_event_.value(), current_stream), + "Failed to record dependency event"); + + NPU_CHECK( + npu_api_, + npu_api_->streamWaitEvent( + internal_stream_.value(), dependency_event_.value(), 0), + "Failed to make internal stream wait for dependency event"); + + return internal_stream_.value(); + } else { + // Use the current PyTorch NPU stream for synchronous operations + return npu_api_->getCurrentNPUStream(device_.index()); + } +} + +void TorchCommHCCL::ensureTensorContiguous(const at::Tensor& tensor) { + if (!tensor.is_contiguous()) { + throw std::runtime_error("Tensor must be contiguous for HCCL operations"); + } +} + +// Protected methods (not in the private section of the header) +npuEvent_t TorchCommHCCL::getEvent() { + std::lock_guard lock(event_pool_mutex_); + + if (!event_pool_.empty()) { + npuEvent_t event = std::move(event_pool_.front()); + event_pool_.pop(); + return event; + } + + // Create new event if pool is empty + npuEvent_t event; + NPU_CHECK( + npu_api_, + npu_api_->eventCreate(event), + "Failed to create event"); + return event; +} + +void TorchCommHCCL::returnEvent(npuEvent_t&& event) { + std::lock_guard lock(event_pool_mutex_); + + if (event_pool_.size() < max_event_pool_size_) { + event_pool_.push(std::move(event)); + } else { + // Pool is full, destroy the event + NPU_CHECK( + npu_api_, npu_api_->eventDestroy(event), "Failed to destroy event"); + } +} +} // namespace torch::comms diff --git a/comms/torchcomms/hccl/TorchWorkHCCL.cpp b/comms/torchcomms/hccl/TorchWorkHCCL.cpp new file mode 100644 index 000000000..5c59a4902 --- /dev/null +++ b/comms/torchcomms/hccl/TorchWorkHCCL.cpp @@ -0,0 +1,136 @@ +#include "comms/torchcomms/hccl/TorchWorkHCCL.hpp" +#include "comms/torchcomms/utils/Logging.hpp" +#include "comms/torchcomms/utils/TracingGuard.hpp" +#include "comms/torchcomms/hccl/TorchCommHCCL.hpp" + +namespace torch::comms { + +TorchWorkHCCL::TorchWorkHCCL( + std::shared_ptr comm, + npuStream_t stream, + std::chrono::milliseconds timeout_ms, + const std::vector& inputTensors) + : inputTensors_(inputTensors), + comm_(std::move(comm)), + stream_(stream), + timeout_ms_(timeout_ms), + state_(WorkStatus::NOT_STARTED) { + // If not in graph capture mode, create the events for start and end + // recording + start_event_ = comm_->getEvent(); + end_event_ = comm_->getEvent(); + + // Events will be recorded around the actual HCCL operations +} + +TorchWorkHCCL::~TorchWorkHCCL() { + if (!comm_) { + return; + } + // If not in graph capture mode, return the events to the pool + comm_->returnEvent(std::move(start_event_)); + comm_->returnEvent(std::move(end_event_)); +} + +void TorchWorkHCCL::recordStart() { + NPU_CHECK( + comm_->getNpuApi(), + comm_->getNpuApi()->eventRecord(start_event_, stream_), + "Failed to record start event"); +} + +void TorchWorkHCCL::recordEnd() { + NPU_CHECK( + comm_->getNpuApi(), + comm_->getNpuApi()->eventRecord(end_event_, stream_), + "Failed to record end event"); +} + +TorchWorkHCCL::WorkStatus TorchWorkHCCL::checkStatus() { + // If already marked as completed, return COMPLETED + if (state_ == WorkStatus::COMPLETED || state_ == WorkStatus::ERROR || + state_ == WorkStatus::TIMEDOUT) { + return state_; + } + + // Step 1: If start_completed_time_ doesn't have a value yet, query the start + // event + if (!start_completed_time_.has_value()) { + npu_result_t start_status = comm_->getNpuApi()->eventQuery(start_event_); + + if (start_status == NPU_SUCCESS) { + // Start event has completed, store the current time + start_completed_time_ = std::chrono::steady_clock::now(); + state_ = WorkStatus::INPROGRESS; + } else if ( + start_status != NPU_ERROR_NOT_READY && + start_status != NPU_ERROR_UNSUPPORTED) { + // Some other error occurred with the start event + TC_LOG(ERROR) << "NPU error during start event query: " + << comm_->getNpuApi()->getErrorString(start_status) << " (" + << start_status << ")"; + state_ = WorkStatus::ERROR; + } + } + if (state_ == WorkStatus::NOT_STARTED || state_ == WorkStatus::ERROR) { + return state_; + } + + // Step 2: If we get here, start event has completed, so query the end event + npu_result_t end_status = comm_->getNpuApi()->eventQuery(end_event_); + + if (end_status == NPU_SUCCESS) { + // End event has completed, mark the work as completed + state_ = WorkStatus::COMPLETED; + + // Release the input tensors to keep the lifetime of the tensors short + inputTensors_.clear(); + } else if (end_status == NPU_ERROR_NOT_READY) { + // End event has not completed yet, check for timeout + auto current_time = std::chrono::steady_clock::now(); + auto elapsed_milliseconds = + std::chrono::duration_cast( + current_time - start_completed_time_.value()); + + // Check if the operation has timed out + if (elapsed_milliseconds > timeout_ms_) { + // Operation has timed out + state_ = WorkStatus::TIMEDOUT; + } + } else { + // Some other error occurred with the end event + TC_LOG(ERROR) << "NPU error during end event query: " + << comm_->getNpuApi()->getErrorString(end_status) << " (" + << end_status << ")"; + state_ = WorkStatus::ERROR; + } + return state_; +} + +void TorchWorkHCCL::wait() { + // If already completed, return immediately + WorkStatus local_state = state_; + if (local_state == WorkStatus::COMPLETED || + local_state == WorkStatus::ERROR || local_state == WorkStatus::TIMEDOUT) { + return; + } + + TracingGuard tracingGuard( + std::string(comm_->getCommName()), + comm_->getSize(), + "wait", + comm_->getRank()); + + // Get the current stream using the device from the comm object + npuStream_t current_stream = + comm_->getNpuApi()->getCurrentNPUStream(comm_->device_.index()); + + // Add a dependency from the work's stream to the current stream + // This makes the current stream wait for the end_event_ recorded on the + // work's stream + NPU_CHECK( + comm_->getNpuApi(), + comm_->getNpuApi()->streamWaitEvent(current_stream, end_event_, 0), + "Failed to make stream wait for event"); +} +} // namespace torch::comms diff --git a/comms/torchcomms/hccl/TorchWorkHCCL.hpp b/comms/torchcomms/hccl/TorchWorkHCCL.hpp new file mode 100644 index 000000000..b89e3a4e2 --- /dev/null +++ b/comms/torchcomms/hccl/TorchWorkHCCL.hpp @@ -0,0 +1,92 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +#include +#include "comms/torchcomms/TorchWork.hpp" +#include "comms/torchcomms/device/npu/NpuApi.hpp" + +namespace torch::comms { + +// Forward declaration +class TorchCommHCCL; + +class TorchWorkHCCL : public TorchWork { + public: + // Status of a work object + enum class WorkStatus { + NOT_STARTED, // Work has not started yet + INPROGRESS, // Work is still in progress, + COMPLETED, // Work has completed successfully + TIMEDOUT, // Work has timed out + ERROR // Work has encountered an error + }; + + TorchWorkHCCL( + std::shared_ptr comm, + npuStream_t stream, + std::chrono::milliseconds timeout_ms, + const std::vector& inputTensors); + ~TorchWorkHCCL() override; + + // Delete copy and move operations + TorchWorkHCCL(const TorchWorkHCCL&) = delete; + TorchWorkHCCL(TorchWorkHCCL&&) = delete; + TorchWorkHCCL& operator=(const TorchWorkHCCL&) = delete; + TorchWorkHCCL& operator=(TorchWorkHCCL&&) = delete; + + void wait() override; + + protected: + void recordStart(); + void recordEnd(); + + friend class TorchCommHCCL; + friend class TorchWorkHCCLQueue; + + private: + // Check the status of the work object + WorkStatus checkStatus(); + + std::chrono::milliseconds getTimeout() const { + return timeout_ms_; + } + std::vector inputTensors_; + + std::shared_ptr comm_; + npuEvent_t start_event_; + npuEvent_t end_event_; + npuStream_t stream_; // stream is not owned by this class + + std::chrono::milliseconds timeout_ms_; + + // state machine variables. TODO: convert to state machine later + std::atomic state_; + + std::optional start_completed_time_; +}; + +class TorchWorkHCCLQueue { + public: + TorchWorkHCCLQueue() = default; + ~TorchWorkHCCLQueue() = default; + + TorchWorkHCCL::WorkStatus garbageCollect(bool isMainThread); + // Finalize function can only be called from the main thread + TorchWorkHCCL::WorkStatus finalize(); + void enqueueWork(c10::intrusive_ptr work, npuStream_t stream); + + private: + std::unordered_map>> + stream_work_queues_; + std::vector> completed_work_queue_; + std::recursive_mutex work_queues_mutex_; +}; + +} // namespace torch::comms diff --git a/comms/torchcomms/hccl/TorchWorkHCCLQueue.cpp b/comms/torchcomms/hccl/TorchWorkHCCLQueue.cpp new file mode 100644 index 000000000..0453db11e --- /dev/null +++ b/comms/torchcomms/hccl/TorchWorkHCCLQueue.cpp @@ -0,0 +1,96 @@ +#include "comms/torchcomms/hccl/TorchWorkHCCL.hpp" + +namespace torch::comms { + +TorchWorkHCCL::WorkStatus TorchWorkHCCLQueue::garbageCollect( + bool isMainThread) { + std::lock_guard lock(work_queues_mutex_); + + TorchWorkHCCL::WorkStatus last_status = TorchWorkHCCL::WorkStatus::COMPLETED; + + // Keep popping completed elements until we hit an in-progress element + // or the queue is empty + // Use an iterator to safely remove empty queues while iterating + auto it = stream_work_queues_.begin(); + while (it != stream_work_queues_.end()) { + auto& work_queue = it->second; + + while (!work_queue.empty()) { + // Get the first work object in the queue + auto work = work_queue.front(); + + // Use the checkStatus function to determine the work status + TorchWorkHCCL::WorkStatus status = work->checkStatus(); + last_status = status; + + if (status == TorchWorkHCCL::WorkStatus::COMPLETED) { + // Work is completed, remove it from the work queue + work_queue.pop(); + completed_work_queue_.push_back(work); + // Continue to the next element in the queue + } else if ( + status == TorchWorkHCCL::WorkStatus::TIMEDOUT || + status == TorchWorkHCCL::WorkStatus::ERROR) { + // Return the error status immediately + return status; + } else { + // NOT_STARTED or INPROGRESS - stop processing this queue + break; + } + } + + // If the queue is now empty, remove it from the map + if (work_queue.empty()) { + it = stream_work_queues_.erase(it); + } else { + ++it; + } + } + + if (isMainThread) { + // If we are the main thread, clear the completed work queues + completed_work_queue_.clear(); + } + + return last_status; +} + +TorchWorkHCCL::WorkStatus TorchWorkHCCLQueue::finalize() { + // Because this function is typically called after the timeout thread has + // already joined, we might not need to lock here. But doing the lock anyway, + // as defensive programming, just in case someone moves the thread join order + // later. The cost of the lock itself should be small on modern linux systems + // (uncontended locks are typically just an atomic operation). + std::lock_guard lock(work_queues_mutex_); + + // Initialize the status to COMPLETED to cover the case where the queue is + // empty + TorchWorkHCCL::WorkStatus status = TorchWorkHCCL::WorkStatus::COMPLETED; + while (!stream_work_queues_.empty()) { + status = garbageCollect(true); + if (status == TorchWorkHCCL::WorkStatus::ERROR || + status == TorchWorkHCCL::WorkStatus::TIMEDOUT || + status == TorchWorkHCCL::WorkStatus::COMPLETED) { + break; + } + } + + // Clear all work queues & completed work queue. + // + // NOTE: finalize MUST return without holding references to any work object, + // otherwise it may leak object and cause side effects. + stream_work_queues_.clear(); + completed_work_queue_.clear(); + + return status; +} + +void TorchWorkHCCLQueue::enqueueWork( + c10::intrusive_ptr work, + npuStream_t stream) { + // Add work to stream's queue after events have been recorded + std::lock_guard lock(work_queues_mutex_); + stream_work_queues_[stream].push(work); +} + +} // namespace torch::comms diff --git a/setup.py b/setup.py index 191b2b869..fdc395968 100644 --- a/setup.py +++ b/setup.py @@ -64,6 +64,7 @@ def flag_str(val: bool): USE_RCCL = flag_enabled("USE_RCCL", False) USE_RCCLX = flag_enabled("USE_RCCLX", False) USE_XCCL = flag_enabled("USE_XCCL", False) +USE_HCCL = flag_enabled("USE_HCCL", False) IS_ROCM = hasattr(torch.version, "hip") and torch.version.hip is not None # Transport is CUDA-only; disable by default on ROCm but allow explicit opt-in. USE_TRANSPORT = flag_enabled("USE_TRANSPORT", not IS_ROCM) @@ -154,6 +155,7 @@ def build_cmake(self, ext): f"-DUSE_RCCL={flag_str(USE_RCCL)}", f"-DUSE_RCCLX={flag_str(USE_RCCLX)}", f"-DUSE_XCCL={flag_str(USE_XCCL)}", + f"-DUSE_HCCL={flag_str(USE_HCCL)}", f"-DUSE_TRANSPORT={flag_str(USE_TRANSPORT)}", f"-DUSE_TRITON={flag_str(USE_TRITON)}", ] @@ -186,6 +188,7 @@ def build_cmake(self, ext): ("rccl", USE_RCCL), ("rcclx", USE_RCCLX), ("xccl", USE_XCCL), + ("hccl", USE_HCCL), ] ext_modules = [CMakeExtension("torchcomms._comms")]