From c2544109d1b75e14f48a07a72a79bcf2ccde531b Mon Sep 17 00:00:00 2001 From: zyuli <2436472829@qq.com> Date: Wed, 29 Apr 2026 08:52:09 +0800 Subject: [PATCH] [TLERaw] per_token_group_quant_fp8 && half2 --- .../PatternTritonGPUOpToLLVM.h | 5 + include/triton/Dialect/Triton/IR/TritonOps.td | 26 ++ lib/Conversion/TritonGPUToLLVM/CMakeLists.txt | 1 + .../TritonGPUToLLVM/ExternCallOpToLLVM.cpp | 61 ++++ .../TritonToTritonGPUPass.cpp | 1 + lib/Dialect/Triton/IR/Ops.cpp | 18 ++ python/src/ir.cc | 8 + .../experimental/tle/language/raw/__init__.py | 4 +- .../experimental/tle/language/raw/core.py | 5 + .../experimental/tle/raw/cuda/runtime.py | 26 +- python/triton/language/core.py | 88 ++++++ .../tle/raw/cuda/01-vector-add-half2.cu | 21 ++ .../tle/raw/cuda/01-vector-add-half2.py | 42 +++ .../per-token-group-quant-fp8.cu | 132 +++++++++ .../per-token-group-quant-fp8.py | 272 ++++++++++++++++++ .../vectorization.cuh | 29 ++ .../vectorization_utils.cuh | 176 ++++++++++++ third_party/nvidia/backend/compiler.py | 61 +++- third_party/nvidia/language/cuda/__init__.py | 2 + .../nvidia/language/cuda/libnvshmem_device.py | 45 +++ .../TritonNVIDIAGPUToLLVM/TritonGPUToLLVM.cpp | 2 + third_party/tle/triton_tle_raw.cc | 3 +- 22 files changed, 1015 insertions(+), 13 deletions(-) create mode 100644 lib/Conversion/TritonGPUToLLVM/ExternCallOpToLLVM.cpp create mode 100644 python/tutorials/tle/raw/cuda/01-vector-add-half2.cu create mode 100644 python/tutorials/tle/raw/cuda/01-vector-add-half2.py create mode 100644 python/tutorials/tle/raw/cuda/04-per-token-group-quant-fp8/per-token-group-quant-fp8.cu create mode 100644 python/tutorials/tle/raw/cuda/04-per-token-group-quant-fp8/per-token-group-quant-fp8.py create mode 100644 python/tutorials/tle/raw/cuda/04-per-token-group-quant-fp8/vectorization.cuh create mode 100644 python/tutorials/tle/raw/cuda/04-per-token-group-quant-fp8/vectorization_utils.cuh create mode 100644 third_party/nvidia/language/cuda/libnvshmem_device.py diff --git a/include/triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h b/include/triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h index 680bf0e045..69102ed61b 100644 --- a/include/triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h +++ b/include/triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h @@ -102,6 +102,11 @@ void populatePrintOpToLLVMPattern(LLVMTypeConverter &typeConverter, const TargetInfoBase &targetInfo, PatternBenefit benefit); +void populateExternCallOpToLLVMPattern(LLVMTypeConverter &typeConverter, + RewritePatternSet &patterns, + const TargetInfoBase &targetInfo, + PatternBenefit benefit); + void populateInstrumentationToLLVMPatterns(LLVMTypeConverter &typeConverter, const TargetInfoBase &targetInfo, RewritePatternSet &patterns, diff --git a/include/triton/Dialect/Triton/IR/TritonOps.td b/include/triton/Dialect/Triton/IR/TritonOps.td index 7fd215f9de..95239b3cb9 100644 --- a/include/triton/Dialect/Triton/IR/TritonOps.td +++ b/include/triton/Dialect/Triton/IR/TritonOps.td @@ -845,6 +845,32 @@ def TT_MapElementwiseReturnOp: TT_Op<"map_elementwise.return", let assemblyFormat = "attr-dict ($result^ `:` type($result))?"; } +// +// External Call op +// +def TT_ExternCallOp : TT_Op<"extern_call", [ + DeclareOpInterfaceMethods, + ConditionallySpeculatable, +]> { + + let description = [{ + call an external function $symbol implemented in $libpath/$libname with $args + return $libpath/$libname:$symbol($args...) + }]; + + let arguments = (ins Variadic:$srcs, StrAttr:$libname, StrAttr:$libpath, StrAttr:$symbol, BoolAttr:$pure); + + let results = (outs Variadic:$result); + + let assemblyFormat = "operands attr-dict `:` functional-type(operands, $result)"; + + let extraClassDeclaration = [{ + // Interface method for ConditionallySpeculatable. + Speculation::Speculatability getSpeculatability(); + }]; + +} + // // External Elementwise op // diff --git a/lib/Conversion/TritonGPUToLLVM/CMakeLists.txt b/lib/Conversion/TritonGPUToLLVM/CMakeLists.txt index d4f49c8d18..44e52cf37c 100644 --- a/lib/Conversion/TritonGPUToLLVM/CMakeLists.txt +++ b/lib/Conversion/TritonGPUToLLVM/CMakeLists.txt @@ -7,6 +7,7 @@ add_triton_library(TritonGPUToLLVM AssertOpToLLVM.cpp ControlFlowOpToLLVM.cpp ConvertLayoutOpToLLVM.cpp + ExternCallOpToLLVM.cpp ElementwiseOpToLLVM.cpp FuncOpToLLVM.cpp GatherOpToLLVM.cpp diff --git a/lib/Conversion/TritonGPUToLLVM/ExternCallOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ExternCallOpToLLVM.cpp new file mode 100644 index 0000000000..81b6bf256e --- /dev/null +++ b/lib/Conversion/TritonGPUToLLVM/ExternCallOpToLLVM.cpp @@ -0,0 +1,61 @@ +#include "mlir/Conversion/LLVMCommon/Pattern.h" +#include "mlir/Dialect/GPU/IR/GPUDialect.h" +#include "mlir/Support/LLVM.h" +#include "triton/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVMBase.h" +#include "triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h" +#include "triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" + +namespace { + +class ExternCallOpConversion + : public ConvertOpToLLVMPattern { +public: + ExternCallOpConversion(const LLVMTypeConverter &converter, + const PatternBenefit &benefit) + : ConvertOpToLLVMPattern(converter, benefit) {} + + LogicalResult + matchAndRewrite(triton::ExternCallOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op->getLoc(); + + if (op->getNumResults() > 1) { + llvm::errs() << "ExternCallConversion does not support multi outs."; + return failure(); + } + + LLVM::LLVMVoidType voidTy = void_ty(op->getContext()); + auto newOperands = adaptor.getOperands(); + Type retType = + op->getNumResults() == 0 + ? voidTy + : this->getTypeConverter()->convertType(op->getResult(0).getType()); + std::string funcName = op.getSymbol().str(); + StringRef libname = op.getLibname(); + StringRef libpath = op.getLibpath(); + + Operation *externCallOp; + Type funcType = mlir::triton::gpu::getFunctionType(retType, newOperands); + LLVM::LLVMFuncOp funcOp = mlir::triton::gpu::appendOrGetExternFuncOp( + rewriter, op, funcName, funcType, libname, libpath); + externCallOp = LLVM::createLLVMCallOp(rewriter, loc, funcOp, newOperands); + + if (op->getNumResults() == 0) { + rewriter.eraseOp(op); + } else { + rewriter.replaceOp(op, externCallOp->getResult(0)); + } + + return success(); + } +}; + +} // namespace + +void mlir::triton::populateExternCallOpToLLVMPattern( + LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, + const TargetInfoBase &targetInfo, PatternBenefit benefit) { + patterns.add(typeConverter, benefit); +} diff --git a/lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp b/lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp index b49871486d..45c6134655 100644 --- a/lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp +++ b/lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp @@ -600,6 +600,7 @@ void populateTritonPatterns(TritonGPUTypeConverter &typeConverter, GenericOpPattern, GenericOpPattern, GenericOpPattern, + GenericOpPattern, GenericOpPattern, GenericOpPattern, GenericOpPattern, diff --git a/lib/Dialect/Triton/IR/Ops.cpp b/lib/Dialect/Triton/IR/Ops.cpp index 27fa26554d..854ae1f2bc 100644 --- a/lib/Dialect/Triton/IR/Ops.cpp +++ b/lib/Dialect/Triton/IR/Ops.cpp @@ -1313,6 +1313,24 @@ Speculation::Speculatability ExternElementwiseOp::getSpeculatability() { return Speculation::NotSpeculatable; } +// -- ExternCallOp -- +void ExternCallOp::getEffects( + SmallVectorImpl> + &effects) { + if (getPure()) + return; + effects.emplace_back(MemoryEffects::Write::get(), + SideEffects::DefaultResource::get()); + effects.emplace_back(MemoryEffects::Read::get(), + SideEffects::DefaultResource::get()); +} + +Speculation::Speculatability ExternCallOp::getSpeculatability() { + if (getPure()) + return Speculation::Speculatable; + return Speculation::NotSpeculatable; +} + // -- GatherOp -- LogicalResult GatherOp::verify() { RankedTensorType indicesTy = getIndices().getType(); diff --git a/python/src/ir.cc b/python/src/ir.cc index e5591dcb08..9abfe8472f 100644 --- a/python/src/ir.cc +++ b/python/src/ir.cc @@ -1673,6 +1673,14 @@ void init_triton_ir(py::module &&m) { return self.create(retType, argList, libName, libPath, symbol, isPure); }) + .def("create_extern_call", + [](TritonOpBuilder &self, const std::string &libName, + const std::string &libPath, const std::string &symbol, + std::vector &argList, const std::vector &retTypes, + bool isPure) -> OpState { + return self.create(retTypes, argList, libName, + libPath, symbol, isPure); + }) // Built-in instruction .def("create_get_program_id", [](TritonOpBuilder &self, int axis) -> Value { diff --git a/python/triton/experimental/tle/language/raw/__init__.py b/python/triton/experimental/tle/language/raw/__init__.py index d66f6cd381..790c65cf3f 100644 --- a/python/triton/experimental/tle/language/raw/__init__.py +++ b/python/triton/experimental/tle/language/raw/__init__.py @@ -1,3 +1,3 @@ -from .core import call, call_smem +from .core import call, call_smem, call_nvshmem -__all__ = ["call", "call_smem"] +__all__ = ["call", "call_smem", "call_nvshmem"] diff --git a/python/triton/experimental/tle/language/raw/core.py b/python/triton/experimental/tle/language/raw/core.py index d7029380eb..7e25385313 100644 --- a/python/triton/experimental/tle/language/raw/core.py +++ b/python/triton/experimental/tle/language/raw/core.py @@ -46,3 +46,8 @@ def call_smem(func, args, _semantic=None): return buffer_tensors[0] else: return tl.tuple(buffer_tensors) + + +@builtin +def call_nvshmem(func, outputs, inputs, _semantic=None): + func.make_cubin() diff --git a/python/triton/experimental/tle/raw/cuda/runtime.py b/python/triton/experimental/tle/raw/cuda/runtime.py index 9015f2bf7f..ce264a8ded 100644 --- a/python/triton/experimental/tle/raw/cuda/runtime.py +++ b/python/triton/experimental/tle/raw/cuda/runtime.py @@ -3,21 +3,26 @@ from pathlib import Path import subprocess from typing import Any, Final +import torch from triton._C.libtriton import llvm # pyright: ignore[reportMissingImports] from triton._C.libtriton.tle.llvm import parse_llvm_ir # pyright: ignore[reportMissingImports] # TODO: We use cli tools to compile CUDA code temporarily, and plan to replace it with LLVM components Python bindings in the future. CLANG = os.getenv("CLANG", "clang") +NVCC = os.getenv("NVCC", "nvcc") class CUDAJITFunction(object): def __init__(self, fn: Any, file: Path, *args, **kwargs) -> None: - super().__init__(*args, **kwargs) + super().__init__() self.fn: Final[Any] = fn self.code: Final[str] = file.read_text() + self.file: Final[Path] = file + self.libs = kwargs.get("library", {}) self.__triton_builtin__: Final[bool] = True + os.environ["USE_CLANG"] = "True" def make_llvm(self, mlir_context) -> str: build = subprocess.run( @@ -40,3 +45,22 @@ def make_llvm(self, mlir_context) -> str: llvm_context = llvm.context() module = parse_llvm_ir(build.stdout.decode(), llvm_context, mlir_context) return f"{module}" + + def make_cubin(self): + src = self.file + dst = Path(src).with_suffix('.o') + include_dirs = [] + for lib_name, lib_path in self.libs.items(): + # TODO: Remove the method of passing information by setting environment variables. + os.environ[(lib_name + "_home").upper()] = lib_path + include_dirs.append(os.path.join(lib_path, "include")) + include_flags = [f"-I{inc_dir}" for inc_dir in include_dirs] + prop = torch.cuda.get_device_properties(torch.cuda.current_device()) + arch = f"-arch=sm_{prop.major}{prop.minor}" + build = subprocess.run([NVCC, "-rdc=true", arch, *include_flags, "--extended-lambda", "-c", "-o", dst, src], + capture_output=True) + assert build.returncode == 0, (f"nvcc failed\nstderr:\n{build.stderr.decode()}") + # TODO: Remove the method of passing information by setting environment variables. + os.environ["USE_NVCC"] = 'True' + os.environ["CUDA_CUBIN"] = str(dst) + return diff --git a/python/triton/language/core.py b/python/triton/language/core.py index 439671c3a1..5a9c56b12c 100644 --- a/python/triton/language/core.py +++ b/python/triton/language/core.py @@ -3433,6 +3433,94 @@ def binary_op_type_legalization(lhs, rhs, semantic): return semantic.binary_op_type_checking_impl(lhs, rhs) +def dispatch_ec(func, lib_name: str, lib_path: str, args: list, arg_type_symbol_dict: dict, is_pure: bool, + _semantic=None): + ''' + Dispatch a function to a library + :param func: the function to dispatch + :param lib_name: the name of the library + :param lib_path: the path of the library + :param args: the arguments of the function + :param arg_type_symbol_dict: the type of the arguments + :param ret_shape: the shape of the return value + :param _semantic: the builder + :return: the return value of the function + ''' + if len(arg_type_symbol_dict) == 0: + raise ValueError("arg_type_symbol_dict is empty") + + num_args = len(list(arg_type_symbol_dict.keys())[0]) + if len(args) != num_args: + raise ValueError(f"length of input args does not match." + f"Expect {len(args)}, got {num_args}") + + arg_types = [] + arg_list = [] + for arg in args: + if isinstance(arg, tensor): + arg_types.append(arg.dtype) + arg_list.append(arg.handle) + else: + arg_types.append(type(arg)) + arg_list.append(arg) + arg_types = tuple(arg_types) + + if arg_types not in arg_type_symbol_dict: + raise ValueError(f"input arg type does not match." + f"Expect one of {arg_type_symbol_dict.keys()}, got {arg_types}") + else: + symbol = arg_type_symbol_dict[arg_types][0] + ret_types = arg_type_symbol_dict[arg_types][1] + if not isinstance(ret_types, (builtins.list, builtins.tuple)): + ret_types = [ret_types] + + if symbol == "": + raise ValueError("Symbol can not be empty") + call = func(lib_name, lib_path, symbol, arg_list, [ret_type.to_ir(_semantic.builder) for ret_type in ret_types], + is_pure) + + if len(ret_types) == 0: + return tensor(call, void) + if len(ret_types) == 1: + return tensor(call.get_result(0), ret_types[0]) + return tuple(tensor(call.get_result(i), ty) for i, ty in enumerate(ret_types)) + + +@builtin +def extern_call(lib_name: str, lib_path: str, args: list, arg_type_symbol_dict: dict, is_pure: bool, _semantic=None): + ''' + Dispatch an function to a library + :param lib_name: the name of the library + :param lib_path: the path of the library + :param args: the arguments of the function + :param arg_type_symbol_dict: the type of the arguments + :param is_pure: whether the function is pure + :param _semantic: the semantic + :return: the return value of the function + ''' + dispatch_args = args.copy() + all_scalar = True + arg_types = [] + for i in builtins.range(len(dispatch_args)): + dispatch_args[i] = _semantic.to_tensor(dispatch_args[i]) + arg_types.append(dispatch_args[i].dtype) + if dispatch_args[i].type.is_block(): + all_scalar = False + if not all_scalar: + raise ValueError("extern call only support inputs with scalr type") + + if len(arg_type_symbol_dict) == 0: + raise ValueError("arg_type_symbol_dict is empty") + + num_args = len(list(arg_type_symbol_dict.keys())[0]) + if len(args) != num_args: + raise ValueError(f"length of input args does not match." + f"Expect {len(args)}, got {num_args}") + + func = _semantic.builder.create_extern_call + return dispatch_ec(func, lib_name, lib_path, dispatch_args, arg_type_symbol_dict, is_pure, _semantic) + + def extern(fn): """A decorator for external functions.""" return builtin(fn) diff --git a/python/tutorials/tle/raw/cuda/01-vector-add-half2.cu b/python/tutorials/tle/raw/cuda/01-vector-add-half2.cu new file mode 100644 index 0000000000..e3adc6d995 --- /dev/null +++ b/python/tutorials/tle/raw/cuda/01-vector-add-half2.cu @@ -0,0 +1,21 @@ +#include + +__device__ void +VectorAddHalf2(__attribute__((address_space(1))) __half *C, + __attribute__((address_space(1))) const __half *A, + __attribute__((address_space(1))) const __half *B, const int N) { + int idx = blockIdx.x * blockDim.x + threadIdx.x; + + const __half2 *A2 = reinterpret_cast(A); + const __half2 *B2 = reinterpret_cast(B); + __half2 *C2 = reinterpret_cast<__half2 *>(C); + + for (int i = idx; i < N / 2; i += blockDim.x * gridDim.x) { + C2[i] = __hadd2(A2[i], B2[i]); + } + + if (idx == 0 && N % 2 != 0) { + int last = N - 1; + C[last] = __hadd(A[last], B[last]); + } +} diff --git a/python/tutorials/tle/raw/cuda/01-vector-add-half2.py b/python/tutorials/tle/raw/cuda/01-vector-add-half2.py new file mode 100644 index 0000000000..1d804a68c5 --- /dev/null +++ b/python/tutorials/tle/raw/cuda/01-vector-add-half2.py @@ -0,0 +1,42 @@ +from pathlib import Path + +import torch +import triton +import triton.language as tl +from triton.experimental.tle.raw import dialect +import triton.experimental.tle.language.raw as tle_raw + +DEVICE = triton.runtime.driver.active.get_active_torch_device() + + +@dialect(name="cuda", file=Path(__file__).parent / "01-vector-add-half2.cu") +def edsl_half2(*args, **kwargs): + ... + + +@triton.jit +def add_half2_kernel( + x_ptr, + y_ptr, + output_ptr, + n_elements, + BLOCK_SIZE: tl.constexpr, +): + tle_raw.call(edsl_half2, [output_ptr, x_ptr, y_ptr, n_elements]) + + +def add_half2(x: torch.Tensor, y: torch.Tensor): + output = torch.empty_like(x) + assert x.device == DEVICE and y.device == DEVICE and output.device == DEVICE + n_elements = output.numel() + grid = lambda meta: (triton.cdiv(n_elements // 2, meta["BLOCK_SIZE"]), ) + add_half2_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=1024) + return output + + +if __name__ == "__main__": + x = torch.randn(16384 * 256, device=DEVICE, dtype=torch.float16) + y = torch.randn(16384 * 256, device=DEVICE, dtype=torch.float16) + z_half2 = add_half2(x, y) + + assert torch.allclose(x + y, z_half2), (x + y, z_half2) diff --git a/python/tutorials/tle/raw/cuda/04-per-token-group-quant-fp8/per-token-group-quant-fp8.cu b/python/tutorials/tle/raw/cuda/04-per-token-group-quant-fp8/per-token-group-quant-fp8.cu new file mode 100644 index 0000000000..ab9c67ce60 --- /dev/null +++ b/python/tutorials/tle/raw/cuda/04-per-token-group-quant-fp8/per-token-group-quant-fp8.cu @@ -0,0 +1,132 @@ +#include "vectorization.cuh" +#include "vectorization_utils.cuh" +#include +#include + +__device__ __forceinline__ float GroupReduceMax(float val) { + unsigned mask = threadIdx.x % 32 >= 16 ? 0xffff0000 : 0x0000ffff; + + val = fmaxf(val, __shfl_xor_sync(mask, val, 8)); + val = fmaxf(val, __shfl_xor_sync(mask, val, 4)); + val = fmaxf(val, __shfl_xor_sync(mask, val, 2)); + val = fmaxf(val, __shfl_xor_sync(mask, val, 1)); + return val; +} + +// template +__device__ __forceinline__ float +ComputeGroupScale(const float *__restrict__ group_input, + float *__restrict__ smem_group, const int group_size, + const int lane_id, const int threads_per_group, + const float eps, const float max_8bit) { + float local_absmax = eps; + + constexpr int vec_size = 16 / sizeof(float); + + // copy global -> shared & compute absmax + auto scalar_op_cache = [&] __device__(float &dst, const float &src) { + float abs_v = fabsf(static_cast(src)); + local_absmax = fmaxf(local_absmax, abs_v); + dst = src; + }; + + vllm::vectorize_with_alignment(group_input, // in + smem_group, // out (shared) + group_size, // elements per group + lane_id, // thread id + threads_per_group, // stride in group + scalar_op_cache); // scalar handler + + local_absmax = GroupReduceMax(local_absmax); + + float y_s = local_absmax / max_8bit; + // if constexpr (SCALE_UE8M0) { + // y_s = exp2f(ceilf(log2f(fmaxf(fabsf(y_s), 1e-10f)))); + // } + + return y_s; +} + +// template +__device__ __forceinline__ void +QuantizeGroup(const float *__restrict__ smem_group, + __nv_fp8_e4m3 *__restrict__ group_output, const int group_size, + const int lane_id, const int threads_per_group, const float y_s, + const float min_8bit, const float max_8bit) { + constexpr int vec_size = 16 / sizeof(float); + + // quantize shared -> global 8-bit + auto scalar_op_quant = [&] __device__(__nv_fp8_e4m3 & dst, const float &src) { + float q = fminf(fmaxf(static_cast(src) / y_s, min_8bit), max_8bit); + dst = __nv_fp8_e4m3(q); + }; + + vllm::vectorize_with_alignment( + smem_group, // in (shared) + group_output, // out (global quant tensor) + group_size, // elements + lane_id, // tid + threads_per_group, // stride + scalar_op_quant); // scalar handler +} + +// T: float; DST_DTYPE: __nv_fp8_e4m3 +// template +// __global__ void per_token_group_quant_8bit_kernel( +extern "C" __device__ void per_token_group_quant_8bit( + const float *__restrict__ input, void *__restrict__ output_q, + float *__restrict__ output_s, const int group_size, const int num_groups, + const int groups_per_block, const float eps, const float min_8bit, + const float max_8bit) { + const int threads_per_group = 16; + const int64_t local_group_id = threadIdx.x / threads_per_group; + const int lane_id = threadIdx.x % threads_per_group; + + const int64_t block_group_id = blockIdx.x * groups_per_block; + const int64_t global_group_id = block_group_id + local_group_id; + const int64_t block_group_offset = global_group_id * group_size; + + static_assert(sizeof(float) % sizeof(float) == 0); + + const float *group_input = input + block_group_offset; + __nv_fp8_e4m3 *group_output = + static_cast<__nv_fp8_e4m3 *>(output_q) + block_group_offset; + float *scale_output; + + // bool IS_COLUMN_MAJOR = false; + // if (IS_COLUMN_MAJOR) { + // const int num_elems_per_pack = + // static_cast(sizeof(float) / sizeof(float)); + // const int scale_num_rows_element = scale_num_rows * num_elems_per_pack; + // const int row_idx = global_group_id / scale_num_rows_element; + // const int col_idx_raw = global_group_id % scale_num_rows_element; + // const int col_idx = col_idx_raw / num_elems_per_pack; + // const int pack_idx = col_idx_raw % num_elems_per_pack; + // scale_output = reinterpret_cast(output_s) + + // (col_idx * scale_stride * num_elems_per_pack + + // row_idx * num_elems_per_pack + pack_idx); + // } else { + scale_output = output_s + global_group_id; + // } + + // shared memory to cache each group's data to avoid double DRAM reads. + extern __shared__ __align__(16) char smem_raw[]; + float *smem = reinterpret_cast(smem_raw); + float *smem_group = smem + local_group_id * group_size; + + const float y_s = + ComputeGroupScale(group_input, smem_group, group_size, lane_id, + threads_per_group, eps, max_8bit); + + float y_s_quant = y_s; + + if (lane_id == 0) { + *scale_output = y_s_quant; + } + + __syncthreads(); + + QuantizeGroup(smem_group, group_output, group_size, lane_id, + threads_per_group, y_s, min_8bit, max_8bit); +} diff --git a/python/tutorials/tle/raw/cuda/04-per-token-group-quant-fp8/per-token-group-quant-fp8.py b/python/tutorials/tle/raw/cuda/04-per-token-group-quant-fp8/per-token-group-quant-fp8.py new file mode 100644 index 0000000000..9611613ba8 --- /dev/null +++ b/python/tutorials/tle/raw/cuda/04-per-token-group-quant-fp8/per-token-group-quant-fp8.py @@ -0,0 +1,272 @@ +from pathlib import Path + +from typing import Optional +import logging +import torch +import triton +import triton.language as tl +from triton.experimental.tle.raw import dialect +import triton.experimental.tle.language.raw as tle_raw + +from triton.language.extra.cuda import libnvshmem_device + +torch.cuda.set_device(1) +DEVICE = triton.runtime.driver.active.get_active_torch_device() +logger = logging.getLogger(__name__) + + +@dialect(name="cuda", file=(Path(__file__).parent / "per-token-group-quant-fp8.cu").resolve(), + library={"torch": "/home/zyuli/miniconda3/envs/flagtree_nvshmem/lib/python3.12/site-packages/torch/"}) +def edsl(*args, **kwargs): + ... + + +@triton.jit +def test_kernel( + x_ptr, + x_q_ptr, + x_s_ptr, + group_size, + num_groups, + groups_per_block, + eps, + fp8_min, + fp8_max, +): + tle_raw.call_nvshmem(edsl, [], []) + libnvshmem_device.per_token_group_quant_8bit(x_ptr, x_q_ptr, x_s_ptr, group_size, num_groups, groups_per_block, eps, + fp8_min, fp8_max) + + +def get_groups_per_block(num_groups: int) -> int: + # Removing this branch gives better performance. + # if (num_groups % 16 == 0): + # return 16 + if (num_groups % 8 == 0): + return 8 + elif (num_groups % 4 == 0): + return 4 + elif (num_groups % 2 == 0): + return 2 + else: + return 1 + + +def per_token_group_quant_fp8_tle( + x: torch.Tensor, + group_size: int, + eps: float = 1e-10, + dtype: Optional[torch.dtype] = None, + # column_major_scales: bool = False, + # scale_ue8m0: bool = False, +): + logger.debug("GEMS PER TOKEN GROUP QUANT FP8") + assert x.shape[-1] % group_size == 0, (f"the last dimension of `x` {x.shape[-1]} must be divisible " + f"by `group_size` {group_size}") + assert x.stride(-1) == 1, "`x` groups must be contiguous" + + fp8_dtype = torch.float8_e4m3fn if dtype is None else dtype + finfo = torch.finfo(fp8_dtype) + fp8_min = finfo.min + fp8_max = finfo.max + + x_q = torch.empty_like(x, device=x.device, dtype=fp8_dtype) + shape = x.shape[:-1] + (x.shape[-1] // group_size, ) + x_s = torch.empty(shape, device=x.device, dtype=torch.float32) + + # num_groups + num_groups = x.numel() // group_size + groups_per_block = get_groups_per_block(num_groups) + + # num_blocks + THREADS_PER_GROUP = 16 + num_blocks = num_groups // groups_per_block + num_warps = max(groups_per_block * THREADS_PER_GROUP // 32, 1) + + # The .cu device function uses `extern __shared__` for groups_per_block * group_size floats. + # Triton's compiler cannot infer this smem requirement from the extern_call, so we patch + # packed_metadata after warmup compilation to include the extra shared memory bytes. + smem_bytes = groups_per_block * group_size * x.element_size() # float32 = 4 bytes + + kernel = test_kernel.run( + x, + x_q, + x_s, + group_size, + num_groups, + groups_per_block, + eps, + fp8_min, + fp8_max, + grid=(num_blocks, ), + warmup=True, + num_warps=num_warps, + ) + + # Resolve async future if needed + if hasattr(kernel, "result"): + kernel = kernel.result() + + old_meta = kernel.packed_metadata + new_shared = max(old_meta[2], smem_bytes) + kernel.packed_metadata = old_meta[:2] + (new_shared, ) + old_meta[3:] + + test_kernel[(num_blocks, )]( + x, + x_q, + x_s, + group_size, + num_groups, + groups_per_block, + eps, + fp8_min, + fp8_max, + num_warps=num_warps, + ) + + return x_q, x_s + + +@triton.jit +def _per_token_group_quant_fp8( + y_ptr, + y_q_ptr, + y_s_ptr, + group_size, + y_num_columns, + y_row_stride, + eps, + fp8_min, + fp8_max, + scale_ue8m0, + BLOCK: tl.constexpr, +): + groups_per_row = y_num_columns // group_size + + g_id = tl.program_id(0) + row = g_id // groups_per_row + row_g_id = g_id % groups_per_row + + y_ptr += (row * y_row_stride) + (row_g_id * group_size) + y_q_ptr += g_id * group_size + y_s_ptr += g_id + + cols = tl.arange(0, BLOCK) + mask = cols < group_size + + y = tl.load(y_ptr + cols, mask=mask, other=0.0).to(tl.float32) + _absmax = tl.maximum(tl.max(tl.abs(y)), eps) + y_s = _absmax / fp8_max + + if scale_ue8m0: + y_s = tl.exp2(tl.ceil(tl.log2(tl.maximum(tl.abs(y_s), 1e-10)))) + + y_q = tl.clamp(y / y_s, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty) + + tl.store(y_q_ptr + cols, y_q, mask=mask) + tl.store(y_s_ptr, y_s) + + +def per_token_group_quant_fp8_triton( + x: torch.Tensor, + group_size: int, + eps: float = 1e-10, + dtype: Optional[torch.dtype] = None, + column_major_scales: bool = False, + scale_ue8m0: bool = False, +): + logger.debug("GEMS PER TOKEN GROUP QUANT FP8") + # dtype: The dype of output tensor. Note that only `torch.float8_e4m3fn` + fp8_dtype = torch.float8_e4m3fn if dtype is None else dtype + assert x.shape[-1] % group_size == 0, (f"the last dimension of `x` {x.shape[-1]} must be divisible " + f"by `group_size` {group_size}") + assert x.stride(-1) == 1, "`x` groups must be contiguous" + + finfo = torch.finfo(fp8_dtype) + fp8_min = finfo.min + fp8_max = finfo.max + + x_q = torch.empty_like(x, device=x.device, dtype=fp8_dtype) + M = x.numel() // group_size + N = group_size + + if column_major_scales: + shape = (x.shape[-1] // group_size, ) + x.shape[:-1] + x_s = torch.empty(shape, device=x.device, dtype=torch.float32).permute(-1, -2) + else: + shape = x.shape[:-1] + (x.shape[-1] // group_size, ) + x_s = torch.empty(shape, device=x.device, dtype=torch.float32) + + BLOCK = triton.next_power_of_2(N) + num_warps = min(max(BLOCK // 256, 1), 8) + num_stages = 1 + _per_token_group_quant_fp8[(M, )]( + x, + x_q, + x_s, + group_size, + x.shape[1], + x.stride(0), + eps, + fp8_min=fp8_min, + fp8_max=fp8_max, + scale_ue8m0=scale_ue8m0, + BLOCK=BLOCK, + num_warps=num_warps, + num_stages=num_stages, + ) + + return x_q, x_s + + +if __name__ == "__main__": + x = torch.randn((16384, 32768), device=DEVICE, dtype=torch.float32) + group_size = 128 + + x_q_triton, x_s_triton = per_token_group_quant_fp8_triton(x, group_size) + x_q_tle, x_s_tle = per_token_group_quant_fp8_tle(x, group_size) + + q_tri = x_q_triton.to(torch.float32) + q_tle = x_q_tle.to(torch.float32) + + b_tri = x_q_triton.view(torch.int8).to(torch.int16) # promote to avoid int8 overflow + b_tle = x_q_tle.view(torch.int8).to(torch.int16) + bit_diff = (b_tri - b_tle).abs() # 0 = exact match, 1 = 1-ULP, >1 = real bug + num_exact = bit_diff.eq(0).sum().item() + num_1ulp = bit_diff.eq(1).sum().item() + num_beyond = bit_diff.gt(1).sum().item() + if num_beyond == 0: + if num_1ulp == 0: + print("✅ x_q Triton and TLE match (bit-exact)") + else: + print(f"✅ x_q Triton and TLE match (1-ULP diff={num_1ulp}, " + f"expected from RTZ vs RTNE rounding)") + else: + q_tri = x_q_triton.to(torch.float32) + q_tle = x_q_tle.to(torch.float32) + float_diff = (q_tri - q_tle).abs() + print(f"❌ x_q Triton and TLE differ: " + f"exact={num_exact}, 1-ULP={num_1ulp}, >1-ULP={num_beyond}") + beyond_idx = bit_diff.gt(1).nonzero() + for idx in beyond_idx[:10]: + r, c = idx[0].item(), idx[1].item() + group_id = r * (q_tri.shape[1] // group_size) + c // group_size + pos_in_group = c % group_size + print(f" [{r},{c}] group={group_id} pos_in_group={pos_in_group}" + f" triton_bits={b_tri[r,c].item():4d} tle_bits={b_tle[r,c].item():4d}" + f" bit_diff={bit_diff[r,c].item()}" + f" triton={q_tri[r,c].item():.1f} tle={q_tle[r,c].item():.1f}" + f" float_diff={float_diff[r,c].item():.1f}" + f" x={x[r,c].item():.6f} scale={x_s_triton.view(-1)[group_id].item():.8f}") + + if torch.allclose(x_s_triton, x_s_tle, atol=0.125, rtol=0): + print("✅ x_s Triton and TLE match") + else: + print("❌ x_s Triton and TLE differ") + + # perf + mean_ms_triton = triton.testing.do_bench(lambda: per_token_group_quant_fp8_triton(x, group_size)) + mean_ms_tle = triton.testing.do_bench(lambda: per_token_group_quant_fp8_tle(x, group_size)) + + print(f"Triton time: {mean_ms_triton:.3f} ms") + print(f"TLE time: {mean_ms_tle:.3f} ms") diff --git a/python/tutorials/tle/raw/cuda/04-per-token-group-quant-fp8/vectorization.cuh b/python/tutorials/tle/raw/cuda/04-per-token-group-quant-fp8/vectorization.cuh new file mode 100644 index 0000000000..6348c97b4c --- /dev/null +++ b/python/tutorials/tle/raw/cuda/04-per-token-group-quant-fp8/vectorization.cuh @@ -0,0 +1,29 @@ +#pragma once +/** + * __device__ datatypes vectorized by 4 + */ + +// Include both AMD and NVIDIA fp8 types to avoid circular import +#include +#include + +namespace vllm { + +// Vectorization containers +template +struct __align__(vec_size * sizeof(scalar_t)) vec_n_t { + scalar_t val[vec_size]; +}; + +template +struct __align__(vec_size * sizeof(quant_type_t)) q8_n_t { + static_assert(std::is_same_v || + std::is_same_v || + std::is_same_v); + quant_type_t val[vec_size]; +}; + +template using vec4_t = vec_n_t; +template using q8x4_t = q8_n_t; + +} // namespace vllm diff --git a/python/tutorials/tle/raw/cuda/04-per-token-group-quant-fp8/vectorization_utils.cuh b/python/tutorials/tle/raw/cuda/04-per-token-group-quant-fp8/vectorization_utils.cuh new file mode 100644 index 0000000000..be42dad348 --- /dev/null +++ b/python/tutorials/tle/raw/cuda/04-per-token-group-quant-fp8/vectorization_utils.cuh @@ -0,0 +1,176 @@ +#pragma once +#include "vectorization.cuh" + +namespace vllm { + +template +struct DefaultVecOp { + ScaOp scalar_op; + + __device__ __forceinline__ void + operator()(vec_n_t &dst, + const vec_n_t &src) const { +#pragma unroll + for (int i = 0; i < VEC_SIZE; ++i) { + scalar_op(dst.val[i], src.val[i]); + } + } +}; + +template +__device__ inline void +vectorize_with_alignment(const InT *in, OutT *out, int len, int tid, int stride, + VecOp &&vec_op, // vec_n_t -> vec_n_t + ScaOp &&scalar_op) { // InT -> OutT + static_assert(VEC_SIZE > 0 && (VEC_SIZE & (VEC_SIZE - 1)) == 0, + "VEC_SIZE must be a positive power-of-two"); + constexpr int WIDTH = VEC_SIZE * sizeof(InT); // eg: 64 B + uintptr_t addr = reinterpret_cast(in); + + // fast path when the whole region is already aligned + // Note: currently the output is guaranteed to be same as the input, so we + // don't check it here, comments here just for future reference. + bool can_vec = ((addr & (WIDTH - 1)) == 0) && ((len & (VEC_SIZE - 1)) == 0); + if (can_vec) { + int num_vec = len / VEC_SIZE; + + using vin_t = vec_n_t; + using vout_t = vec_n_t; + auto *v_in = reinterpret_cast(in); + auto *v_out = reinterpret_cast(out); + + for (int i = tid; i < num_vec; i += stride) { + vout_t tmp; + // Make a local copy of the entire pack + vin_t src = v_in[i]; // <- encourages a single vector ld + vec_op(tmp, src); + v_out[i] = tmp; // <- encourages a single vector st + } + return; + } + + int misalignment_offset = addr & (WIDTH - 1); // addr % 64 + int alignment_bytes = WIDTH - misalignment_offset; // 64 - (addr % 64) + int prefix_elems = alignment_bytes & (WIDTH - 1); // handle 64 + prefix_elems /= sizeof(InT); + prefix_elems = min(prefix_elems, len); // 0 ≤ prefix < 16 + + // 1. prefill the when it is unsafe to vectorize + for (int i = tid; i < prefix_elems; i += stride) { + scalar_op(out[i], in[i]); + } + + in += prefix_elems; + out += prefix_elems; + len -= prefix_elems; + + int num_vec = len / VEC_SIZE; + using vin_t = vec_n_t; + using vout_t = vec_n_t; + auto *v_in = reinterpret_cast(in); + auto *v_out = reinterpret_cast(out); + + // 2. vectorize the main part + for (int i = tid; i < num_vec; i += stride) { + vout_t tmp; + // Make a local copy of the entire pack + vin_t src = v_in[i]; // <- encourages a single vector ld + vec_op(tmp, src); + v_out[i] = tmp; // <- encourages a single vector st + } + + // 3. handle the tail + int tail_start = num_vec * VEC_SIZE; + for (int i = tid + tail_start; i < len; i += stride) { + scalar_op(out[i], in[i]); + } +} + +template +__device__ __forceinline__ void +vectorize_with_alignment(const InT *in, OutT *out, int len, int tid, int stride, + ScaOp &&scalar_op) { + using Vec = DefaultVecOp>; + vectorize_with_alignment(in, out, len, tid, stride, Vec{scalar_op}, + std::forward(scalar_op)); +} + +template struct DefaultReadVecOp { + ScaOp scalar_op; + + __device__ __forceinline__ void + operator()(const vec_n_t &src) const { +#pragma unroll + for (int i = 0; i < VEC_SIZE; ++i) { + scalar_op(src.val[i]); + } + } +}; + +// read-only version: iterate over the input with alignment guarantees +template +__device__ inline void +vectorize_read_with_alignment(const InT *in, int len, int tid, int stride, + VecOp &&vec_op, ScaOp &&scalar_op) { + static_assert(VEC_SIZE > 0 && (VEC_SIZE & (VEC_SIZE - 1)) == 0, + "VEC_SIZE must be a positive power-of-two"); + constexpr int WIDTH = VEC_SIZE * sizeof(InT); + uintptr_t addr = reinterpret_cast(in); + + // fast path when the whole region is already aligned + bool can_vec = ((addr & (WIDTH - 1)) == 0) && ((len & (VEC_SIZE - 1)) == 0); + if (can_vec) { + int num_vec = len / VEC_SIZE; + + using vin_t = vec_n_t; + auto *v_in = reinterpret_cast(in); + + for (int i = tid; i < num_vec; i += stride) { + vin_t tmp = v_in[i]; + vec_op(tmp); + } + return; + } + + int misalignment_offset = addr & (WIDTH - 1); + int alignment_bytes = WIDTH - misalignment_offset; + int prefix_elems = alignment_bytes & (WIDTH - 1); + prefix_elems /= sizeof(InT); + prefix_elems = min(prefix_elems, len); + + // 1. handle the possibly unaligned prefix with scalar access. + for (int i = tid; i < prefix_elems; i += stride) { + scalar_op(in[i]); + } + + in += prefix_elems; + len -= prefix_elems; + + int num_vec = len / VEC_SIZE; + using vin_t = vec_n_t; + auto *v_in = reinterpret_cast(in); + + // 2. vectorized traversal of the main aligned region. + for (int i = tid; i < num_vec; i += stride) { + vec_op(v_in[i]); + } + + // 3. handle remaining tail elements. + int tail_start = num_vec * VEC_SIZE; + for (int i = tid + tail_start; i < len; i += stride) { + scalar_op(in[i]); + } +} + +// overload that requires only a scalar_op +template +__device__ __forceinline__ void +vectorize_read_with_alignment(const InT *in, int len, int tid, int stride, + ScaOp &&scalar_op) { + using Vec = DefaultReadVecOp>; + vectorize_read_with_alignment(in, len, tid, stride, Vec{scalar_op}, + std::forward(scalar_op)); +} + +} // namespace vllm diff --git a/third_party/nvidia/backend/compiler.py b/third_party/nvidia/backend/compiler.py index 632b79075d..0021ec85f8 100644 --- a/third_party/nvidia/backend/compiler.py +++ b/third_party/nvidia/backend/compiler.py @@ -422,7 +422,8 @@ def make_llir(self, src, metadata, options, capability): passes.common.add_symbol_dce(pm) passes.convert.add_nvvm_to_llvm(pm) - if not knobs.compilation.disable_line_info and not knobs.compilation.dump_ir_extract_di_local_variables: + use_clang = os.getenv("USE_CLANG", '').lower() in ('1', 'true') + if not knobs.compilation.disable_line_info and not knobs.compilation.dump_ir_extract_di_local_variables and not use_clang: passes.llvmir.add_di_scope(pm) if CUDABackend.instrumentation: @@ -487,12 +488,25 @@ def make_llir(self, src, metadata, options, capability): def make_ptx(self, src, metadata, opt, capability): ptx_version = get_ptx_version_from_options(opt, self.target.arch) - triple = 'nvptx64-nvidia-cuda' proc = sm_arch_from_capability(capability) - features = get_features(opt, self.target.arch) - flags = ["nvptx-mad-wide-opt"] - ret = llvm.translate_to_asm(src, triple, proc, features, flags, opt.enable_fp_fusion, False) + + use_clang = os.getenv("USE_CLANG", '').lower() in ('1', 'true') + if use_clang: + with tempfile.NamedTemporaryFile(delete=False, mode='w', suffix='.llir') as fsrc: + fsrc.write(src) + fsrc.flush() + fptx = fsrc.name + '.ptx' + llc = os.getenv("llc", "llc") + llc_cmd = [llc, '-march=nvptx64', f'-mcpu={proc}', f'-mattr=+ptx{ptx_version}', fsrc.name, '-o', fptx] + build = subprocess.run(llc_cmd, capture_output=True) + assert build.returncode == 0, (f"llc failed\nstderr:\n{build.stderr.decode()}") + with open(fptx, 'r') as f: + ret = f.read() + else: + features = get_features(opt, self.target.arch) + flags = ["nvptx-mad-wide-opt"] + ret = llvm.translate_to_asm(src, triple, proc, features, flags, opt.enable_fp_fusion, False) # Find kernel names (there should only be one) names = re.findall(r".visible .entry ([a-zA-Z_][a-zA-Z0-9_]*)", ret) assert len(names) == 1 @@ -519,6 +533,8 @@ def make_cubin(self, src, metadata, opt, capability): fsrc.flush() fbin = fsrc.name + '.o' + use_nvcc = os.getenv("USE_NVCC", '').lower() in ('1', 'true') + os.environ.pop("USE_NVCC", None) debug_info = [] if knobs.compilation.disable_line_info: # This option is ignored if used without -lineinfo @@ -539,9 +555,11 @@ def make_cubin(self, src, metadata, opt, capability): # Accept more ptxas options if provided ptx_extra_options = opt.ptx_options.split(" ") if opt.ptx_options else [] + # If use nvshmem, we need to compile the ptx file into a relocatable object file and then link it with nvshmem library + compile_only = ["-c"] if use_nvcc else [] ptxas_cmd = [ - ptxas, *debug_info, *fmad, '-v', *disable_opt, *ptx_extra_options, f'--gpu-name={arch}', fsrc.name, - '-o', fbin + ptxas, *compile_only, *debug_info, *fmad, '-v', *disable_opt, *ptx_extra_options, f'--gpu-name={arch}', + fsrc.name, '-o', fbin ] try: subprocess.run(ptxas_cmd, check=True, close_fds=False, stderr=flog) @@ -581,8 +599,33 @@ def make_cubin(self, src, metadata, opt, capability): """) raise PTXASError(error) - with open(fbin, 'rb') as f: - cubin = f.read() + if use_nvcc: + NVLINK = os.getenv("NVLINK", "nvlink") + fbin_combined = fbin + ".combined.cubin" + cuda_cubin = os.getenv("CUDA_CUBIN") + nvlink_cmds = [ + NVLINK, + f"-arch={arch}", + fbin, + cuda_cubin, + "-o", + fbin_combined, + ] + try: + subprocess.run(nvlink_cmds, check=True, close_fds=False, stderr=flog) + except Exception as e: + import logging + logging.error(f"error runing nvlink: {nvlink_cmds}") + logging.exception(e) + + if use_nvcc: + with open(fbin_combined, 'rb') as f: + cubin = f.read() + if os.path.exists(fbin_combined): + os.remove(fbin_combined) + else: + with open(fbin, 'rb') as f: + cubin = f.read() if os.path.exists(fbin): os.remove(fbin) return cubin diff --git a/third_party/nvidia/language/cuda/__init__.py b/third_party/nvidia/language/cuda/__init__.py index fbececf1de..39207bf60f 100644 --- a/third_party/nvidia/language/cuda/__init__.py +++ b/third_party/nvidia/language/cuda/__init__.py @@ -1,10 +1,12 @@ from . import libdevice +from . import libnvshmem_device from .utils import (globaltimer, num_threads, num_warps, smid, convert_custom_float8_sm70, convert_custom_float8_sm80) from .gdc import (gdc_launch_dependents, gdc_wait) __all__ = [ "libdevice", + 'libnvshmem_device', "globaltimer", "num_threads", "num_warps", diff --git a/third_party/nvidia/language/cuda/libnvshmem_device.py b/third_party/nvidia/language/cuda/libnvshmem_device.py new file mode 100644 index 0000000000..3816194496 --- /dev/null +++ b/third_party/nvidia/language/cuda/libnvshmem_device.py @@ -0,0 +1,45 @@ +from triton.language import core +import triton.language as tl + + +def _pointer_type_hash(self): + return hash((self.name, self.element_ty, "tt_ptr")) + + +def patch_hash_method_for_pointer_type(): + elem_dtype_list = tl.core.dtype.SINT_TYPES + tl.core.dtype.UINT_TYPES + tl.core.dtype.FP_TYPES + tl.core.dtype.OTHER_TYPES + for elem_dtype in elem_dtype_list: + ptr_ty = type(tl.core.pointer_type(tl.core.dtype(elem_dtype))) + ptr_ty.__hash__ = _pointer_type_hash + + +patch_hash_method_for_pointer_type() + + +@core.extern +def per_token_group_quant_8bit(x_ptr, x_q_ptr, x_s_ptr, group_size, num_groups, groups_per_block, eps, fp8_min, fp_max, + _semantic=None): + return core.extern_call( + "", + "", + [ + x_ptr, + tl.cast(x_q_ptr, tl.pointer_type(core.dtype("void")), _semantic=_semantic), x_s_ptr, group_size, num_groups, + groups_per_block, eps, fp8_min, fp_max + ], + { + ( + core.pointer_type(core.dtype("fp32")), + core.pointer_type(core.dtype("void")), + core.pointer_type(core.dtype("fp32")), + core.dtype("int32"), + core.dtype("int32"), + core.dtype("int32"), + core.dtype("fp32"), + core.dtype("fp32"), + core.dtype("fp32"), + ): ("per_token_group_quant_8bit", ()), + }, + is_pure=False, + _semantic=_semantic, + ) diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TritonGPUToLLVM.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TritonGPUToLLVM.cpp index 368a666292..0fc8242c54 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TritonGPUToLLVM.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TritonGPUToLLVM.cpp @@ -222,6 +222,8 @@ struct ConvertTritonGPUToLLVM targetInfo, benefit); mlir::triton::populatePrintOpToLLVMPattern(typeConverter, patterns, targetInfo, benefit); + mlir::triton::populateExternCallOpToLLVMPattern(typeConverter, patterns, + targetInfo, benefit); mlir::triton::populateControlFlowOpToLLVMPattern(typeConverter, patterns, targetInfo, benefit); mlir::triton::NVIDIA::populateSPMDOpToLLVMPattern(typeConverter, patterns, diff --git a/third_party/tle/triton_tle_raw.cc b/third_party/tle/triton_tle_raw.cc index 96c31e2b47..61cb0bea42 100644 --- a/third_party/tle/triton_tle_raw.cc +++ b/third_party/tle/triton_tle_raw.cc @@ -102,7 +102,8 @@ createTLERawRegionByLLVMFunc(TritonOpBuilder &self, std::string_view text, assert(module && "Failed to parse LLVM IR text"); LLVM::LLVMFuncOp func = nullptr; for (auto op : module->getOps()) { - if (!op.empty() && op.getLinkage() != LLVM::Linkage::Internal) { + if (!op.empty() && op.getLinkage() != LLVM::linkage::Linkage::Internal && + op.getLinkage() != LLVM::linkage::Linkage::LinkonceODR) { if (func) { llvm_unreachable("Multiple functions found in LLVM IR text"); } else {