Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
The table of contents is too big for display.
Diff view
Diff view
  •  
  •  
  •  
16 changes: 12 additions & 4 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@ elseif(FLAGTREE_BACKEND STREQUAL "mthreads")
set(ENV{PATH} "$ENV{LLVM_SYSPATH}/bin:$ENV{PATH}")
set(CMAKE_C_COMPILER clang)
set(CMAKE_CXX_COMPILER clang++)
set(ENV{FLAGTREE_PLUGIN} $ENV{FLAGTREE_BACKEND})
set(FLAGTREE_TLE OFF)
remove_definitions(-D__TLE__)
elseif(FLAGTREE_BACKEND STREQUAL "aipu")
set(CMAKE_C_COMPILER clang-16)
set(CMAKE_CXX_COMPILER clang++-16)
Expand Down Expand Up @@ -281,14 +282,19 @@ if(TRITON_BUILD_PYTHON_MODULE)
include_directories(${PROJECT_BINARY_DIR}/third_party/${FLAGTREE_BACKEND})
add_subdirectory(third_party/hcu/proton/Dialect)
add_subdirectory(third_party/nvidia)
elseif(FLAGTREE_BACKEND AND FLAGTREE_BACKEND STREQUAL "mthreads")
include_directories(${PROJECT_BINARY_DIR}/third_party/${FLAGTREE_BACKEND})
add_subdirectory(third_party/mthreads/proton/Dialect)
else()
list(APPEND TRITON_PLUGIN_NAMES "proton")
add_subdirectory(third_party/proton/Dialect)
endif()

# Add TLE plugin
list(APPEND TRITON_PLUGIN_NAMES "tle")
add_subdirectory(third_party/tle)
if(FLAGTREE_TLE)
list(APPEND TRITON_PLUGIN_NAMES "tle")
add_subdirectory(third_party/tle)
endif()

if (DEFINED TRITON_PLUGIN_DIRS)
foreach(PLUGIN_DIR ${TRITON_PLUGIN_DIRS})
Expand Down Expand Up @@ -499,7 +505,9 @@ if(NOT TRITON_BUILD_PYTHON_MODULE)
endforeach()
add_subdirectory(third_party/proton/Dialect)
# flagtree tle
add_subdirectory(third_party/tle)
if(FLAGTREE_TLE)
add_subdirectory(third_party/tle)
endif()
endif()

find_package(Threads REQUIRED)
Expand Down
123 changes: 123 additions & 0 deletions python/setup_tools/utils/mthreads.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
import sys
import shutil
import inspect
from pathlib import Path

from setuptools import find_packages

MTHREADS_PYTHON_ROOT = "third_party/mthreads/python"
FLAGTREE_PYTHON_ROOT = "python"
TLE_PACKAGE = "triton.experimental.tle"


def skip_package_dir(package):
return package == "triton" or package.startswith("triton.")


def get_package_dir():
return {
"": MTHREADS_PYTHON_ROOT,
}


def _is_backend_package(package):
return package == "triton.backends" or package.startswith("triton.backends.")


def _is_language_extra_package(package):
return package == "triton.language.extra" or package.startswith("triton.language.extra.")


def _merge_mthreads_packages(existing_packages):
packages = []
seen = set()

def add(package):
if package not in seen:
packages.append(package)
seen.add(package)

for package in find_packages(where=MTHREADS_PYTHON_ROOT, include=["triton", "triton.*"]):
add(package)

for package in find_packages(where=FLAGTREE_PYTHON_ROOT, include=[TLE_PACKAGE, f"{TLE_PACKAGE}.*"]):
add(package)

for package in existing_packages:
if (not package.startswith("triton.") or _is_backend_package(package) or _is_language_extra_package(package)
or package == "triton.profiler" or package.startswith("triton.profiler.")):
add(package)

return packages


def _merge_mthreads_package_dir(existing_package_dir):
package_dir = dict(existing_package_dir or {})
package_dir[""] = MTHREADS_PYTHON_ROOT

