From 8667b5463ffc2d2a3ba93d5120be812c6a1c3d55 Mon Sep 17 00:00:00 2001 From: zyuli <2436472829@qq.com> Date: Tue, 28 Apr 2026 18:58:35 +0800 Subject: [PATCH] [TLERaw] Support NVSHMEM --- .../PatternTritonGPUOpToLLVM.h | 5 + include/triton/Dialect/Triton/IR/TritonOps.td | 26 ++++ lib/Conversion/TritonGPUToLLVM/CMakeLists.txt | 1 + .../TritonGPUToLLVM/ExternCallOpToLLVM.cpp | 61 +++++++++ .../TritonToTritonGPUPass.cpp | 1 + lib/Dialect/Triton/IR/Ops.cpp | 18 +++ python/src/ir.cc | 8 ++ .../experimental/tle/language/raw/__init__.py | 4 +- .../experimental/tle/language/raw/core.py | 5 + .../experimental/tle/raw/cuda/runtime.py | 25 +++- python/triton/language/core.py | 88 +++++++++++++ .../01-simple-shift/simple-shift-device.cu | 8 ++ .../01-simple-shift/simple-shift-host.cu | 31 +++++ .../nvshmem/01-simple-shift/simple-shift.py | 122 ++++++++++++++++++ third_party/nvidia/backend/compiler.py | 41 +++++- third_party/nvidia/language/cuda/__init__.py | 2 + .../nvidia/language/cuda/libnvshmem_device.py | 31 +++++ .../TritonNVIDIAGPUToLLVM/TritonGPUToLLVM.cpp | 2 + 18 files changed, 472 insertions(+), 7 deletions(-) create mode 100644 lib/Conversion/TritonGPUToLLVM/ExternCallOpToLLVM.cpp create mode 100644 python/tutorials/tle/raw/nvshmem/01-simple-shift/simple-shift-device.cu create mode 100644 python/tutorials/tle/raw/nvshmem/01-simple-shift/simple-shift-host.cu create mode 100644 python/tutorials/tle/raw/nvshmem/01-simple-shift/simple-shift.py create mode 100644 third_party/nvidia/language/cuda/libnvshmem_device.py diff --git a/include/triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h b/include/triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h index 680bf0e045..69102ed61b 100644 --- a/include/triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h +++ b/include/triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h @@ -102,6 +102,11 @@ void populatePrintOpToLLVMPattern(LLVMTypeConverter &typeConverter, const TargetInfoBase &targetInfo, PatternBenefit benefit); +void populateExternCallOpToLLVMPattern(LLVMTypeConverter &typeConverter, + RewritePatternSet &patterns, + const TargetInfoBase &targetInfo, + PatternBenefit benefit); + void populateInstrumentationToLLVMPatterns(LLVMTypeConverter &typeConverter, const TargetInfoBase &targetInfo, RewritePatternSet &patterns, diff --git a/include/triton/Dialect/Triton/IR/TritonOps.td b/include/triton/Dialect/Triton/IR/TritonOps.td index 7fd215f9de..95239b3cb9 100644 --- a/include/triton/Dialect/Triton/IR/TritonOps.td +++ b/include/triton/Dialect/Triton/IR/TritonOps.td @@ -845,6 +845,32 @@ def TT_MapElementwiseReturnOp: TT_Op<"map_elementwise.return", let assemblyFormat = "attr-dict ($result^ `:` type($result))?"; } +// +// External Call op +// +def TT_ExternCallOp : TT_Op<"extern_call", [ + DeclareOpInterfaceMethods, + ConditionallySpeculatable, +]> { + + let description = [{ + call an external function $symbol implemented in $libpath/$libname with $args + return $libpath/$libname:$symbol($args...) + }]; + + let arguments = (ins Variadic:$srcs, StrAttr:$libname, StrAttr:$libpath, StrAttr:$symbol, BoolAttr:$pure); + + let results = (outs Variadic:$result); + + let assemblyFormat = "operands attr-dict `:` functional-type(operands, $result)"; + + let extraClassDeclaration = [{ + // Interface method for ConditionallySpeculatable. + Speculation::Speculatability getSpeculatability(); + }]; + +} + // // External Elementwise op // diff --git a/lib/Conversion/TritonGPUToLLVM/CMakeLists.txt b/lib/Conversion/TritonGPUToLLVM/CMakeLists.txt index d4f49c8d18..44e52cf37c 100644 --- a/lib/Conversion/TritonGPUToLLVM/CMakeLists.txt +++ b/lib/Conversion/TritonGPUToLLVM/CMakeLists.txt @@ -7,6 +7,7 @@ add_triton_library(TritonGPUToLLVM AssertOpToLLVM.cpp ControlFlowOpToLLVM.cpp ConvertLayoutOpToLLVM.cpp + ExternCallOpToLLVM.cpp ElementwiseOpToLLVM.cpp FuncOpToLLVM.cpp GatherOpToLLVM.cpp diff --git a/lib/Conversion/TritonGPUToLLVM/ExternCallOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ExternCallOpToLLVM.cpp new file mode 100644 index 0000000000..81b6bf256e --- /dev/null +++ b/lib/Conversion/TritonGPUToLLVM/ExternCallOpToLLVM.cpp @@ -0,0 +1,61 @@ +#include "mlir/Conversion/LLVMCommon/Pattern.h" +#include "mlir/Dialect/GPU/IR/GPUDialect.h" +#include "mlir/Support/LLVM.h" +#include "triton/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVMBase.h" +#include "triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h" +#include "triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" + +namespace { + +class ExternCallOpConversion + : public ConvertOpToLLVMPattern { +public: + ExternCallOpConversion(const LLVMTypeConverter &converter, + const PatternBenefit &benefit) + : ConvertOpToLLVMPattern(converter, benefit) {} + + LogicalResult + matchAndRewrite(triton::ExternCallOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op->getLoc(); + + if (op->getNumResults() > 1) { + llvm::errs() << "ExternCallConversion does not support multi outs."; + return failure(); + } + + LLVM::LLVMVoidType voidTy = void_ty(op->getContext()); + auto newOperands = adaptor.getOperands(); + Type retType = + op->getNumResults() == 0 + ? voidTy + : this->getTypeConverter()->convertType(op->getResult(0).getType()); + std::string funcName = op.getSymbol().str(); + StringRef libname = op.getLibname(); + StringRef libpath = op.getLibpath(); + + Operation *externCallOp; + Type funcType = mlir::triton::gpu::getFunctionType(retType, newOperands); + LLVM::LLVMFuncOp funcOp = mlir::triton::gpu::appendOrGetExternFuncOp( + rewriter, op, funcName, funcType, libname, libpath); + externCallOp = LLVM::createLLVMCallOp(rewriter, loc, funcOp, newOperands); + + if (op->getNumResults() == 0) { + rewriter.eraseOp(op); + } else { + rewriter.replaceOp(op, externCallOp->getResult(0)); + } + + return success(); + } +}; + +} // namespace + +void mlir::triton::populateExternCallOpToLLVMPattern( + LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, + const TargetInfoBase &targetInfo, PatternBenefit benefit) { + patterns.add(typeConverter, benefit); +} diff --git a/lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp b/lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp index b49871486d..45c6134655 100644 --- a/lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp +++ b/lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp @@ -600,6 +600,7 @@ void populateTritonPatterns(TritonGPUTypeConverter &typeConverter, GenericOpPattern, GenericOpPattern, GenericOpPattern, + GenericOpPattern, GenericOpPattern, GenericOpPattern, GenericOpPattern, diff --git a/lib/Dialect/Triton/IR/Ops.cpp b/lib/Dialect/Triton/IR/Ops.cpp index 27fa26554d..854ae1f2bc 100644 --- a/lib/Dialect/Triton/IR/Ops.cpp +++ b/lib/Dialect/Triton/IR/Ops.cpp @@ -1313,6 +1313,24 @@ Speculation::Speculatability ExternElementwiseOp::getSpeculatability() { return Speculation::NotSpeculatable; } +// -- ExternCallOp -- +void ExternCallOp::getEffects( + SmallVectorImpl> + &effects) { + if (getPure()) + return; + effects.emplace_back(MemoryEffects::Write::get(), + SideEffects::DefaultResource::get()); + effects.emplace_back(MemoryEffects::Read::get(), + SideEffects::DefaultResource::get()); +} + +Speculation::Speculatability ExternCallOp::getSpeculatability() { + if (getPure()) + return Speculation::Speculatable; + return Speculation::NotSpeculatable; +} + // -- GatherOp -- LogicalResult GatherOp::verify() { RankedTensorType indicesTy = getIndices().getType(); diff --git a/python/src/ir.cc b/python/src/ir.cc index e5591dcb08..9abfe8472f 100644 --- a/python/src/ir.cc +++ b/python/src/ir.cc @@ -1673,6 +1673,14 @@ void init_triton_ir(py::module &&m) { return self.create(retType, argList, libName, libPath, symbol, isPure); }) + .def("create_extern_call", + [](TritonOpBuilder &self, const std::string &libName, + const std::string &libPath, const std::string &symbol, + std::vector &argList, const std::vector &retTypes, + bool isPure) -> OpState { + return self.create(retTypes, argList, libName, + libPath, symbol, isPure); + }) // Built-in instruction .def("create_get_program_id", [](TritonOpBuilder &self, int axis) -> Value { diff --git a/python/triton/experimental/tle/language/raw/__init__.py b/python/triton/experimental/tle/language/raw/__init__.py index d66f6cd381..790c65cf3f 100644 --- a/python/triton/experimental/tle/language/raw/__init__.py +++ b/python/triton/experimental/tle/language/raw/__init__.py @@ -1,3 +1,3 @@ -from .core import call, call_smem +from .core import call, call_smem, call_nvshmem -__all__ = ["call", "call_smem"] +__all__ = ["call", "call_smem", "call_nvshmem"] diff --git a/python/triton/experimental/tle/language/raw/core.py b/python/triton/experimental/tle/language/raw/core.py index d7029380eb..7e25385313 100644 --- a/python/triton/experimental/tle/language/raw/core.py +++ b/python/triton/experimental/tle/language/raw/core.py @@ -46,3 +46,8 @@ def call_smem(func, args, _semantic=None): return buffer_tensors[0] else: return tl.tuple(buffer_tensors) + + +@builtin +def call_nvshmem(func, outputs, inputs, _semantic=None): + func.make_cubin() diff --git a/python/triton/experimental/tle/raw/cuda/runtime.py b/python/triton/experimental/tle/raw/cuda/runtime.py index 9015f2bf7f..955a366e67 100644 --- a/python/triton/experimental/tle/raw/cuda/runtime.py +++ b/python/triton/experimental/tle/raw/cuda/runtime.py @@ -3,20 +3,24 @@ from pathlib import Path import subprocess from typing import Any, Final +import torch from triton._C.libtriton import llvm # pyright: ignore[reportMissingImports] from triton._C.libtriton.tle.llvm import parse_llvm_ir # pyright: ignore[reportMissingImports] # TODO: We use cli tools to compile CUDA code temporarily, and plan to replace it with LLVM components Python bindings in the future. CLANG = os.getenv("CLANG", "clang") +NVCC = os.getenv("NVCC", "nvcc") class CUDAJITFunction(object): def __init__(self, fn: Any, file: Path, *args, **kwargs) -> None: - super().__init__(*args, **kwargs) + super().__init__() self.fn: Final[Any] = fn self.code: Final[str] = file.read_text() + self.file: Final[Path] = file + self.libs = kwargs.get("library", {}) self.__triton_builtin__: Final[bool] = True def make_llvm(self, mlir_context) -> str: @@ -40,3 +44,22 @@ def make_llvm(self, mlir_context) -> str: llvm_context = llvm.context() module = parse_llvm_ir(build.stdout.decode(), llvm_context, mlir_context) return f"{module}" + + def make_cubin(self): + src = self.file + dst = Path(src).with_suffix('.o') + include_dirs = [] + for lib_name, lib_path in self.libs.items(): + # TODO: Remove the method of passing information by setting environment variables. + os.environ[(lib_name + "_home").upper()] = lib_path + include_dirs.append(os.path.join(lib_path, "include")) + include_flags = [f"-I{inc_dir}" for inc_dir in include_dirs] + prop = torch.cuda.get_device_properties(torch.cuda.current_device()) + arch = f"-arch=sm_{prop.major}{prop.minor}" + build = subprocess.run([NVCC, "-rdc=true", arch, *include_flags, "--extended-lambda", "-c", "-o", dst, src], + capture_output=True) + assert build.returncode == 0, (f"nvcc failed\nstderr:\n{build.stderr.decode()}") + # TODO: Remove the method of passing information by setting environment variables. + os.environ["USE_NVCC"] = 'True' + os.environ["CUDA_CUBIN"] = str(dst) + return diff --git a/python/triton/language/core.py b/python/triton/language/core.py index 439671c3a1..5a9c56b12c 100644 --- a/python/triton/language/core.py +++ b/python/triton/language/core.py @@ -3433,6 +3433,94 @@ def binary_op_type_legalization(lhs, rhs, semantic): return semantic.binary_op_type_checking_impl(lhs, rhs) +def dispatch_ec(func, lib_name: str, lib_path: str, args: list, arg_type_symbol_dict: dict, is_pure: bool, + _semantic=None): + ''' + Dispatch a function to a library + :param func: the function to dispatch + :param lib_name: the name of the library + :param lib_path: the path of the library + :param args: the arguments of the function + :param arg_type_symbol_dict: the type of the arguments + :param ret_shape: the shape of the return value + :param _semantic: the builder + :return: the return value of the function + ''' + if len(arg_type_symbol_dict) == 0: + raise ValueError("arg_type_symbol_dict is empty") + + num_args = len(list(arg_type_symbol_dict.keys())[0]) + if len(args) != num_args: + raise ValueError(f"length of input args does not match." + f"Expect {len(args)}, got {num_args}") + + arg_types = [] + arg_list = [] + for arg in args: + if isinstance(arg, tensor): + arg_types.append(arg.dtype) + arg_list.append(arg.handle) + else: + arg_types.append(type(arg)) + arg_list.append(arg) + arg_types = tuple(arg_types) + + if arg_types not in arg_type_symbol_dict: + raise ValueError(f"input arg type does not match." + f"Expect one of {arg_type_symbol_dict.keys()}, got {arg_types}") + else: + symbol = arg_type_symbol_dict[arg_types][0] + ret_types = arg_type_symbol_dict[arg_types][1] + if not isinstance(ret_types, (builtins.list, builtins.tuple)): + ret_types = [ret_types] + + if symbol == "": + raise ValueError("Symbol can not be empty") + call = func(lib_name, lib_path, symbol, arg_list, [ret_type.to_ir(_semantic.builder) for ret_type in ret_types], + is_pure) + + if len(ret_types) == 0: + return tensor(call, void) + if len(ret_types) == 1: + return tensor(call.get_result(0), ret_types[0]) + return tuple(tensor(call.get_result(i), ty) for i, ty in enumerate(ret_types)) + + +@builtin +def extern_call(lib_name: str, lib_path: str, args: list, arg_type_symbol_dict: dict, is_pure: bool, _semantic=None): + ''' + Dispatch an function to a library + :param lib_name: the name of the library + :param lib_path: the path of the library + :param args: the arguments of the function + :param arg_type_symbol_dict: the type of the arguments + :param is_pure: whether the function is pure + :param _semantic: the semantic + :return: the return value of the function + ''' + dispatch_args = args.copy() + all_scalar = True + arg_types = [] + for i in builtins.range(len(dispatch_args)): + dispatch_args[i] = _semantic.to_tensor(dispatch_args[i]) + arg_types.append(dispatch_args[i].dtype) + if dispatch_args[i].type.is_block(): + all_scalar = False + if not all_scalar: + raise ValueError("extern call only support inputs with scalr type") + + if len(arg_type_symbol_dict) == 0: + raise ValueError("arg_type_symbol_dict is empty") + + num_args = len(list(arg_type_symbol_dict.keys())[0]) + if len(args) != num_args: + raise ValueError(f"length of input args does not match." + f"Expect {len(args)}, got {num_args}") + + func = _semantic.builder.create_extern_call + return dispatch_ec(func, lib_name, lib_path, dispatch_args, arg_type_symbol_dict, is_pure, _semantic) + + def extern(fn): """A decorator for external functions.""" return builtin(fn) diff --git a/python/tutorials/tle/raw/nvshmem/01-simple-shift/simple-shift-device.cu b/python/tutorials/tle/raw/nvshmem/01-simple-shift/simple-shift-device.cu new file mode 100644 index 0000000000..f2910883bb --- /dev/null +++ b/python/tutorials/tle/raw/nvshmem/01-simple-shift/simple-shift-device.cu @@ -0,0 +1,8 @@ +#include + +extern "C" __device__ void simple_shift(int *destination) { + int mype = nvshmem_my_pe(); + int npes = nvshmem_n_pes(); + int peer = (mype + 1) % npes; + nvshmem_int_p(destination, mype, peer); +} diff --git a/python/tutorials/tle/raw/nvshmem/01-simple-shift/simple-shift-host.cu b/python/tutorials/tle/raw/nvshmem/01-simple-shift/simple-shift-host.cu new file mode 100644 index 0000000000..be9ea236a1 --- /dev/null +++ b/python/tutorials/tle/raw/nvshmem/01-simple-shift/simple-shift-host.cu @@ -0,0 +1,31 @@ +#include +#include +#include +#include +#include + +extern "C" void nvshmem_init_wrapper() { nvshmem_init(); } + +extern "C" int nvshmemx_cumodule_init_wrapper(CUmodule module) { + return nvshmemx_cumodule_init(module); +} + +extern "C" int nvshmem_team_mype_wrapper() { + int mype_node = nvshmem_team_my_pe(NVSHMEMX_TEAM_NODE); + return mype_node; +} + +// TODO: Adapt to different data types +extern "C" int *nvshmem_alloc_wrapper(int size) { + int *destination = (int *)nvshmem_malloc(sizeof(int) * size); + return destination; +} + +extern "C" void nvshmemx_barrier_warpper(cudaStream_t stream) { + nvshmemx_barrier_all_on_stream(stream); +} + +extern "C" void nvshmem_finalize_wrapper(int *dest) { + nvshmem_free(dest); + nvshmem_finalize(); +} diff --git a/python/tutorials/tle/raw/nvshmem/01-simple-shift/simple-shift.py b/python/tutorials/tle/raw/nvshmem/01-simple-shift/simple-shift.py new file mode 100644 index 0000000000..2f87d70f6d --- /dev/null +++ b/python/tutorials/tle/raw/nvshmem/01-simple-shift/simple-shift.py @@ -0,0 +1,122 @@ +import os +import subprocess +import ctypes +import torch +import triton +import triton.knobs as knobs +import triton.experimental.tle.language.raw as tle_raw + +from pathlib import Path +from triton.experimental.tle.raw import dialect +from triton.language.extra.cuda import libnvshmem_device + + +@dialect( + name="cuda", + file=(Path(__file__).parent / "simple-shift-device.cu").resolve(), + library={"nvshmem": "/home/zyl/zyuli/envs/nvshmem/lib/python3.12/site-packages/nvidia/nvshmem"}, +) +def edsl(*args, **kwargs): + ... + + +@triton.jit +def simple_shift_kernel(destination_ptr, ): + # TODO: Combine with tle_raw.call, then dispatch + tle_raw.call_nvshmem(edsl, [], [destination_ptr]) + libnvshmem_device.simple_shift(destination_ptr) + + +def cuda_host_compile(cuda_host_path, cuda_host_lib): + NVCC = os.getenv("NVCC", "nvcc") + NVSHMEM_HOME = "/home/zyl/zyuli/envs/nvshmem/lib/python3.12/site-packages/nvidia/nvshmem" + include_path = f"-I{os.path.join(NVSHMEM_HOME, 'include')}" + lib_path = f"-L{os.path.join(NVSHMEM_HOME, 'lib')}" + + prop = torch.cuda.get_device_properties(torch.cuda.current_device()) + arch = f"-arch=sm_{prop.major}{prop.minor}" + tmp_file = Path(cuda_host_lib).with_suffix('.so.tmp') + build = [ + NVCC, "-shared", "-Xcompiler", "-fPIC", "-rdc=true", arch, include_path, lib_path, "-lnvshmem_host", + "-lnvshmem_device", "-o", tmp_file, cuda_host_path + ] + build = subprocess.run(build, capture_output=True) + assert build.returncode == 0, (f"NVCC host failed\nstderr:\n{build.stderr.decode()}") + tmp_file.rename(cuda_host_lib) + + +def simpe_shift(): + cu_file = (Path(__file__).parent / "simple-shift-host.cu").resolve() + lib_file = Path(cu_file).with_suffix('.so') + + rank = int(os.getenv("OMPI_COMM_WORLD_RANK", "0")) + if rank == 0: + cuda_host_compile(cu_file, lib_file) + + import time + timeout = 60 + start = time.time() + while True: + if lib_file.exists(): + try: + ctypes.CDLL(str(lib_file)) + break + except OSError: + pass + if time.time() - start > timeout: + raise RuntimeError(f"Timeout waiting for {lib_file}") + time.sleep(0.1) + + lib = ctypes.CDLL(lib_file) + lib.nvshmem_init_wrapper.argtypes = [] + lib.nvshmem_init_wrapper.restype = None + lib.nvshmemx_cumodule_init_wrapper.argtypes = [ctypes.c_void_p] + lib.nvshmemx_cumodule_init_wrapper.restype = ctypes.c_int + lib.nvshmem_team_mype_wrapper.argtypes = [] + lib.nvshmem_team_mype_wrapper.restype = ctypes.c_int + lib.nvshmem_alloc_wrapper.argtypes = [ctypes.c_int] + lib.nvshmem_alloc_wrapper.restype = ctypes.POINTER(ctypes.c_int) + lib.nvshmemx_barrier_warpper.argtypes = [ctypes.c_void_p] + lib.nvshmemx_barrier_warpper.restype = None + lib.nvshmem_finalize_wrapper.argtypes = [ctypes.POINTER(ctypes.c_int)] + lib.nvshmem_finalize_wrapper.restype = None + + lib.nvshmem_init_wrapper() + mype_node = lib.nvshmem_team_mype_wrapper() + torch.cuda.set_device(mype_node) + device = triton.runtime.driver.active.get_active_torch_device() + stream = torch.cuda.Stream() + + dest = lib.nvshmem_alloc_wrapper(1) + dest_addr = ctypes.cast(dest, ctypes.c_void_p).value + storage = torch._C._construct_storage_from_data_pointer(dest_addr, device, 4) + dest_tensor = torch.empty(0, dtype=torch.int32, device=device).set_(storage).view(1) + msg = torch.empty((1, ), dtype=torch.int32, pin_memory=True) + + def cumodule_init_hook(*args, **kwargs): + key = kwargs["key"] + jit_function = kwargs["fn"].jit_function + device = kwargs["compile"]["device"] + kernel_cache = jit_function.device_caches[device][0] + kernel = kernel_cache.get(key, None) + assert kernel is not None + kernel._init_handles() + ret = lib.nvshmemx_cumodule_init_wrapper(ctypes.c_void_p(kernel.module)) + assert ret == 0, f"nvshmemx_cumodule_init_wrapper failed: {ret}" + + knobs.runtime.jit_post_compile_hook = cumodule_init_hook + + simple_shift_kernel[(1, )](dest_tensor) + + stream_ptr = stream.cuda_stream + lib.nvshmemx_barrier_warpper(ctypes.c_void_p(stream_ptr)) + with torch.cuda.stream(stream): + msg.copy_(dest_tensor, non_blocking=True) + stream.synchronize() + + lib.nvshmem_finalize_wrapper(dest) + print(f"Rank {mype_node}: {msg}") + + +if __name__ == "__main__": + simpe_shift() diff --git a/third_party/nvidia/backend/compiler.py b/third_party/nvidia/backend/compiler.py index 632b79075d..a415d63ec1 100644 --- a/third_party/nvidia/backend/compiler.py +++ b/third_party/nvidia/backend/compiler.py @@ -519,6 +519,8 @@ def make_cubin(self, src, metadata, opt, capability): fsrc.flush() fbin = fsrc.name + '.o' + use_nvcc = os.getenv("USE_NVCC", '').lower() in ('1', 'true') + os.environ.pop("USE_NVCC", None) debug_info = [] if knobs.compilation.disable_line_info: # This option is ignored if used without -lineinfo @@ -539,9 +541,11 @@ def make_cubin(self, src, metadata, opt, capability): # Accept more ptxas options if provided ptx_extra_options = opt.ptx_options.split(" ") if opt.ptx_options else [] + # If use nvshmem, we need to compile the ptx file into a relocatable object file and then link it with nvshmem library + compile_only = ["-c"] if use_nvcc else [] ptxas_cmd = [ - ptxas, *debug_info, *fmad, '-v', *disable_opt, *ptx_extra_options, f'--gpu-name={arch}', fsrc.name, - '-o', fbin + ptxas, *compile_only, *debug_info, *fmad, '-v', *disable_opt, *ptx_extra_options, f'--gpu-name={arch}', + fsrc.name, '-o', fbin ] try: subprocess.run(ptxas_cmd, check=True, close_fds=False, stderr=flog) @@ -581,8 +585,37 @@ def make_cubin(self, src, metadata, opt, capability): """) raise PTXASError(error) - with open(fbin, 'rb') as f: - cubin = f.read() + if use_nvcc: + NVLINK = os.getenv("NVLINK", "nvlink") + NVSHMEM_HOME = os.getenv("NVSHMEM_HOME") + fbin_combined = fbin + ".combined.cubin" + cuda_cubin = os.getenv("CUDA_CUBIN") + nvshmem_lib = os.path.join(NVSHMEM_HOME, "lib") + nvlink_cmds = [ + NVLINK, + f"-arch={arch}", + f"-L{nvshmem_lib}", + "-lnvshmem_device", + fbin, + cuda_cubin, + "-o", + fbin_combined, + ] + try: + subprocess.run(nvlink_cmds, check=True, close_fds=False, stderr=flog) + except Exception as e: + import logging + logging.error(f"error runing nvlink: {nvlink_cmds}") + logging.exception(e) + + if use_nvcc: + with open(fbin_combined, 'rb') as f: + cubin = f.read() + if os.path.exists(fbin_combined): + os.remove(fbin_combined) + else: + with open(fbin, 'rb') as f: + cubin = f.read() if os.path.exists(fbin): os.remove(fbin) return cubin diff --git a/third_party/nvidia/language/cuda/__init__.py b/third_party/nvidia/language/cuda/__init__.py index fbececf1de..39207bf60f 100644 --- a/third_party/nvidia/language/cuda/__init__.py +++ b/third_party/nvidia/language/cuda/__init__.py @@ -1,10 +1,12 @@ from . import libdevice +from . import libnvshmem_device from .utils import (globaltimer, num_threads, num_warps, smid, convert_custom_float8_sm70, convert_custom_float8_sm80) from .gdc import (gdc_launch_dependents, gdc_wait) __all__ = [ "libdevice", + 'libnvshmem_device', "globaltimer", "num_threads", "num_warps", diff --git a/third_party/nvidia/language/cuda/libnvshmem_device.py b/third_party/nvidia/language/cuda/libnvshmem_device.py new file mode 100644 index 0000000000..3dec30b9a5 --- /dev/null +++ b/third_party/nvidia/language/cuda/libnvshmem_device.py @@ -0,0 +1,31 @@ +from triton.language import core +import triton.language as tl + + +def _pointer_type_hash(self): + return hash((self.name, self.element_ty, "tt_ptr")) + + +def patch_hash_method_for_pointer_type(): + elem_dtype_list = tl.core.dtype.SINT_TYPES + tl.core.dtype.UINT_TYPES + tl.core.dtype.FP_TYPES + tl.core.dtype.OTHER_TYPES + for elem_dtype in elem_dtype_list: + ptr_ty = type(tl.core.pointer_type(tl.core.dtype(elem_dtype))) + ptr_ty.__hash__ = _pointer_type_hash + + +patch_hash_method_for_pointer_type() + + +@core.extern +def simple_shift(dst, _semantic=None): + return core.extern_call( + "", # libname + "", # libpath + [dst], # args + {( + core.pointer_type(core.dtype("int32")), # arg_type_symbol_dict + ): ("simple_shift", ()), # void return type + }, + is_pure=False, + _semantic=_semantic, + ) diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TritonGPUToLLVM.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TritonGPUToLLVM.cpp index 368a666292..0fc8242c54 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TritonGPUToLLVM.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TritonGPUToLLVM.cpp @@ -222,6 +222,8 @@ struct ConvertTritonGPUToLLVM targetInfo, benefit); mlir::triton::populatePrintOpToLLVMPattern(typeConverter, patterns, targetInfo, benefit); + mlir::triton::populateExternCallOpToLLVMPattern(typeConverter, patterns, + targetInfo, benefit); mlir::triton::populateControlFlowOpToLLVMPattern(typeConverter, patterns, targetInfo, benefit); mlir::triton::NVIDIA::populateSPMDOpToLLVMPattern(typeConverter, patterns,