diff --git a/python/setup_tools/setup_helper.py b/python/setup_tools/setup_helper.py index eaf7fff9a8..d465c70f61 100644 --- a/python/setup_tools/setup_helper.py +++ b/python/setup_tools/setup_helper.py @@ -293,7 +293,8 @@ def handle_flagtree_backend(): print(f"\033[1;32m[INFO] FlagtreeBackend is {configs.flagtree_backend}\033[0m") configs.extend_backends.append(configs.flagtree_backend) if "editable_wheel" in sys.argv and configs.flagtree_backend not in configs.plugin_backends: - ext_sourcedir = os.path.abspath(f"./third_party/{configs.flagtree_backend}/python/{ext_sourcedir}") + "/" + configs.ext_sourcedir = os.path.abspath( + f"./third_party/{configs.flagtree_backend}/python/{configs.ext_sourcedir}") + "/" def handle_plugin_backend(editable): @@ -310,6 +311,8 @@ def handle_plugin_backend(editable): os.makedirs(dst_build_plugin_dir) dst_build_plugin_path = dst_build_plugin_dir / flagtree_plugin_so shutil.copy(src_build_plugin_path, dst_build_plugin_path) + dst_build_plugin_path = Path(__file__).resolve().parent.parent.parent / "python" / "triton" / "_C" + shutil.copy(src_build_plugin_path, dst_build_plugin_path) src_install_plugin_path = flagtree_backend_dir / flagtree_plugin_so if flagtree_backend in ("mthreads", "sunrise"): dst_install_plugin_dir = Path( @@ -362,8 +365,6 @@ def uninstall_triton(): download_flagtree_third_party("flir", condition=(flagtree_backend == "aipu"), hock=utils.aipu.precompile_hock, required=True) -handle_plugin_backend(False) - handle_flagtree_backend() cache = FlagTreeCache() diff --git a/setup.py b/setup.py index 62aea0f91d..9d86eb59b5 100644 --- a/setup.py +++ b/setup.py @@ -453,7 +453,8 @@ def build_extension(self, ext): thirdparty_cmake_args = get_thirdparty_packages([get_llvm_package_info()]) thirdparty_cmake_args += self.get_pybind11_cmake_args() extdir = os.path.abspath(os.path.dirname(self.get_ext_fullpath(ext.path))) - wheeldir = os.path.dirname(extdir) + print(f"extdir {extdir}") + wheeldir = os.path.dirname(os.path.dirname(extdir)) # create build directories if not os.path.exists(self.build_temp): @@ -713,6 +714,11 @@ def add_link_to_backends(external_only): package_data_tools = ["compile.h", "compile.c"] if helper.flagtree_backend == "xpu": package_data_tools += ["compile_xpu.h", "compile_xpu.c"] + +if helper.flagtree_backend == "sunrise": + package_data = {"": ["*TritonPlugin.so"]} +else: + package_data = {} # package_data = { # "triton/tools/extra": sum((b.tools_package_data for b in backends), []), # **{f"triton/backends/{b.name}": b.package_data @@ -855,8 +861,9 @@ def get_git_version_suffix(): packages=list(get_packages()), package_dir=dict(get_package_dirs()), entry_points=get_entry_points(), + package_data=package_data, include_package_data=True, - ext_modules=[CMakeExtension("triton", "triton/_C/")], + ext_modules=[CMakeExtension("triton", helper.configs.ext_sourcedir)], cmdclass={ "bdist_wheel": plugin_bdist_wheel, "build_ext": CMakeBuild, diff --git a/third_party/sunrise/CMakeLists.txt b/third_party/sunrise/CMakeLists.txt index d7f0e11e39..571fc1126e 100644 --- a/third_party/sunrise/CMakeLists.txt +++ b/third_party/sunrise/CMakeLists.txt @@ -5,7 +5,7 @@ option(EDITABLE_MODE "Build in developer (editable) mode" OFF) if(FLAGTREE_PLUGIN) set(SUNRISE_PLUGIN_DIR "${Python3_SITELIB}/triton/_C") elseif(EDITABLE_MODE) - set(SUNRISE_PLUGIN_DIR "${CMAKE_CURRENT_SOURCE_DIR}") + set(SUNRISE_PLUGIN_DIR "${CMAKE_CURRENT_SOURCE_DIR}/python/triton/_C") else() set(SUNRISE_PLUGIN_DIR "${Python3_SITELIB}/triton/_C") endif() diff --git a/third_party/sunrise/python/triton/backends b/third_party/sunrise/python/triton/backends new file mode 120000 index 0000000000..13a83a85ce --- /dev/null +++ b/third_party/sunrise/python/triton/backends @@ -0,0 +1 @@ +../../../../python/triton/backends \ No newline at end of file diff --git a/third_party/sunrise/python/triton/backends/__init__.py b/third_party/sunrise/python/triton/backends/__init__.py deleted file mode 100644 index 69a8dab0a3..0000000000 --- a/third_party/sunrise/python/triton/backends/__init__.py +++ /dev/null @@ -1,47 +0,0 @@ -import importlib -import inspect -import sys -from dataclasses import dataclass -from typing import Type, TypeVar, Union -from types import ModuleType -from .driver import DriverBase -from .compiler import BaseBackend - -if sys.version_info >= (3, 10): - from importlib.metadata import entry_points -else: - from importlib_metadata import entry_points - -T = TypeVar("T", bound=Union[BaseBackend, DriverBase]) - - -def _find_concrete_subclasses(module: ModuleType, base_class: Type[T]) -> Type[T]: - ret: list[Type[T]] = [] - for attr_name in dir(module): - attr = getattr(module, attr_name) - if isinstance(attr, type) and issubclass(attr, base_class) and not inspect.isabstract(attr): - ret.append(attr) - if len(ret) == 0: - raise RuntimeError(f"Found 0 concrete subclasses of {base_class} in {module}: {ret}") - if len(ret) > 1: - raise RuntimeError(f"Found >1 concrete subclasses of {base_class} in {module}: {ret}") - return ret[0] - - -@dataclass(frozen=True) -class Backend: - compiler: Type[BaseBackend] - driver: Type[DriverBase] - - -def _discover_backends() -> dict[str, Backend]: - backends = dict() - for ep in entry_points().select(group="triton.backends"): - compiler = importlib.import_module(f"{ep.value}.compiler") - driver = importlib.import_module(f"{ep.value}.driver") - backends[ep.name] = Backend(_find_concrete_subclasses(compiler, BaseBackend), # type: ignore - _find_concrete_subclasses(driver, DriverBase)) # type: ignore - return backends - - -backends: dict[str, Backend] = _discover_backends() diff --git a/third_party/sunrise/python/triton/backends/compiler.py b/third_party/sunrise/python/triton/backends/compiler.py deleted file mode 100644 index 9bbc5eadbd..0000000000 --- a/third_party/sunrise/python/triton/backends/compiler.py +++ /dev/null @@ -1,90 +0,0 @@ -from abc import ABCMeta, abstractmethod -from dataclasses import dataclass -from enum import Enum -from typing import Dict, Union -from types import ModuleType - - -@dataclass(frozen=True) -class GPUTarget(object): - # Target backend, e.g., cuda, hip - backend: str - # Target architecture, e.g., 90 (for cuda compute capability), gfx940 (for hip) - arch: Union[int, str] - warp_size: int - - -class Language(Enum): - """The input language being compiled by the backend.""" - TRITON = 0 - GLUON = 1 - - -class BaseBackend(metaclass=ABCMeta): - - def __init__(self, target: GPUTarget) -> None: - self.target = target - assert self.supports_target(target) - - @staticmethod - @abstractmethod - def supports_target(target: GPUTarget): - raise NotImplementedError - - @abstractmethod - def hash(self) -> str: - """Returns a unique identifier for this backend""" - raise NotImplementedError - - @abstractmethod - def parse_options(self, options: dict) -> object: - """ - Converts an `options` dictionary into an arbitrary object and returns it. - This function may contain target-specific heuristics and check the legality of the provided options - """ - raise NotImplementedError - - @abstractmethod - def add_stages(self, stages: dict, options: object) -> None: - """ - Populates `stages` dictionary with entries of the form: - ir_name [str] => Function[(src: str, metadata: dict) -> str|bytes] - The value of each entry may populate a `metadata` dictionary. - Stages will be run sequentially (in inseriton order) and can communicate using `metadata`. - All stages are expected to return a `str` object, except for the last stage which returns - a `bytes` object for execution by the launcher. - """ - raise NotImplementedError - - @abstractmethod - def load_dialects(self, context): - """ - Load additional MLIR dialects into the provided `context` - """ - raise NotImplementedError - - @abstractmethod - def get_module_map(self) -> Dict[str, ModuleType]: - """ - Return a map of interface modules to their device-specific implementations - """ - raise NotImplementedError - - @staticmethod - def parse_attr(desc): - assert isinstance(desc, str) - ret = [] - if "D" in desc: - ret += [["tt.divisibility", 16]] - return ret - - @staticmethod - def get_arg_specialization(arg, ty, **kwargs): - """ - Return a string unique to each possible specialization of the argument - """ - if ty == "int" and arg % 16 == 0 and kwargs.get("align", False): - return "D" - if ty == "tensor" and arg.data_ptr() % 16 == 0 and kwargs.get("align", False): - return "D" - return "" diff --git a/third_party/sunrise/python/triton/backends/driver.py b/third_party/sunrise/python/triton/backends/driver.py deleted file mode 100644 index ca5d501005..0000000000 --- a/third_party/sunrise/python/triton/backends/driver.py +++ /dev/null @@ -1,84 +0,0 @@ -from abc import ABCMeta, abstractmethod -from typing import Callable, List, Protocol, Sequence - - -class Benchmarker(Protocol): - - def __call__(self, kernel_call: Callable, *, quantiles: List[float], **kwargs) -> Sequence[float]: - pass - - -class DriverBase(metaclass=ABCMeta): - - @classmethod - @abstractmethod - def is_active(self): - pass - - @abstractmethod - def get_current_target(self): - pass - - @abstractmethod - def get_active_torch_device(self): - pass - - @abstractmethod - def get_benchmarker(self) -> Benchmarker: - """ - Return the benchmarking function that this backend should use by default. - """ - raise NotImplementedError - - def __init__(self) -> None: - pass - - -class GPUDriver(DriverBase): - - def __init__(self): - # TODO: support other frameworks than torch - import torch - try: - import torch_ptpu - _is_ptpu = True - except ImportError as e: - _is_ptpu = False - if _is_ptpu: - self.get_device_capability = torch.ptpu.get_device_capability - self.get_current_stream = lambda dev_idx: torch.ptpu.current_stream(dev_idx).ptpu_stream - self.get_current_device = torch.ptpu.current_device - self.set_current_device = torch.ptpu.set_device - return - - try: - from torch._C import _cuda_getCurrentRawStream - _is_cuda = True - except ImportError as e: - _cuda_getCurrentRawStream = None - _is_cuda = True if torch.version.cuda else False - if _is_cuda: - self.get_device_capability = torch.cuda.get_device_capability - if _cuda_getCurrentRawStream is not None: - self.get_current_stream = _cuda_getCurrentRawStream - else: - self.get_current_stream = lambda dev_idx: torch.cuda.current_stream(dev_idx).cuda_stream - self.get_current_device = torch.cuda.current_device - self.set_current_device = torch.cuda.set_device - return - - try: - import torch_dipu - _is_dipu = True - except ImportError as e: - _is_dipu = False - if _is_dipu: - self.get_device_capability = torch.cuda.get_device_capability - self.get_current_stream = lambda dev_idx: torch.cuda.current_stream(dev_idx).dipu_stream - self.get_current_device = torch.cuda.current_device - self.set_current_device = torch.cuda.set_device - return - - # TODO: remove once TMA is cleaned up - def assemble_tensormap_to_arg(self, tensormaps_info, args): - return args