Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,11 @@ void populatePrintOpToLLVMPattern(LLVMTypeConverter &typeConverter,
const TargetInfoBase &targetInfo,
PatternBenefit benefit);

void populateExternCallOpToLLVMPattern(LLVMTypeConverter &typeConverter,
RewritePatternSet &patterns,
const TargetInfoBase &targetInfo,
PatternBenefit benefit);

void populateInstrumentationToLLVMPatterns(LLVMTypeConverter &typeConverter,
const TargetInfoBase &targetInfo,
RewritePatternSet &patterns,
Expand Down
26 changes: 26 additions & 0 deletions include/triton/Dialect/Triton/IR/TritonOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -845,6 +845,32 @@ def TT_MapElementwiseReturnOp: TT_Op<"map_elementwise.return",
let assemblyFormat = "attr-dict ($result^ `:` type($result))?";
}

//
// External Call op
//
def TT_ExternCallOp : TT_Op<"extern_call", [
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
ConditionallySpeculatable,
]> {

let description = [{
call an external function $symbol implemented in $libpath/$libname with $args
return $libpath/$libname:$symbol($args...)
}];

let arguments = (ins Variadic<TT_Type>:$srcs, StrAttr:$libname, StrAttr:$libpath, StrAttr:$symbol, BoolAttr:$pure);

let results = (outs Variadic<TT_Type>:$result);

let assemblyFormat = "operands attr-dict `:` functional-type(operands, $result)";

let extraClassDeclaration = [{
// Interface method for ConditionallySpeculatable.
Speculation::Speculatability getSpeculatability();
}];

}