for package in find_packages(where=MTHREADS_PYTHON_ROOT, include=["triton", "triton.*"]):
rel_package_path = package.replace(".", "/")
package_dir[package] = f"{MTHREADS_PYTHON_ROOT}/{rel_package_path}"

for package in find_packages(where=FLAGTREE_PYTHON_ROOT, include=[TLE_PACKAGE, f"{TLE_PACKAGE}.*"]):
rel_package_path = package.replace(".", "/")
package_dir[package] = f"{FLAGTREE_PYTHON_ROOT}/{rel_package_path}"

return package_dir


def _patch_mthreads_cmdclass(existing_cmdclass):
cmdclass = dict(existing_cmdclass or {})
original_build_py = cmdclass.get("build_py")
if original_build_py is None:
return cmdclass

class MthreadsBuildPy(original_build_py):

def run(self):
self.force = True
build_triton_dir = Path(self.build_lib) / "triton"
if build_triton_dir.exists():
shutil.rmtree(build_triton_dir)
return super().run()

cmdclass["build_py"] = MthreadsBuildPy
return cmdclass


def _wrap_setup(original_setup):
if getattr(original_setup, "_mthreads_python_root_patched", False):
return original_setup

def setup_with_mthreads_python_root(*args, **kwargs):
kwargs["packages"] = _merge_mthreads_packages(kwargs.get("packages", []))
kwargs["package_dir"] = _merge_mthreads_package_dir(kwargs.get("package_dir", {}))
kwargs["cmdclass"] = _patch_mthreads_cmdclass(kwargs.get("cmdclass", {}))
return original_setup(*args, **kwargs)

setup_with_mthreads_python_root._mthreads_python_root_patched = True
setup_with_mthreads_python_root._mthreads_original_setup = original_setup
return setup_with_mthreads_python_root


def _patch_setup_for_mthreads_python_root():
patched = False

frame = inspect.currentframe()
while frame is not None:
setup_func = frame.f_globals.get("setup")
if callable(setup_func):
frame.f_globals["setup"] = _wrap_setup(setup_func)
patched = True
frame = frame.f_back

main_module = sys.modules.get("__main__")
if main_module is not None and hasattr(main_module, "setup"):
main_module.setup = _wrap_setup(main_module.setup)
patched = True

if not patched:
raise RuntimeError("mthreads setup hook could not find setup() to patch")


_patch_setup_for_mthreads_python_root()
2 changes: 0 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -704,8 +704,6 @@ def get_packages():

if helper.flagtree_backend == "xpu":
yield f"triton.language.extra.xpu"
elif helper.flagtree_backend == "mthreads":
yield f"triton/language/extra/musa"

if check_env_flag("TRITON_BUILD_PROTON", "ON"): # Default ON
yield "triton.profiler"
Expand Down
23 changes: 23 additions & 0 deletions third_party/mthreads/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
include_directories(${CMAKE_CURRENT_SOURCE_DIR})
include_directories(${CMAKE_CURRENT_BINARY_DIR})
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/include)
include_directories(${CMAKE_CURRENT_BINARY_DIR}/include)
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/musa/include)
include_directories(${CMAKE_CURRENT_BINARY_DIR}/musa/include)
add_subdirectory(include)
add_subdirectory(lib)
add_subdirectory(musa)
if(TRITON_BUILD_PYTHON_MODULE)
add_triton_plugin(TritonMthreads ${CMAKE_CURRENT_SOURCE_DIR}/triton_mthreads.cc
LINK_LIBS TritonMUSAGPUToLLVM MTGPUToLLVM TritonMUSAGPUTransforms)
add_dependencies(TritonMthreads
MUSATableGen
MUSAAttrDefsIncGen
MTGPUTableGen
MTGPUTypesIncGen
MTGPUConversionPassIncGen
TritonMUSAGPUConversionPassIncGen
TritonMUSAGPUTransformsIncGen)
target_link_libraries(TritonMthreads PRIVATE Python3::Module pybind11::headers)
endif()
add_subdirectory(bin)
Empty file.
Loading
Loading