//
// External Elementwise op
//
Expand Down
1 change: 1 addition & 0 deletions lib/Conversion/TritonGPUToLLVM/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ add_triton_library(TritonGPUToLLVM
AssertOpToLLVM.cpp
ControlFlowOpToLLVM.cpp
ConvertLayoutOpToLLVM.cpp
ExternCallOpToLLVM.cpp
ElementwiseOpToLLVM.cpp
FuncOpToLLVM.cpp
GatherOpToLLVM.cpp
Expand Down
61 changes: 61 additions & 0 deletions lib/Conversion/TritonGPUToLLVM/ExternCallOpToLLVM.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
#include "mlir/Conversion/LLVMCommon/Pattern.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Support/LLVM.h"
#include "triton/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVMBase.h"
#include "triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h"
#include "triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h"
#include "triton/Conversion/TritonGPUToLLVM/Utility.h"
#include "triton/Dialect/TritonGPU/IR/Dialect.h"

namespace {

class ExternCallOpConversion
: public ConvertOpToLLVMPattern<triton::ExternCallOp> {
public:
ExternCallOpConversion(const LLVMTypeConverter &converter,
const PatternBenefit &benefit)
: ConvertOpToLLVMPattern<triton::ExternCallOp>(converter, benefit) {}

LogicalResult
matchAndRewrite(triton::ExternCallOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto loc = op->getLoc();

if (op->getNumResults() > 1) {
llvm::errs() << "ExternCallConversion does not support multi outs.";
return failure();
}

LLVM::LLVMVoidType voidTy = void_ty(op->getContext());
auto newOperands = adaptor.getOperands();
Type retType =
op->getNumResults() == 0
? voidTy
: this->getTypeConverter()->convertType(op->getResult(0).getType());
std::string funcName = op.getSymbol().str();
StringRef libname = op.getLibname();
StringRef libpath = op.getLibpath();

Operation *externCallOp;
Type funcType = mlir::triton::gpu::getFunctionType(retType, newOperands);
LLVM::LLVMFuncOp funcOp = mlir::triton::gpu::appendOrGetExternFuncOp(
rewriter, op, funcName, funcType, libname, libpath);
externCallOp = LLVM::createLLVMCallOp(rewriter, loc, funcOp, newOperands);

if (op->getNumResults() == 0) {
rewriter.eraseOp(op);
} else {
rewriter.replaceOp(op, externCallOp->getResult(0));
}

return success();
}
};

} // namespace

void mlir::triton::populateExternCallOpToLLVMPattern(
LLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
const TargetInfoBase &targetInfo, PatternBenefit benefit) {
patterns.add<ExternCallOpConversion>(typeConverter, benefit);
}
Original file line number Diff line number Diff line change
Expand Up @@ -600,6 +600,7 @@ void populateTritonPatterns(TritonGPUTypeConverter &typeConverter,
GenericOpPattern<triton::HistogramOp>,
GenericOpPattern<triton::GatherOp>,
GenericOpPattern<triton::ExternElementwiseOp>,
GenericOpPattern<triton::ExternCallOp>,
GenericOpPattern<triton::PrintOp>,
GenericOpPattern<triton::AssertOp>,
GenericOpPattern<triton::AtomicCASOp>,
Expand Down
18 changes: 18 additions & 0 deletions lib/Dialect/Triton/IR/Ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1313,6 +1313,24 @@ Speculation::Speculatability ExternElementwiseOp::getSpeculatability() {
return Speculation::NotSpeculatable;
}

// -- ExternCallOp --
void ExternCallOp::getEffects(
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
&effects) {
if (getPure())
return;
effects.emplace_back(MemoryEffects::Write::get(),
SideEffects::DefaultResource::get());
effects.emplace_back(MemoryEffects::Read::get(),
SideEffects::DefaultResource::get());
}

Speculation::Speculatability ExternCallOp::getSpeculatability() {
if (getPure())
return Speculation::Speculatable;
return Speculation::NotSpeculatable;
}

// -- GatherOp --
LogicalResult GatherOp::verify() {
RankedTensorType indicesTy = getIndices().getType();
Expand Down
8 changes: 8 additions & 0 deletions python/src/ir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1673,6 +1673,14 @@ void init_triton_ir(py::module &&m) {
return self.create<ExternElementwiseOp>(retType, argList, libName,
libPath, symbol, isPure);
})
.def("create_extern_call",
[](TritonOpBuilder &self, const std::string &libName,
const std::string &libPath, const std::string &symbol,
std::vector<Value> &argList, const std::vector<Type> &retTypes,
bool isPure) -> OpState {
return self.create<ExternCallOp>(retTypes, argList, libName,
libPath, symbol, isPure);
})
// Built-in instruction
.def("create_get_program_id",
[](TritonOpBuilder &self, int axis) -> Value {
Expand Down
4 changes: 2 additions & 2 deletions python/triton/experimental/tle/language/raw/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from .core import call, call_smem
from .core import call, call_smem, call_nvshmem

__all__ = ["call", "call_smem"]
__all__ = ["call", "call_smem", "call_nvshmem"]
5 changes: 5 additions & 0 deletions python/triton/experimental/tle/language/raw/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,3 +46,8 @@ def call_smem(func, args, _semantic=None):
return buffer_tensors[0]
else:
return tl.tuple(buffer_tensors)


@builtin
def call_nvshmem(func, outputs, inputs, _semantic=None):
func.make_cubin()
26 changes: 25 additions & 1 deletion python/triton/experimental/tle/raw/cuda/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,21 +3,26 @@
from pathlib import Path
import subprocess
from typing import Any, Final
import torch

from triton._C.libtriton import llvm # pyright: ignore[reportMissingImports]
from triton._C.libtriton.tle.llvm import parse_llvm_ir # pyright: ignore[reportMissingImports]

# TODO: We use cli tools to compile CUDA code temporarily, and plan to replace it with LLVM components Python bindings in the future.
CLANG = os.getenv("CLANG", "clang")
NVCC = os.getenv("NVCC", "nvcc")


class CUDAJITFunction(object):

def __init__(self, fn: Any, file: Path, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
super().__init__()
self.fn: Final[Any] = fn
self.code: Final[str] = file.read_text()
self.file: Final[Path] = file
self.libs = kwargs.get("library", {})
self.__triton_builtin__: Final[bool] = True
os.environ["USE_CLANG"] = "True"

def make_llvm(self, mlir_context) -> str:
build = subprocess.run(
Expand All @@ -40,3 +45,22 @@ def make_llvm(self, mlir_context) -> str:
llvm_context = llvm.context()
module = parse_llvm_ir(build.stdout.decode(), llvm_context, mlir_context)
return f"{module}"

def make_cubin(self):
src = self.file
dst = Path(src).with_suffix('.o')
include_dirs = []
for lib_name, lib_path in self.libs.items():
# TODO: Remove the method of passing information by setting environment variables.
os.environ[(lib_name + "_home").upper()] = lib_path
include_dirs.append(os.path.join(lib_path, "include"))
include_flags = [f"-I{inc_dir}" for inc_dir in include_dirs]
prop = torch.cuda.get_device_properties(torch.cuda.current_device())
arch = f"-arch=sm_{prop.major}{prop.minor}"
build = subprocess.run([NVCC, "-rdc=true", arch, *include_flags, "--extended-lambda", "-c", "-o", dst, src],
capture_output=True)
assert build.returncode == 0, (f"nvcc failed\nstderr:\n{build.stderr.decode()}")
# TODO: Remove the method of passing information by setting environment variables.
os.environ["USE_NVCC"] = 'True'
os.environ["CUDA_CUBIN"] = str(dst)
return
88 changes: 88 additions & 0 deletions python/triton/language/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -3433,6 +3433,94 @@ def binary_op_type_legalization(lhs, rhs, semantic):
return semantic.binary_op_type_checking_impl(lhs, rhs)


def dispatch_ec(func, lib_name: str, lib_path: str, args: list, arg_type_symbol_dict: dict, is_pure: bool,
_semantic=None):
'''
Dispatch a function to a library
:param func: the function to dispatch
:param lib_name: the name of the library
:param lib_path: the path of the library
:param args: the arguments of the function
:param arg_type_symbol_dict: the type of the arguments
:param ret_shape: the shape of the return value
:param _semantic: the builder
:return: the return value of the function
'''
if len(arg_type_symbol_dict) == 0:
raise ValueError("arg_type_symbol_dict is empty")

num_args = len(list(arg_type_symbol_dict.keys())[0])
if len(args) != num_args:
raise ValueError(f"length of input args does not match."
f"Expect {len(args)}, got {num_args}")

arg_types = []
arg_list = []
for arg in args:
if isinstance(arg, tensor):
arg_types.append(arg.dtype)
arg_list.append(arg.handle)
else:
arg_types.append(type(arg))
arg_list.append(arg)
arg_types = tuple(arg_types)

if arg_types not in arg_type_symbol_dict:
raise ValueError(f"input arg type does not match."
f"Expect one of {arg_type_symbol_dict.keys()}, got {arg_types}")
else:
symbol = arg_type_symbol_dict[arg_types][0]
ret_types = arg_type_symbol_dict[arg_types][1]
if not isinstance(ret_types, (builtins.list, builtins.tuple)):
ret_types = [ret_types]

if symbol == "":
raise ValueError("Symbol can not be empty")
call = func(lib_name, lib_path, symbol, arg_list, [ret_type.to_ir(_semantic.builder) for ret_type in ret_types],
is_pure)

if len(ret_types) == 0:
return tensor(call, void)
if len(ret_types) == 1:
return tensor(call.get_result(0), ret_types[0])
return tuple(tensor(call.get_result(i), ty) for i, ty in enumerate(ret_types))


@builtin
def extern_call(lib_name: str, lib_path: str, args: list, arg_type_symbol_dict: dict, is_pure: bool, _semantic=None):
'''
Dispatch an function to a library
:param lib_name: the name of the library
:param lib_path: the path of the library
:param args: the arguments of the function
:param arg_type_symbol_dict: the type of the arguments
:param is_pure: whether the function is pure
:param _semantic: the semantic
:return: the return value of the function
'''
dispatch_args = args.copy()
all_scalar = True
arg_types = []
for i in builtins.range(len(dispatch_args)):
dispatch_args[i] = _semantic.to_tensor(dispatch_args[i])
arg_types.append(dispatch_args[i].dtype)
if dispatch_args[i].type.is_block():
all_scalar = False
if not all_scalar:
raise ValueError("extern call only support inputs with scalr type")

if len(arg_type_symbol_dict) == 0:
raise ValueError("arg_type_symbol_dict is empty")

num_args = len(list(arg_type_symbol_dict.keys())[0])
if len(args) != num_args:
raise ValueError(f"length of input args does not match."
f"Expect {len(args)}, got {num_args}")

func = _semantic.builder.create_extern_call
return dispatch_ec(func, lib_name, lib_path, dispatch_args, arg_type_symbol_dict, is_pure, _semantic)


def extern(fn):
"""A decorator for external functions."""
return builtin(fn)
Expand Down
21 changes: 21 additions & 0 deletions python/tutorials/tle/raw/cuda/01-vector-add-half2.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
#include <cuda_fp16.h>

__device__ void
VectorAddHalf2(__attribute__((address_space(1))) __half *C,
__attribute__((address_space(1))) const __half *A,
__attribute__((address_space(1))) const __half *B, const int N) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;

const __half2 *A2 = reinterpret_cast<const __half2 *>(A);
const __half2 *B2 = reinterpret_cast<const __half2 *>(B);
__half2 *C2 = reinterpret_cast<__half2 *>(C);

for (int i = idx; i < N / 2; i += blockDim.x * gridDim.x) {
C2[i] = __hadd2(A2[i], B2[i]);
}

if (idx == 0 && N % 2 != 0) {
int last = N - 1;
C[last] = __hadd(A[last], B[last]);
}
}
42 changes: 42 additions & 0 deletions python/tutorials/tle/raw/cuda/01-vector-add-half2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
from pathlib import Path

import torch
import triton
import triton.language as tl
from triton.experimental.tle.raw import dialect
import triton.experimental.tle.language.raw as tle_raw

DEVICE = triton.runtime.driver.active.get_active_torch_device()


@dialect(name="cuda", file=Path(__file__).parent / "01-vector-add-half2.cu")
def edsl_half2(*args, **kwargs):
...


@triton.jit
def add_half2_kernel(
x_ptr,
y_ptr,
output_ptr,
n_elements,
BLOCK_SIZE: tl.constexpr,
):
tle_raw.call(edsl_half2, [output_ptr, x_ptr, y_ptr, n_elements])


def add_half2(x: torch.Tensor, y: torch.Tensor):
output = torch.empty_like(x)
assert x.device == DEVICE and y.device == DEVICE and output.device == DEVICE
n_elements = output.numel()
grid = lambda meta: (triton.cdiv(n_elements // 2, meta["BLOCK_SIZE"]), )
add_half2_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=1024)
return output


if __name__ == "__main__":
x = torch.randn(16384 * 256, device=DEVICE, dtype=torch.float16)
y = torch.randn(16384 * 256, device=DEVICE, dtype=torch.float16)
z_half2 = add_half2(x, y)

assert torch.allclose(x + y, z_half2), (x + y, z_half2)
Loading
Loading